source: trunk/yat/classifier/EnsembleBuilder.h @ 1954

Last change on this file since 1954 was 1954, checked in by Jari Häkkinen, 14 years ago

Merged patch release 0.5.3 to the trunk. Delta 0.5.3 - 0.5.2

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.8 KB
Line 
1#ifndef _theplu_yat_classifier_ensemblebuilder_
2#define _theplu_yat_classifier_ensemblebuilder_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér
8  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
9  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
10  Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
11  Copyright (C) 2009 Jari Häkkinen
12
13  This file is part of the yat library, http://dev.thep.lu.se/yat
14
15  The yat library is free software; you can redistribute it and/or
16  modify it under the terms of the GNU General Public License as
17  published by the Free Software Foundation; either version 3 of the
18  License, or (at your option) any later version.
19
20  The yat library is distributed in the hope that it will be useful,
21  but WITHOUT ANY WARRANTY; without even the implied warranty of
22  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
23  General Public License for more details.
24
25  You should have received a copy of the GNU General Public License
26  along with yat. If not, see <http://www.gnu.org/licenses/>.
27*/
28
29#include "FeatureSelector.h"
30#include "Sampler.h"
31#include "SubsetGenerator.h"
32#include "yat/statistics/Averager.h"
33#include "yat/utility/Matrix.h"
34
35#include <vector>
36
37namespace theplu {
38namespace yat {
39namespace classifier { 
40
41  ///
42  /// @brief Class for ensembles of supervised classifiers
43  ///
44  template <class Classifier, class Data>
45  class EnsembleBuilder
46  {
47  public:
48    /**
49       \brief Type of classifier that ensemble is built on.
50     */
51    typedef Classifier classifier_type;
52
53    /**
54       Type of container used for storing data. Must be MatrixLookup,
55       MatrixLookupWeighted, or KernelLookup
56     */
57    typedef Data data_type;
58
59    ///
60    /// Constructor.
61    ///
62    EnsembleBuilder(const Classifier&, const Data&, const Sampler&);
63
64    ///
65    /// Constructor.
66    ///
67    EnsembleBuilder(const Classifier&, const Data&, const Sampler&, 
68                    FeatureSelector&);
69
70    ///
71    /// Destructor.
72    ///
73    virtual ~EnsembleBuilder(void);
74
75    /**
76       \brief Generate ensemble.
77       
78       Function trains each member of the Ensemble.
79    */
80    void build(void);
81
82    ///
83    /// @return ith classifier
84    ///
85    const Classifier& classifier(size_t i) const;
86     
87    ///
88    /// @return Number of classifiers in ensemble. Prior build(void)
89    /// is issued size is zero.
90    ///
91    unsigned long size(void) const;
92
93    ///
94    /// @brief Generate validation data for ensemble
95    ///
96    /// validate()[i][j] return averager for class @a i for sample @a j
97    ///
98    const std::vector<std::vector<statistics::Averager> >& validate(void);
99   
100    /**
101       Predict a dataset using the ensemble.
102       
103       If @a data is a KernelLookup each column should correspond to a
104       test sample and each row should correspond to a training
105       sample. More exactly row \f$ i \f$ in @a data should correspond
106       to the same sample as row/column \f$ i \f$ in the training
107       kernel corresponds to.
108    */
109    void predict(const Data& data, 
110                 std::vector<std::vector<statistics::Averager> > &);
111
112  private:
113    // no copying
114    EnsembleBuilder(const EnsembleBuilder&);
115    const EnsembleBuilder& operator=(const EnsembleBuilder&);
116   
117
118    const Classifier& mother_;
119    SubsetGenerator<Data>* subset_;
120    std::vector<Classifier*> classifier_;
121    KernelLookup test_data(const KernelLookup&, size_t k);
122    MatrixLookup test_data(const MatrixLookup&, size_t k);
123    MatrixLookupWeighted test_data(const MatrixLookupWeighted&, size_t k);
124    std::vector<std::vector<statistics::Averager> > validation_result_;
125
126  };
127 
128
129  // implementation
130
131  template <class Classifier, class Data>
132  EnsembleBuilder<Classifier, Data>::EnsembleBuilder(const Classifier& sc,
133                                                     const Data& data,
134                                                     const Sampler& sampler)
135    : mother_(sc),subset_(new SubsetGenerator<Data>(sampler,data))
136  {
137  }
138
139
140  template <class Classifier, class Data>
141  EnsembleBuilder<Classifier, Data>::EnsembleBuilder(const Classifier& sc,
142                                                     const Data& data,
143                                                     const Sampler& sampler,
144                                                     FeatureSelector& fs)
145    : mother_(sc),
146      subset_(new SubsetGenerator<Data>(sampler,data,fs))
147  {
148  }
149
150
151  template <class Classifier, class Data>
152  EnsembleBuilder<Classifier, Data>::~EnsembleBuilder(void)
153  {
154    for(size_t i=0; i<classifier_.size(); i++)
155      delete classifier_[i];
156    delete subset_;
157  }
158
159
160  template <class Classifier, class Data>
161  void EnsembleBuilder<Classifier, Data>::build(void)
162  {
163    if (classifier_.empty()){
164      for(unsigned long i=0; i<subset_->size();++i) {
165        Classifier* classifier = mother_.make_classifier();
166        classifier->train(subset_->training_data(i), 
167                          subset_->training_target(i));
168        classifier_.push_back(classifier);
169      }   
170    }
171  }
172
173
174  template <class Classifier, class Data>
175  const Classifier& EnsembleBuilder<Classifier, Data>::classifier(size_t i) const
176  {
177    return *(classifier_[i]);
178  }
179
180
181  template <class Classifier, class Data>
182  void EnsembleBuilder<Classifier, Data>::predict
183  (const Data& data, std::vector<std::vector<statistics::Averager> >& result)
184  {
185    result = std::vector<std::vector<statistics::Averager> >
186      (subset_->target().nof_classes(), 
187       std::vector<statistics::Averager>(data.columns()));
188   
189    utility::Matrix prediction; 
190
191    for(unsigned long k=0;k<size();++k) {       
192      Data sub_data = test_data(data, k);
193      classifier(k).predict(sub_data,prediction);
194    }
195
196    for(size_t i=0; i<prediction.rows();i++) 
197      for(size_t j=0; j<prediction.columns();j++) 
198        result[i][j].add(prediction(i,j));   
199  }
200
201 
202  template <class Classifier, class Data>
203  unsigned long EnsembleBuilder<Classifier, Data>::size(void) const
204  {
205    return classifier_.size();
206  }
207
208
209  template <class Classifier, class Data>
210  MatrixLookup EnsembleBuilder<Classifier,
211                               Data>::test_data(const MatrixLookup& data,
212                                                size_t k)
213  {
214    return MatrixLookup(data, subset_->training_features(k), true);
215  }
216 
217
218  template <class Classifier, class Data>
219  MatrixLookupWeighted
220  EnsembleBuilder<Classifier, Data>::test_data(const MatrixLookupWeighted& data,
221                                               size_t k)
222  {
223    return MatrixLookupWeighted(data, subset_->training_features(k), true);
224  }
225 
226
227  template <class Classifier, class Data>
228  KernelLookup
229  EnsembleBuilder<Classifier, Data>::test_data(const KernelLookup& kernel,
230                                               size_t k)
231  {
232    // weighted case
233    if (kernel.weighted()){
234      assert(false);
235      // no feature selection
236      if (kernel.data_weighted().rows()==subset_->training_features(k).size())
237        return KernelLookup(kernel, subset_->training_index(k), true);
238      MatrixLookupWeighted mlw = test_data(kernel.data_weighted(), k);
239      return subset_->training_data(k).test_kernel(mlw);
240
241    }
242    // unweighted case
243
244    // no feature selection
245    if (kernel.data().rows()==subset_->training_features(k).size())
246      return KernelLookup(kernel, subset_->training_index(k), true);
247   
248    // feature selection
249    return subset_->training_data(k).test_kernel(test_data(kernel.data(),k));
250  }
251 
252
253  template <class Classifier, class Data>
254  const std::vector<std::vector<statistics::Averager> >& 
255  EnsembleBuilder<Classifier, Data>::validate(void)
256  {
257    // Don't recalculate validation_result_
258    if (!validation_result_.empty())
259      return validation_result_;
260
261    validation_result_ = std::vector<std::vector<statistics::Averager> >
262      (subset_->target().nof_classes(), 
263       std::vector<statistics::Averager>(subset_->target().size()));
264
265    utility::Matrix prediction; 
266    for(unsigned long k=0;k<size();k++) {
267      classifier(k).predict(subset_->validation_data(k),prediction);
268     
269      // map results to indices of samples in training + validation data set
270      for(size_t i=0; i<prediction.rows();i++) 
271        for(size_t j=0; j<prediction.columns();j++) {
272          validation_result_[i][subset_->validation_index(k)[j]].
273            add(prediction(i,j));
274        }           
275    }
276    return validation_result_;
277  }
278
279}}} // of namespace classifier, yat, and theplu
280
281#endif
Note: See TracBrowser for help on using the repository browser.