source: trunk/c++_tools/classifier/SubsetGenerator.cc @ 636

Last change on this file since 636 was 636, checked in by Peter, 15 years ago

closes #120 removed weights from SubsetGenerator?

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 6.6 KB
Line 
1// $Id$
2
3
4#include <c++_tools/classifier/SubsetGenerator.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/FeatureSelector.h>
7#include <c++_tools/classifier/KernelLookup.h>
8#include <c++_tools/classifier/MatrixLookup.h>
9#include <c++_tools/classifier/MatrixLookupWeighted.h>
10#include <c++_tools/classifier/Target.h>
11
12#include <algorithm>
13#include <cassert>
14#include <utility>
15#include <typeinfo>
16#include <vector>
17
18namespace theplu {
19namespace classifier { 
20
21  SubsetGenerator::SubsetGenerator(const Sampler& sampler, 
22                                   const DataLookup2D& data)
23    : f_selector_(NULL), sampler_(sampler), state_(0), weighted_(false)
24  { 
25    assert(target().size()==data.columns());
26
27    training_data_.reserve(sampler_.size());
28    validation_data_.reserve(sampler_.size());
29    for (size_t i=0; i<sampler_.size(); ++i){
30      // Dynamically allocated. Must be deleted in destructor.
31      training_data_.push_back(data.training_data(sampler.training_index(i)));
32      validation_data_.push_back(data.validation_data(sampler.training_index(i),
33                                                      sampler.validation_index(i)));
34
35      training_target_.push_back(Target(target(),sampler.training_index(i)));
36      validation_target_.push_back(Target(target(),
37                                          sampler.validation_index(i)));
38      assert(training_data_.size()==i+1);
39      assert(training_target_.size()==i+1);
40      assert(validation_data_.size()==i+1);
41      assert(validation_target_.size()==i+1);
42    }
43
44    // No feature selection, hence features same for all partitions
45    // and can be stored in features_[0]
46    features_.resize(1);
47    features_[0].reserve(data.rows());
48    for (size_t i=0; i<data.rows(); ++i)
49      features_[0].push_back(i);
50
51    assert(training_data_.size()==size());
52    assert(training_target_.size()==size());
53    assert(validation_data_.size()==size());
54    assert(validation_target_.size()==size());
55  }
56
57
58  SubsetGenerator::SubsetGenerator(const Sampler& sampler, 
59                                   const DataLookup2D& data, 
60                                   FeatureSelector& fs)
61    : f_selector_(&fs), sampler_(sampler), state_(0), weighted_(false)
62  { 
63    assert(target().size()==data.columns());
64
65    features_.reserve(size());
66    training_data_.reserve(size());
67    validation_data_.reserve(size());
68
69    // Taking care of three different case.
70    // We start with the case of MatrixLookup
71    const MatrixLookup* ml = dynamic_cast<const MatrixLookup*>(&data);
72    if (ml){
73      for (reset(); more(); next()){
74     
75        // training data with no feature selection
76        const MatrixLookup* train_data_all_feat = 
77          ml->training_data(training_index());
78        // use these data to create feature selection
79        f_selector_->update(*train_data_all_feat, training_target());
80        // get features
81        features_.push_back(f_selector_->features());
82        delete train_data_all_feat;
83       
84        // Dynamically allocated. Must be deleted in destructor.
85        training_data_.push_back(ml->training_data(features_.back(), 
86                                                    training_index()));
87        validation_data_.push_back(ml->validation_data(features_.back(),
88                                                        training_index(), 
89                                                        validation_index()));
90
91        training_target_.push_back(Target(target(),training_index()));
92        validation_target_.push_back(Target(target(),validation_index()));
93      }
94    }
95    else {
96      // Second the case of MatrixLookupWeighted
97      const MatrixLookupWeighted* ml = 
98        dynamic_cast<const MatrixLookupWeighted*>(&data);
99      if (ml){
100        for (reset(); more(); next()){
101     
102          // training data with no feature selection
103          const MatrixLookupWeighted* train_data_all_feat = 
104            ml->training_data(training_index());
105          // use these data to create feature selection
106          f_selector_->update(*train_data_all_feat, training_target());
107          // get features
108          features_.push_back(f_selector_->features());
109          delete train_data_all_feat;
110         
111          // Dynamically allocated. Must be deleted in destructor.
112          training_data_.push_back(ml->training_data(features_.back(), 
113                                                     training_index()));
114          validation_data_.push_back(ml->validation_data(features_.back(),
115                                                         training_index(), 
116                                                         validation_index()));
117         
118          training_target_.push_back(Target(target(),training_index()));
119          validation_target_.push_back(Target(target(),validation_index()));
120        }
121      }
122      else {
123        // Third the case of MatrixLookupWeighted
124        const KernelLookup* kernel =  dynamic_cast<const KernelLookup*>(&data);
125        if (kernel){
126          for (reset(); more(); next()){
127            const KernelLookup* kl=NULL;
128            if (kernel->weighted()){
129              std::cerr << "Feature selection with weighted Kernel not " 
130                        << "implemented.\nPlease see http://lev.thep.lu."
131                        << "se/trac/c++_tools/ticket/116\n";
132              exit(-1);
133            }
134            else {
135              const DataLookup2D* matrix = kernel->data();
136              const DataLookup2D* training_matrix = 
137                matrix->training_data(training_index());
138              if (kernel->weighted()){
139                const MatrixLookupWeighted& ml = 
140                  dynamic_cast<const MatrixLookupWeighted&>(*training_matrix);
141                f_selector_->update(ml, training_target());
142              }
143              else {
144                const MatrixLookup& ml = 
145                  dynamic_cast<const MatrixLookup&>(*training_matrix);
146                f_selector_->update(ml, training_target());
147              } 
148
149              features_.push_back(f_selector_->features());
150              kl = kernel->selected(features_.back());
151              delete matrix;
152              delete training_matrix;
153            }
154           
155            // Dynamically allocated. Must be deleted in destructor.
156            training_data_.push_back(kl->training_data(features_.back(), 
157                                                       training_index()));
158            validation_data_.push_back(kl->validation_data(features_.back(),
159                                                           training_index(), 
160                                                           validation_index()));
161           
162            training_target_.push_back(Target(target(),training_index()));
163            validation_target_.push_back(Target(target(),validation_index()));
164            if (kl)
165              delete kl;
166          }
167        }
168        else {
169        std::cerr << "Sorry, your type of DataLookup2D " << typeid(data).name() 
170                  << "is not supported in FeatureSelection\n";
171        exit(-1);
172        }
173      }
174    }
175    assert(training_data_.size()==size());
176    assert(training_target_.size()==size());
177    assert(validation_data_.size()==size());
178    assert(validation_target_.size()==size());
179    reset();
180  }
181
182
183  SubsetGenerator::~SubsetGenerator()
184  {
185    assert(training_data_.size()==validation_data_.size());
186    for (size_t i=0; i<training_data_.size(); i++) 
187      delete training_data_[i];
188    for (size_t i=0; i<validation_data_.size(); i++) 
189      delete validation_data_[i];
190  }
191
192}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.