source: trunk/lib/classifier/CrossSplitter.cc @ 509

Last change on this file since 509 was 509, checked in by Peter, 16 years ago

added test for target
redesign crossSplitter
added two class function in Target

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.5 KB
Line 
1// $Id: CrossSplitter.cc 509 2006-02-18 13:47:32Z peter $
2
3
4#include <c++_tools/classifier/CrossSplitter.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/random/random.h>
8
9#include <algorithm>
10#include <cassert>
11#include <utility>
12#include <vector>
13
14namespace theplu {
15namespace classifier { 
16
17  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
18                               const size_t N, const size_t k)
19    : k_(k), state_(0), target_(target)
20  { 
21    assert(target.size()>1);
22    std::vector<std::pair<size_t,size_t> > v;
23    for (size_t i=0; i<target.size(); i++)
24      v.push_back(std::make_pair(target(i),i));
25    // sorting with respect to class
26    std::sort(v.begin(),v.end());
27   
28    // my_begin[i] is index of first sample of class i
29    std::vector<size_t> my_begin;
30    my_begin.reserve(target.nof_classes());
31    my_begin.push_back(0);
32    for (size_t i=1; i<target.size(); i++)
33      while (v[i].first > my_begin.size()-1)
34        my_begin.push_back(i);
35    my_begin.push_back(target.size());
36
37    random::DiscreteUniform rnd;
38
39    for (size_t i=0; i<N; ) {
40
41      // shuffle indices within class each class
42      for (size_t j=0; j<target.nof_classes(); j++)
43        random_shuffle(v.begin()+my_begin[j],v.begin()+my_begin[j+1],rnd);
44     
45      for (size_t part=0; part<k && i<N; i++, part++) {
46       
47        std::vector<size_t> training_index;
48        std::vector<size_t> validation_index;
49
50        for (size_t j=0; j<v.size(); j++) {
51          if (j%k==part)
52            validation_index.push_back(v[j].second);
53          else
54            training_index.push_back(v[j].second);
55        }
56
57        // Dynamically allocated. Must be deleted in destructor.
58        training_data_.push_back(data.training_data(training_index));
59        validation_data_.push_back(data.validation_data(training_index, 
60                                                        validation_index));
61
62        training_target_.push_back(Target(target,training_index));
63        validation_target_.push_back(Target(target,validation_index));
64        training_index_.push_back(training_index);
65        validation_index_.push_back(validation_index);
66
67      }
68    }
69    assert(training_data_.size()==N);
70    assert(training_target_.size()==N);
71    assert(training_index_.size()==N);
72    assert(validation_data_.size()==N);
73    assert(validation_target_.size()==N);
74    assert(validation_index_.size()==N);
75  }
76
77  CrossSplitter::~CrossSplitter()
78  {
79    assert(training_data_.size()==validation_data_.size());
80    for (size_t i=0; i<training_data_.size(); i++) 
81      delete training_data_[i];
82    for (size_t i=0; i<validation_data_.size(); i++) 
83      delete validation_data_[i];
84  }
85
86}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.