source: trunk/lib/classifier/CrossSplitter.cc @ 541

Last change on this file since 541 was 541, checked in by Peter, 16 years ago

weights supported in CrossSplitter?

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 4.3 KB
Line 
1// $Id: CrossSplitter.cc 541 2006-03-05 15:20:58Z peter $
2
3
4#include <c++_tools/classifier/CrossSplitter.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/random/random.h>
8
9#include <algorithm>
10#include <cassert>
11#include <utility>
12#include <vector>
13
14namespace theplu {
15namespace classifier { 
16
17  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
18                               const size_t N, const size_t k)
19    : k_(k), state_(0), target_(target)
20  { 
21    assert(target.size()>1);
22    assert(target.size()==data.columns());
23
24    build(target, N, k);
25
26    for (size_t i=0; i<N; i++){
27     
28      // Dynamically allocated. Must be deleted in destructor.
29      training_data_.push_back(data.training_data(training_index_[i]));
30      training_weight_.push_back
31        (new MatrixLookup(training_data_.back()->rows(),
32                          training_data_.back()->columns(),1));
33      validation_data_.push_back(data.validation_data(training_index_[i], 
34                                                    validation_index_[i]));
35      validation_weight_.push_back
36        (new MatrixLookup(validation_data_.back()->rows(),
37                          validation_data_.back()->columns(),1));
38
39
40      training_target_.push_back(Target(target,training_index_[i]));
41      validation_target_.push_back(Target(target,validation_index_[i]));
42    }
43    assert(training_data_.size()==N);
44    assert(training_weight_.size()==N);
45    assert(training_target_.size()==N);
46    assert(validation_data_.size()==N);
47    assert(validation_weight_.size()==N);
48    assert(validation_target_.size()==N);
49  }
50
51  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
52                               const MatrixLookup& weight,
53                               const size_t N, const size_t k)
54    : k_(k), state_(0), target_(target)
55  { 
56    assert(target.size()>1);
57    assert(target.size()==data.columns());
58
59    build(target, N, k);
60
61    for (size_t i=0; i<N; i++){
62     
63      // Dynamically allocated. Must be deleted in destructor.
64      training_data_.push_back(data.training_data(training_index_[i]));
65      validation_data_.push_back(data.validation_data(training_index_[i], 
66                                                    validation_index_[i]));
67      training_weight_.push_back(weight.training_data(training_index_[i]));
68      validation_weight_.push_back(weight.validation_data(training_index_[i], 
69                                                          validation_index_[i]));
70
71
72      training_target_.push_back(Target(target,training_index_[i]));
73      validation_target_.push_back(Target(target,validation_index_[i]));
74    }
75    assert(training_data_.size()==N);
76    assert(training_weight_.size()==N);
77    assert(training_target_.size()==N);
78    assert(validation_data_.size()==N);
79    assert(validation_weight_.size()==N);
80    assert(validation_target_.size()==N);
81  }
82
83  CrossSplitter::~CrossSplitter()
84  {
85    assert(training_data_.size()==validation_data_.size());
86    for (size_t i=0; i<training_data_.size(); i++) 
87      delete training_data_[i];
88    for (size_t i=0; i<validation_data_.size(); i++) 
89      delete validation_data_[i];
90    for (size_t i=0; i<training_weight_.size(); i++) 
91      delete training_weight_[i];
92    for (size_t i=0; i<validation_weight_.size(); i++) 
93      delete validation_weight_[i];
94  }
95
96  void CrossSplitter::build(const Target& target, size_t N, size_t k)
97  {
98    std::vector<std::pair<size_t,size_t> > v;
99    for (size_t i=0; i<target.size(); i++)
100      v.push_back(std::make_pair(target(i),i));
101    // sorting with respect to class
102    std::sort(v.begin(),v.end());
103   
104    // my_begin[i] is index of first sample of class i
105    std::vector<size_t> my_begin;
106    my_begin.reserve(target.nof_classes());
107    my_begin.push_back(0);
108    for (size_t i=1; i<target.size(); i++)
109      while (v[i].first > my_begin.size()-1)
110        my_begin.push_back(i);
111    my_begin.push_back(target.size());
112
113    random::DiscreteUniform rnd;
114
115    for (size_t i=0; i<N; ) {
116      // shuffle indices within class each class
117      for (size_t j=0; j<target.nof_classes(); j++)
118        random_shuffle(v.begin()+my_begin[j],v.begin()+my_begin[j+1],rnd);
119     
120      for (size_t part=0; part<k && i<N; i++, part++) {
121        std::vector<size_t> training_index;
122        std::vector<size_t> validation_index;
123        for (size_t j=0; j<v.size(); j++) {
124          if (j%k==part)
125            validation_index.push_back(v[j].second);
126          else
127            training_index.push_back(v[j].second);
128        }
129
130        training_index_.push_back(training_index);
131        validation_index_.push_back(validation_index);
132      }
133    }
134    assert(training_index_.size()==N);
135    assert(validation_index_.size()==N);
136}
137
138}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.