source: trunk/c++_tools/classifier/CrossSplitter.cc @ 608

Last change on this file since 608 was 608, checked in by Peter, 15 years ago

set properties

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 6.7 KB
Line 
1// $Id$
2
3
4#include <c++_tools/classifier/CrossSplitter.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/FeatureSelector.h>
7#include <c++_tools/classifier/Target.h>
8#include <c++_tools/random/random.h>
9
10#include <algorithm>
11#include <cassert>
12#include <utility>
13#include <vector>
14
15namespace theplu {
16namespace classifier { 
17
18  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
19                               const size_t N, const size_t k)
20    : k_(k), state_(0), target_(target), weighted_(false)
21  { 
22    assert(target.size()>1);
23    assert(target.size()==data.columns());
24
25    build(target, N, k);
26
27    for (size_t i=0; i<N; i++){
28     
29      // Dynamically allocated. Must be deleted in destructor.
30      training_data_.push_back(data.training_data(training_index_[i]));
31      training_weight_.push_back
32        (new MatrixLookup(training_data_.back()->rows(),
33                          training_data_.back()->columns(),1));
34      validation_data_.push_back(data.validation_data(training_index_[i], 
35                                                    validation_index_[i]));
36      validation_weight_.push_back
37        (new MatrixLookup(validation_data_.back()->rows(),
38                          validation_data_.back()->columns(),1));
39
40
41      training_target_.push_back(Target(target,training_index_[i]));
42      validation_target_.push_back(Target(target,validation_index_[i]));
43    }
44
45    // No feature selection, hence features same for all partitions
46    // and can be stored in features_[0]
47    features_.resize(1);
48    features_[0].reserve(data.rows());
49    for (size_t i=0; i<data.rows(); ++i)
50      features_[0].push_back(i);
51
52    assert(training_data_.size()==N);
53    assert(training_weight_.size()==N);
54    assert(training_target_.size()==N);
55    assert(validation_data_.size()==N);
56    assert(validation_weight_.size()==N);
57    assert(validation_target_.size()==N);
58  }
59
60  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
61                               const MatrixLookup& weight,
62                               const size_t N, const size_t k)
63    : k_(k), state_(0), target_(target), weighted_(true)
64  { 
65    assert(target.size()>1);
66    assert(target.size()==data.columns());
67
68    build(target, N, k);
69
70    for (size_t i=0; i<N; i++){
71     
72      // Dynamically allocated. Must be deleted in destructor.
73      training_data_.push_back(data.training_data(training_index_[i]));
74      validation_data_.push_back(data.validation_data(training_index_[i], 
75                                                    validation_index_[i]));
76      training_weight_.push_back(weight.training_data(training_index_[i]));
77      validation_weight_.push_back(weight.validation_data(training_index_[i], 
78                                                          validation_index_[i]));
79
80
81      training_target_.push_back(Target(target,training_index_[i]));
82      validation_target_.push_back(Target(target,validation_index_[i]));
83    }
84    assert(training_data_.size()==N);
85    assert(training_weight_.size()==N);
86    assert(training_target_.size()==N);
87    assert(validation_data_.size()==N);
88    assert(validation_weight_.size()==N);
89    assert(validation_target_.size()==N);
90  }
91
92  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
93                               const size_t N, const size_t k, 
94                               FeatureSelector& fs)
95    : k_(k), state_(0), target_(target), weighted_(false), f_selector_(&fs)
96  { 
97    assert(target.size()>1);
98    assert(target.size()==data.columns());
99
100    build(target, N, k);
101    features_.reserve(N);
102    training_data_.reserve(N);
103    training_weight_.reserve(N);
104    validation_data_.reserve(N);
105    validation_weight_.reserve(N);
106     
107    for (size_t i=0; i<N; i++){
108     
109      // training data with no feature selection
110      const DataLookup2D* train_data_all_feat = 
111        data.training_data(training_index_[i]);
112      // use these data to create feature selection
113      f_selector_->update(*train_data_all_feat, training_target_[i]);
114      // get features
115      features_.push_back(f_selector_->features());
116      delete train_data_all_feat;
117
118      // Dynamically allocated. Must be deleted in destructor.
119      training_data_.push_back(data.training_data(features_[i], 
120                                                  training_index_[i]));
121      training_weight_.push_back
122        (new MatrixLookup(training_data_.back()->rows(),
123                          training_data_.back()->columns(),1));
124      validation_data_.push_back(data.validation_data(features_[i],
125                                                      training_index_[i], 
126                                                      validation_index_[i]));
127      validation_weight_.push_back
128        (new MatrixLookup(validation_data_.back()->rows(),
129                          validation_data_.back()->columns(),1));
130
131
132      training_target_.push_back(Target(target,training_index_[i]));
133      validation_target_.push_back(Target(target,validation_index_[i]));
134    }
135
136    // No feature selection, hence features same for all partitions
137    // and can be stored in features_[0]
138    features_.resize(1);
139    features_[0].reserve(data.rows());
140    for (size_t i=0; i<data.rows(); ++i)
141      features_[0].push_back(i);
142
143    assert(training_data_.size()==N);
144    assert(training_weight_.size()==N);
145    assert(training_target_.size()==N);
146    assert(validation_data_.size()==N);
147    assert(validation_weight_.size()==N);
148    assert(validation_target_.size()==N);
149  }
150
151  CrossSplitter::~CrossSplitter()
152  {
153    assert(training_data_.size()==validation_data_.size());
154    for (size_t i=0; i<training_data_.size(); i++) 
155      delete training_data_[i];
156    for (size_t i=0; i<validation_data_.size(); i++) 
157      delete validation_data_[i];
158    for (size_t i=0; i<training_weight_.size(); i++) 
159      delete training_weight_[i];
160    for (size_t i=0; i<validation_weight_.size(); i++) 
161      delete validation_weight_[i];
162  }
163
164  void CrossSplitter::build(const Target& target, size_t N, size_t k)
165  {
166    std::vector<std::pair<size_t,size_t> > v;
167    for (size_t i=0; i<target.size(); i++)
168      v.push_back(std::make_pair(target(i),i));
169    // sorting with respect to class
170    std::sort(v.begin(),v.end());
171   
172    // my_begin[i] is index of first sample of class i
173    std::vector<size_t> my_begin;
174    my_begin.reserve(target.nof_classes());
175    my_begin.push_back(0);
176    for (size_t i=1; i<target.size(); i++)
177      while (v[i].first > my_begin.size()-1)
178        my_begin.push_back(i);
179    my_begin.push_back(target.size());
180
181    random::DiscreteUniform rnd;
182
183    for (size_t i=0; i<N; ) {
184      // shuffle indices within class each class
185      for (size_t j=0; j<target.nof_classes(); j++)
186        random_shuffle(v.begin()+my_begin[j],v.begin()+my_begin[j+1],rnd);
187     
188      for (size_t part=0; part<k && i<N; i++, part++) {
189        std::vector<size_t> training_index;
190        std::vector<size_t> validation_index;
191        for (size_t j=0; j<v.size(); j++) {
192          if (j%k==part)
193            validation_index.push_back(v[j].second);
194          else
195            training_index.push_back(v[j].second);
196        }
197
198        training_index_.push_back(training_index);
199        validation_index_.push_back(validation_index);
200      }
201    }
202    assert(training_index_.size()==N);
203    assert(validation_index_.size()==N);
204}
205
206}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.