source: trunk/lib/classifier/CrossSplitter.cc @ 505

Last change on this file since 505 was 505, checked in by Markus Ringnér, 16 years ago

Continuing with adding validation functionality to EnsembleBuilder?. Structure is there but also bugs

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.4 KB
Line 
1// $Id: CrossSplitter.cc 505 2006-02-02 16:34:38Z markus $
2
3
4#include <c++_tools/classifier/CrossSplitter.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/random/random.h>
8
9#include <vector>
10
11namespace theplu {
12namespace classifier { 
13
14  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
15                               const size_t N, const size_t k)
16    : k_(k), state_(0), target_(target)
17  { 
18    std::vector<size_t> index_pos;
19    std::vector<size_t> index_neg;
20
21    for (size_t i=0; i<target.size(); i++){
22      if (target(i)==1)
23        index_pos.push_back(i);
24      else
25        index_neg.push_back(i);
26    }
27
28    std::vector<size_t> part_pos(index_pos.size()); // [0,k-1]
29    for (size_t i=0; i<part_pos.size(); i++)
30      part_pos[i] = int(i*k/part_pos.size());
31
32    std::vector<size_t> part_neg(index_neg.size()); // [0,k-1]
33    for (size_t i=0; i<part_pos.size(); i++)
34      part_neg[i] = int(i*k/part_neg.size());
35
36    random::DiscreteUniform rnd;
37
38
39    for (size_t i=0; i<N; ) {
40      random_shuffle(index_neg.begin(), index_neg.end(), rnd);
41      random_shuffle(index_pos.begin(), index_pos.end(), rnd);
42     
43      std::vector<size_t> training_index;
44      std::vector<size_t> validation_index;
45
46      for (size_t part=0; part<k && i<N; i++, part++) {
47       
48        training_index.clear();
49        training_index.clear();
50        for (size_t j=0; j<index_neg.size(); j++) {
51          if (part_neg[j]==part)
52            validation_index.push_back(index_neg[j]);
53          else
54            training_index.push_back(index_neg[j]);
55        }
56        for (size_t j=0; j<index_pos.size(); j++) {
57          if (part_pos[j]==part)
58            validation_index.push_back(index_pos[j]);
59          else
60            training_index.push_back(index_pos[j]);
61        }
62
63        // Dynamically allocated. Must be deleted in destructor.
64        training_data_.push_back(data.training_data(training_index));
65        validation_data_.push_back(data.validation_data(training_index, 
66                                                        validation_index));
67       
68        training_target_.push_back(Target(target,training_index));
69        validation_target_.push_back(Target(target,validation_index));
70       
71        training_index_.push_back(training_index);
72        validation_index_.push_back(validation_index);
73      }
74    }
75   
76  }
77
78  CrossSplitter::~CrossSplitter()
79  {
80    for (size_t i=0; i<training_data_.size(); i++) 
81      delete training_data_[i];
82    for (size_t i=0; i<validation_data_.size(); i++) 
83      delete validation_data_[i];
84  }
85
86}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.