source: trunk/lib/classifier/CrossSplitter.cc @ 517

Last change on this file since 517 was 517, checked in by Peter, 16 years ago

fixed bug in DataLookup2D and created test for EnsembleBuilder?

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.5 KB
Line 
1// $Id: CrossSplitter.cc 517 2006-02-21 16:35:34Z peter $
2
3
4#include <c++_tools/classifier/CrossSplitter.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/random/random.h>
8
9#include <algorithm>
10#include <cassert>
11#include <utility>
12#include <vector>
13
14namespace theplu {
15namespace classifier { 
16
17  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
18                               const size_t N, const size_t k)
19    : k_(k), state_(0), target_(target)
20  { 
21    assert(target.size()>1);
22    assert(target.size()==data.columns());
23    std::vector<std::pair<size_t,size_t> > v;
24    for (size_t i=0; i<target.size(); i++)
25      v.push_back(std::make_pair(target(i),i));
26    // sorting with respect to class
27    std::sort(v.begin(),v.end());
28   
29    // my_begin[i] is index of first sample of class i
30    std::vector<size_t> my_begin;
31    my_begin.reserve(target.nof_classes());
32    my_begin.push_back(0);
33    for (size_t i=1; i<target.size(); i++)
34      while (v[i].first > my_begin.size()-1)
35        my_begin.push_back(i);
36    my_begin.push_back(target.size());
37
38    random::DiscreteUniform rnd;
39
40    for (size_t i=0; i<N; ) {
41      // shuffle indices within class each class
42      for (size_t j=0; j<target.nof_classes(); j++)
43        random_shuffle(v.begin()+my_begin[j],v.begin()+my_begin[j+1],rnd);
44     
45      for (size_t part=0; part<k && i<N; i++, part++) {
46        std::vector<size_t> training_index;
47        std::vector<size_t> validation_index;
48        for (size_t j=0; j<v.size(); j++) {
49          if (j%k==part)
50            validation_index.push_back(v[j].second);
51          else
52            training_index.push_back(v[j].second);
53        }
54
55        // Dynamically allocated. Must be deleted in destructor.
56        training_data_.push_back(data.training_data(training_index));
57        validation_data_.push_back(data.validation_data(training_index, 
58                                                        validation_index));
59
60        training_target_.push_back(Target(target,training_index));
61        validation_target_.push_back(Target(target,validation_index));
62        training_index_.push_back(training_index);
63        validation_index_.push_back(validation_index);
64
65      }
66    }
67    assert(training_data_.size()==N);
68    assert(training_target_.size()==N);
69    assert(training_index_.size()==N);
70    assert(validation_data_.size()==N);
71    assert(validation_target_.size()==N);
72    assert(validation_index_.size()==N);
73  }
74
75  CrossSplitter::~CrossSplitter()
76  {
77    assert(training_data_.size()==validation_data_.size());
78    for (size_t i=0; i<training_data_.size(); i++) 
79      delete training_data_[i];
80    for (size_t i=0; i<validation_data_.size(); i++) 
81      delete validation_data_[i];
82  }
83
84}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.