source: trunk/yat/classifier/SubsetGenerator.cc @ 704

Last change on this file since 704 was 704, checked in by Markus Ringnér, 15 years ago

Fixes #104. Also fixed inline bug in Averager.h

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.3 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "SubsetGenerator.h"
25#include "DataLookup2D.h"
26#include "FeatureSelector.h"
27#include "KernelLookup.h"
28#include "MatrixLookup.h"
29#include "MatrixLookupWeighted.h"
30#include "Target.h"
31
32#include <algorithm>
33#include <cassert>
34#include <utility>
35#include <typeinfo>
36#include <vector>
37
38namespace theplu {
39namespace yat {
40namespace classifier { 
41
42  SubsetGenerator::SubsetGenerator(const Sampler& sampler, 
43                                   const DataLookup2D& data)
44    : f_selector_(NULL), sampler_(sampler), weighted_(false)
45  { 
46    assert(target().size()==data.columns());
47
48    training_data_.reserve(sampler_.size());
49    validation_data_.reserve(sampler_.size());
50    for (size_t i=0; i<sampler_.size(); ++i){
51      // Dynamically allocated. Must be deleted in destructor.
52      training_data_.push_back(data.training_data(sampler.training_index(i)));
53      validation_data_.push_back(data.validation_data(sampler.training_index(i),
54                                                      sampler.validation_index(i)));
55
56      training_target_.push_back(Target(target(),sampler.training_index(i)));
57      validation_target_.push_back(Target(target(),
58                                          sampler.validation_index(i)));
59      assert(training_data_.size()==i+1);
60      assert(training_target_.size()==i+1);
61      assert(validation_data_.size()==i+1);
62      assert(validation_target_.size()==i+1);
63    }
64
65    // No feature selection, hence features same for all partitions
66    // and can be stored in features_[0]
67    features_.resize(1);
68    features_[0].reserve(data.rows());
69    for (size_t i=0; i<data.rows(); ++i)
70      features_[0].push_back(i);
71
72    assert(training_data_.size()==size());
73    assert(training_target_.size()==size());
74    assert(validation_data_.size()==size());
75    assert(validation_target_.size()==size());
76  }
77
78
79  SubsetGenerator::SubsetGenerator(const Sampler& sampler, 
80                                   const DataLookup2D& data, 
81                                   FeatureSelector& fs)
82    : f_selector_(&fs), sampler_(sampler), weighted_(false)
83  { 
84    assert(target().size()==data.columns());
85
86    features_.reserve(size());
87    training_data_.reserve(size());
88    validation_data_.reserve(size());
89
90    // Taking care of three different case.
91    // We start with the case of MatrixLookup
92    const MatrixLookup* ml = dynamic_cast<const MatrixLookup*>(&data);
93    if (ml){
94      for (size_t k=0; k<size(); k++){
95     
96        training_target_.push_back(Target(target(),training_index(k)));
97        validation_target_.push_back(Target(target(),validation_index(k)));
98        // training data with no feature selection
99        const MatrixLookup* train_data_all_feat = 
100          ml->training_data(training_index(k));
101        // use these data to create feature selection
102        assert(train_data_all_feat);
103        f_selector_->update(*train_data_all_feat, training_target(k));
104        // get features
105        features_.push_back(f_selector_->features());
106        assert(train_data_all_feat);
107        delete train_data_all_feat;
108       
109        // Dynamically allocated. Must be deleted in destructor.
110        training_data_.push_back(new MatrixLookup(*ml,features_.back(), 
111                                                  training_index(k)));
112        validation_data_.push_back(new MatrixLookup(*ml,features_.back(), 
113                                                    validation_index(k)));     
114      }
115    }
116    else {
117      // Second the case of MatrixLookupWeighted
118      const MatrixLookupWeighted* ml = 
119        dynamic_cast<const MatrixLookupWeighted*>(&data);
120      if (ml){       
121        for (u_long k=0; k<size(); k++){
122          training_target_.push_back(Target(target(),training_index(k)));
123          validation_target_.push_back(Target(target(),validation_index(k)));
124          // training data with no feature selection
125          const MatrixLookupWeighted* train_data_all_feat = 
126            ml->training_data(training_index(k));
127          // use these data to create feature selection
128          f_selector_->update(*train_data_all_feat, training_target(k));
129          // get features
130          features_.push_back(f_selector_->features());
131          delete train_data_all_feat;
132         
133          // Dynamically allocated. Must be deleted in destructor.
134          training_data_.push_back(new MatrixLookupWeighted(*ml,
135                                                            features_.back(), 
136                                                            training_index(k)
137                                                            ));
138          validation_data_.push_back(new MatrixLookupWeighted(*ml,
139                                                              features_.back(), 
140                                                              validation_index(k)
141                                                              ));     
142        }
143      }
144      else {
145        // Third the case of MatrixLookupWeighted
146        const KernelLookup* kernel = dynamic_cast<const KernelLookup*>(&data);
147        if (kernel){
148          for (u_long k=0; k<size(); k++){
149            training_target_.push_back(Target(target(),training_index(k)));
150            validation_target_.push_back(Target(target(),validation_index(k)));
151            const DataLookup2D* matrix = kernel->data();
152            // dynamically allocated must be deleted
153            const DataLookup2D* training_matrix = 
154              matrix->training_data(training_index(k));
155            if (matrix->weighted()){
156              const MatrixLookupWeighted& ml = 
157                dynamic_cast<const MatrixLookupWeighted&>(*matrix);
158              f_selector_->update(MatrixLookupWeighted(ml,training_index(k),false), 
159                                  training_target(k));
160            }
161            else {
162              const MatrixLookup& ml = 
163                dynamic_cast<const MatrixLookup&>(*matrix);
164              f_selector_->update(MatrixLookup(ml,training_index(k), false), 
165                                  training_target(k));
166            } 
167            std::vector<size_t> dummie=f_selector_->features();
168            features_.push_back(dummie);
169            //features_.push_back(f_selector_->features());
170            assert(kernel);
171            const KernelLookup* kl = kernel->selected(features_.back());
172            assert(training_matrix);
173            delete training_matrix;
174                     
175            // Dynamically allocated. Must be deleted in destructor.
176            training_data_.push_back(kl->training_data(training_index(k)));
177            validation_data_.push_back(kl->validation_data(training_index(k), 
178                                                           validation_index(k)));
179            assert(kl);
180            delete kl;
181          }
182        }
183        else {
184        std::cerr << "Sorry, your type of DataLookup2D (" 
185                  << typeid(data).name() << ")\nis not supported in " 
186                  << "SubsetGenerator with\nFeatureSelection\n";
187        exit(-1);
188        }
189      }
190    }
191    assert(training_data_.size()==size());
192    assert(training_target_.size()==size());
193    assert(validation_data_.size()==size());
194    assert(validation_target_.size()==size());
195  }
196
197
198  SubsetGenerator::~SubsetGenerator()
199  {
200    assert(training_data_.size()==validation_data_.size());
201    for (size_t i=0; i<training_data_.size(); i++) 
202      delete training_data_[i];
203    for (size_t i=0; i<validation_data_.size(); i++) 
204      delete validation_data_[i];
205  }
206
207}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.