source: trunk/c++_tools/classifier/SubsetGenerator.cc @ 637

Last change on this file since 637 was 637, checked in by Peter, 15 years ago

fixes #122

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 6.5 KB
Line 
1// $Id$
2
3
4#include <c++_tools/classifier/SubsetGenerator.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/FeatureSelector.h>
7#include <c++_tools/classifier/KernelLookup.h>
8#include <c++_tools/classifier/MatrixLookup.h>
9#include <c++_tools/classifier/MatrixLookupWeighted.h>
10#include <c++_tools/classifier/Target.h>
11
12#include <algorithm>
13#include <cassert>
14#include <utility>
15#include <typeinfo>
16#include <vector>
17
18namespace theplu {
19namespace classifier { 
20
21  SubsetGenerator::SubsetGenerator(const Sampler& sampler, 
22                                   const DataLookup2D& data)
23    : f_selector_(NULL), sampler_(sampler), state_(0), weighted_(false)
24  { 
25    assert(target().size()==data.columns());
26
27    training_data_.reserve(sampler_.size());
28    validation_data_.reserve(sampler_.size());
29    for (size_t i=0; i<sampler_.size(); ++i){
30      // Dynamically allocated. Must be deleted in destructor.
31      training_data_.push_back(data.training_data(sampler.training_index(i)));
32      validation_data_.push_back(data.validation_data(sampler.training_index(i),
33                                                      sampler.validation_index(i)));
34
35      training_target_.push_back(Target(target(),sampler.training_index(i)));
36      validation_target_.push_back(Target(target(),
37                                          sampler.validation_index(i)));
38      assert(training_data_.size()==i+1);
39      assert(training_target_.size()==i+1);
40      assert(validation_data_.size()==i+1);
41      assert(validation_target_.size()==i+1);
42    }
43
44    // No feature selection, hence features same for all partitions
45    // and can be stored in features_[0]
46    features_.resize(1);
47    features_[0].reserve(data.rows());
48    for (size_t i=0; i<data.rows(); ++i)
49      features_[0].push_back(i);
50
51    assert(training_data_.size()==size());
52    assert(training_target_.size()==size());
53    assert(validation_data_.size()==size());
54    assert(validation_target_.size()==size());
55  }
56
57
58  SubsetGenerator::SubsetGenerator(const Sampler& sampler, 
59                                   const DataLookup2D& data, 
60                                   FeatureSelector& fs)
61    : f_selector_(&fs), sampler_(sampler), state_(0), weighted_(false)
62  { 
63    assert(target().size()==data.columns());
64
65    features_.reserve(size());
66    training_data_.reserve(size());
67    validation_data_.reserve(size());
68
69    // Taking care of three different case.
70    // We start with the case of MatrixLookup
71    const MatrixLookup* ml = dynamic_cast<const MatrixLookup*>(&data);
72    if (ml){
73      for (reset(); more(); next()){
74     
75        // training data with no feature selection
76        const MatrixLookup* train_data_all_feat = 
77          ml->training_data(training_index());
78        // use these data to create feature selection
79        f_selector_->update(*train_data_all_feat, training_target());
80        // get features
81        features_.push_back(f_selector_->features());
82        delete train_data_all_feat;
83       
84        // Dynamically allocated. Must be deleted in destructor.
85        training_data_.push_back(ml->training_data(features_.back(), 
86                                                    training_index()));
87        validation_data_.push_back(ml->validation_data(features_.back(),
88                                                        training_index(), 
89                                                        validation_index()));
90
91        training_target_.push_back(Target(target(),training_index()));
92        validation_target_.push_back(Target(target(),validation_index()));
93      }
94    }
95    else {
96      // Second the case of MatrixLookupWeighted
97      const MatrixLookupWeighted* ml = 
98        dynamic_cast<const MatrixLookupWeighted*>(&data);
99      if (ml){
100        for (reset(); more(); next()){
101     
102          // training data with no feature selection
103          const MatrixLookupWeighted* train_data_all_feat = 
104            ml->training_data(training_index());
105          // use these data to create feature selection
106          f_selector_->update(*train_data_all_feat, training_target());
107          // get features
108          features_.push_back(f_selector_->features());
109          delete train_data_all_feat;
110         
111          // Dynamically allocated. Must be deleted in destructor.
112          training_data_.push_back(ml->training_data(features_.back(), 
113                                                     training_index()));
114          validation_data_.push_back(ml->validation_data(features_.back(),
115                                                         training_index(), 
116                                                         validation_index()));
117         
118          training_target_.push_back(Target(target(),training_index()));
119          validation_target_.push_back(Target(target(),validation_index()));
120        }
121      }
122      else {
123        // Third the case of MatrixLookupWeighted
124        const KernelLookup* kernel = dynamic_cast<const KernelLookup*>(&data);
125        if (kernel){
126          for (reset(); more(); next()){
127            const DataLookup2D* matrix = kernel->data();
128            // dynamically allocated must be deleted
129            const DataLookup2D* training_matrix = 
130              matrix->training_data(training_index());
131           
132            if (matrix->weighted()){
133              const MatrixLookupWeighted& ml = 
134                dynamic_cast<const MatrixLookupWeighted&>(*matrix);
135              f_selector_->update(MatrixLookupWeighted(ml,training_index(), 
136                                                       false), 
137                                  training_target());
138            }
139            else {
140              const MatrixLookup& ml = 
141                dynamic_cast<const MatrixLookup&>(*matrix);
142              f_selector_->update(MatrixLookup(ml,training_index(), false), 
143                                  training_target());
144            } 
145           
146            features_.push_back(f_selector_->features());
147            const KernelLookup* kl = kernel->selected(features_.back());
148            assert(training_matrix);
149            delete training_matrix;
150                     
151            // Dynamically allocated. Must be deleted in destructor.
152            training_data_.push_back(kl->training_data(features_.back(), 
153                                                       training_index()));
154            validation_data_.push_back(kl->validation_data(features_.back(),
155                                                           training_index(), 
156                                                           validation_index()));
157           
158            training_target_.push_back(Target(target(),training_index()));
159            validation_target_.push_back(Target(target(),validation_index()));
160            assert(kl);
161            delete kl;
162          }
163        }
164        else {
165        std::cerr << "Sorry, your type of DataLookup2D (" 
166                  << typeid(data).name() << ")\nis not supported in " 
167                  << "SubsetGenerator with\nFeatureSelection\n";
168        exit(-1);
169        }
170      }
171    }
172    assert(training_data_.size()==size());
173    assert(training_target_.size()==size());
174    assert(validation_data_.size()==size());
175    assert(validation_target_.size()==size());
176    reset();
177  }
178
179
180  SubsetGenerator::~SubsetGenerator()
181  {
182    assert(training_data_.size()==validation_data_.size());
183    for (size_t i=0; i<training_data_.size(); i++) 
184      delete training_data_[i];
185    for (size_t i=0; i<validation_data_.size(); i++) 
186      delete validation_data_[i];
187  }
188
189}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.