source: trunk/lib/classifier/CrossSplitter.cc @ 514

Last change on this file since 514 was 514, checked in by Peter, 16 years ago

generalised binary functionality in Target

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.5 KB
Line 
1// $Id: CrossSplitter.cc 514 2006-02-20 09:45:34Z peter $
2
3
4#include <c++_tools/classifier/CrossSplitter.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/random/random.h>
8
9#include <algorithm>
10#include <cassert>
11#include <utility>
12#include <vector>
13
14namespace theplu {
15namespace classifier { 
16
17  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
18                               const size_t N, const size_t k)
19    : k_(k), state_(0), target_(target)
20  { 
21    assert(target.size()>1);
22    std::vector<std::pair<size_t,size_t> > v;
23    for (size_t i=0; i<target.size(); i++)
24      v.push_back(std::make_pair(target(i),i));
25    // sorting with respect to class
26    std::sort(v.begin(),v.end());
27   
28    // my_begin[i] is index of first sample of class i
29    std::vector<size_t> my_begin;
30    my_begin.reserve(target.nof_classes());
31    my_begin.push_back(0);
32    for (size_t i=1; i<target.size(); i++)
33      while (v[i].first > my_begin.size()-1)
34        my_begin.push_back(i);
35    my_begin.push_back(target.size());
36
37    random::DiscreteUniform rnd;
38
39    for (size_t i=0; i<N; ) {
40
41      // shuffle indices within class each class
42      for (size_t j=0; j<target.nof_classes(); j++)
43        random_shuffle(v.begin()+my_begin[j],v.begin()+my_begin[j+1],rnd);
44     
45      for (size_t part=0; part<k && i<N; i++, part++) {
46        std::vector<size_t> training_index;
47        std::vector<size_t> validation_index;
48        for (size_t j=0; j<v.size(); j++) {
49          if (j%k==part)
50            validation_index.push_back(v[j].second);
51          else
52            training_index.push_back(v[j].second);
53        }
54
55        // Dynamically allocated. Must be deleted in destructor.
56        training_data_.push_back(data.training_data(training_index));
57        validation_data_.push_back(data.validation_data(training_index, 
58                                                        validation_index));
59
60        training_target_.push_back(Target(target,training_index));
61        validation_target_.push_back(Target(target,validation_index));
62        training_index_.push_back(training_index);
63        validation_index_.push_back(validation_index);
64
65      }
66    }
67    assert(training_data_.size()==N);
68    assert(training_target_.size()==N);
69    assert(training_index_.size()==N);
70    assert(validation_data_.size()==N);
71    assert(validation_target_.size()==N);
72    assert(validation_index_.size()==N);
73  }
74
75  CrossSplitter::~CrossSplitter()
76  {
77    assert(training_data_.size()==validation_data_.size());
78    for (size_t i=0; i<training_data_.size(); i++) 
79      delete training_data_[i];
80    for (size_t i=0; i<validation_data_.size(); i++) 
81      delete validation_data_[i];
82  }
83
84}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.