source: trunk/yat/classifier/EnsembleBuilder.h @ 2077

Last change on this file since 2077 was 2077, checked in by Peter, 14 years ago

refs #567 using YAT_ASSERT in header rather than assert

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.8 KB
Line 
1#ifndef _theplu_yat_classifier_ensemblebuilder_
2#define _theplu_yat_classifier_ensemblebuilder_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér
8  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
9  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
10  Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
11  Copyright (C) 2009 Jari Häkkinen
12
13  This file is part of the yat library, http://dev.thep.lu.se/yat
14
15  The yat library is free software; you can redistribute it and/or
16  modify it under the terms of the GNU General Public License as
17  published by the Free Software Foundation; either version 3 of the
18  License, or (at your option) any later version.
19
20  The yat library is distributed in the hope that it will be useful,
21  but WITHOUT ANY WARRANTY; without even the implied warranty of
22  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
23  General Public License for more details.
24
25  You should have received a copy of the GNU General Public License
26  along with yat. If not, see <http://www.gnu.org/licenses/>.
27*/
28
29#include "FeatureSelector.h"
30#include "Sampler.h"
31#include "SubsetGenerator.h"
32#include "yat/statistics/Averager.h"
33#include "yat/utility/Matrix.h"
34#include "yat/utility/yat_assert.h"
35
36#include <vector>
37
38namespace theplu {
39namespace yat {
40namespace classifier { 
41
42  ///
43  /// @brief Class for ensembles of supervised classifiers
44  ///
45  template <class Classifier, class Data>
46  class EnsembleBuilder
47  {
48  public:
49    /**
50       \brief Type of classifier that ensemble is built on.
51     */
52    typedef Classifier classifier_type;
53
54    /**
55       Type of container used for storing data. Must be MatrixLookup,
56       MatrixLookupWeighted, or KernelLookup
57     */
58    typedef Data data_type;
59
60    ///
61    /// Constructor.
62    ///
63    EnsembleBuilder(const Classifier&, const Data&, const Sampler&);
64
65    ///
66    /// Constructor.
67    ///
68    EnsembleBuilder(const Classifier&, const Data&, const Sampler&, 
69                    FeatureSelector&);
70
71    ///
72    /// Destructor.
73    ///
74    virtual ~EnsembleBuilder(void);
75
76    /**
77       \brief Generate ensemble.
78       
79       Function trains each member of the Ensemble.
80    */
81    void build(void);
82
83    ///
84    /// @return ith classifier
85    ///
86    const Classifier& classifier(size_t i) const;
87     
88    ///
89    /// @return Number of classifiers in ensemble. Prior build(void)
90    /// is issued size is zero.
91    ///
92    unsigned long size(void) const;
93
94    ///
95    /// @brief Generate validation data for ensemble
96    ///
97    /// validate()[i][j] return averager for class @a i for sample @a j
98    ///
99    const std::vector<std::vector<statistics::Averager> >& validate(void);
100   
101    /**
102       Predict a dataset using the ensemble.
103       
104       If @a data is a KernelLookup each column should correspond to a
105       test sample and each row should correspond to a training
106       sample. More exactly row \f$ i \f$ in @a data should correspond
107       to the same sample as row/column \f$ i \f$ in the training
108       kernel corresponds to.
109    */
110    void predict(const Data& data, 
111                 std::vector<std::vector<statistics::Averager> > &);
112
113  private:
114    // no copying
115    EnsembleBuilder(const EnsembleBuilder&);
116    const EnsembleBuilder& operator=(const EnsembleBuilder&);
117   
118
119    const Classifier& mother_;
120    SubsetGenerator<Data>* subset_;
121    std::vector<Classifier*> classifier_;
122    KernelLookup test_data(const KernelLookup&, size_t k);
123    MatrixLookup test_data(const MatrixLookup&, size_t k);
124    MatrixLookupWeighted test_data(const MatrixLookupWeighted&, size_t k);
125    std::vector<std::vector<statistics::Averager> > validation_result_;
126
127  };
128 
129
130  // implementation
131
132  template <class Classifier, class Data>
133  EnsembleBuilder<Classifier, Data>::EnsembleBuilder(const Classifier& sc,
134                                                     const Data& data,
135                                                     const Sampler& sampler)
136    : mother_(sc),subset_(new SubsetGenerator<Data>(sampler,data))
137  {
138  }
139
140
141  template <class Classifier, class Data>
142  EnsembleBuilder<Classifier, Data>::EnsembleBuilder(const Classifier& sc,
143                                                     const Data& data,
144                                                     const Sampler& sampler,
145                                                     FeatureSelector& fs)
146    : mother_(sc),
147      subset_(new SubsetGenerator<Data>(sampler,data,fs))
148  {
149  }
150
151
152  template <class Classifier, class Data>
153  EnsembleBuilder<Classifier, Data>::~EnsembleBuilder(void)
154  {
155    for(size_t i=0; i<classifier_.size(); i++)
156      delete classifier_[i];
157    delete subset_;
158  }
159
160
161  template <class Classifier, class Data>
162  void EnsembleBuilder<Classifier, Data>::build(void)
163  {
164    if (classifier_.empty()){
165      for(unsigned long i=0; i<subset_->size();++i) {
166        Classifier* classifier = mother_.make_classifier();
167        classifier->train(subset_->training_data(i), 
168                          subset_->training_target(i));
169        classifier_.push_back(classifier);
170      }   
171    }
172  }
173
174
175  template <class Classifier, class Data>
176  const Classifier& EnsembleBuilder<Classifier, Data>::classifier(size_t i) const
177  {
178    return *(classifier_[i]);
179  }
180
181
182  template <class Classifier, class Data>
183  void EnsembleBuilder<Classifier, Data>::predict
184  (const Data& data, std::vector<std::vector<statistics::Averager> >& result)
185  {
186    result = std::vector<std::vector<statistics::Averager> >
187      (subset_->target().nof_classes(), 
188       std::vector<statistics::Averager>(data.columns()));
189   
190    utility::Matrix prediction; 
191
192    for(unsigned long k=0;k<size();++k) {       
193      Data sub_data = test_data(data, k);
194      classifier(k).predict(sub_data,prediction);
195    }
196
197    for(size_t i=0; i<prediction.rows();i++) 
198      for(size_t j=0; j<prediction.columns();j++) 
199        result[i][j].add(prediction(i,j));   
200  }
201
202 
203  template <class Classifier, class Data>
204  unsigned long EnsembleBuilder<Classifier, Data>::size(void) const
205  {
206    return classifier_.size();
207  }
208
209
210  template <class Classifier, class Data>
211  MatrixLookup EnsembleBuilder<Classifier,
212                               Data>::test_data(const MatrixLookup& data,
213                                                size_t k)
214  {
215    return MatrixLookup(data, subset_->training_features(k), true);
216  }
217 
218
219  template <class Classifier, class Data>
220  MatrixLookupWeighted
221  EnsembleBuilder<Classifier, Data>::test_data(const MatrixLookupWeighted& data,
222                                               size_t k)
223  {
224    return MatrixLookupWeighted(data, subset_->training_features(k), true);
225  }
226 
227
228  template <class Classifier, class Data>
229  KernelLookup
230  EnsembleBuilder<Classifier, Data>::test_data(const KernelLookup& kernel,
231                                               size_t k)
232  {
233    // weighted case
234    if (kernel.weighted()){
235      YAT_ASSERT(false);
236      // no feature selection
237      if (kernel.data_weighted().rows()==subset_->training_features(k).size())
238        return KernelLookup(kernel, subset_->training_index(k), true);
239      MatrixLookupWeighted mlw = test_data(kernel.data_weighted(), k);
240      return subset_->training_data(k).test_kernel(mlw);
241
242    }
243    // unweighted case
244
245    // no feature selection
246    if (kernel.data().rows()==subset_->training_features(k).size())
247      return KernelLookup(kernel, subset_->training_index(k), true);
248   
249    // feature selection
250    return subset_->training_data(k).test_kernel(test_data(kernel.data(),k));
251  }
252 
253
254  template <class Classifier, class Data>
255  const std::vector<std::vector<statistics::Averager> >& 
256  EnsembleBuilder<Classifier, Data>::validate(void)
257  {
258    // Don't recalculate validation_result_
259    if (!validation_result_.empty())
260      return validation_result_;
261
262    validation_result_ = std::vector<std::vector<statistics::Averager> >
263      (subset_->target().nof_classes(), 
264       std::vector<statistics::Averager>(subset_->target().size()));
265
266    utility::Matrix prediction; 
267    for(unsigned long k=0;k<size();k++) {
268      classifier(k).predict(subset_->validation_data(k),prediction);
269     
270      // map results to indices of samples in training + validation data set
271      for(size_t i=0; i<prediction.rows();i++) 
272        for(size_t j=0; j<prediction.columns();j++) {
273          validation_result_[i][subset_->validation_index(k)[j]].
274            add(prediction(i,j));
275        }           
276    }
277    return validation_result_;
278  }
279
280}}} // of namespace classifier, yat, and theplu
281
282#endif
Note: See TracBrowser for help on using the repository browser.