source: trunk/yat/classifier/EnsembleBuilder.h @ 1227

Last change on this file since 1227 was 1227, checked in by Peter, 14 years ago

fixes #341 and #93

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.3 KB
Line 
1#ifndef _theplu_yat_classifier_ensemblebuilder_
2#define _theplu_yat_classifier_ensemblebuilder_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007, 2008 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "FeatureSelector.h"
30#include "Sampler.h"
31#include "SubsetGenerator.h"
32#include "yat/statistics/Averager.h"
33#include "yat/utility/Matrix.h"
34
35#include <vector>
36
37namespace theplu {
38namespace yat {
39namespace classifier { 
40
41  ///
42  /// @brief Class for ensembles of supervised classifiers
43  ///
44  template <class Classifier, class Data>
45  class EnsembleBuilder
46  {
47  public:
48    /**
49       \brief Type of classifier that ensemble is built on.
50     */
51    typedef Classifier classifier_type;
52
53    /**
54       Type of container used for storing data. Must be MatrixLookup,
55       MatrixLookupWeighted, or KernelLookup
56     */
57    typedef Data data_type;
58
59    ///
60    /// Constructor.
61    ///
62    EnsembleBuilder(const Classifier&, const Data&, const Sampler&);
63
64    ///
65    /// Constructor.
66    ///
67    EnsembleBuilder(const Classifier&, const Data&, const Sampler&, 
68                    FeatureSelector&);
69
70    ///
71    /// Destructor.
72    ///
73    virtual ~EnsembleBuilder(void);
74
75    /**
76       \brief Generate ensemble.
77       
78       Function trains each member of the Ensemble.
79    */
80    void build(void);
81
82    ///
83    /// @return ith classifier
84    ///
85    const Classifier& classifier(size_t i) const;
86     
87    ///
88    /// @return Number of classifiers in ensemble. Prior build(void)
89    /// is issued size is zero.
90    ///
91    u_long size(void) const;
92
93    ///
94    /// @brief Generate validation data for ensemble
95    ///
96    /// validate()[i][j] return averager for class @a i for sample @a j
97    ///
98    const std::vector<std::vector<statistics::Averager> >& validate(void);
99   
100    /**
101       Predict a dataset using the ensemble.
102       
103       If @a data is a KernelLookup each column should correspond to a
104       test sample and each row should correspond to a training
105       sample. More exactly row \f$ i \f$ in @a data should correspond
106       to the same sample as row/column \f$ i \f$ in the training
107       kernel corresponds to.
108    */
109    void predict(const Data& data, 
110                 std::vector<std::vector<statistics::Averager> > &);
111
112  private:
113    // no copying
114    EnsembleBuilder(const EnsembleBuilder&);
115    const EnsembleBuilder& operator=(const EnsembleBuilder&);
116   
117
118    const Classifier& mother_;
119    SubsetGenerator<Data>* subset_;
120    std::vector<Classifier*> classifier_;
121    KernelLookup test_data(const KernelLookup&, size_t k);
122    MatrixLookup test_data(const MatrixLookup&, size_t k);
123    MatrixLookupWeighted test_data(const MatrixLookupWeighted&, size_t k);
124    std::vector<std::vector<statistics::Averager> > validation_result_;
125
126  };
127 
128
129  // implementation
130
131  template <class C, class D> 
132  EnsembleBuilder<C,D>::EnsembleBuilder(const C& sc, const D& data,
133                                        const Sampler& sampler) 
134    : mother_(sc),subset_(new SubsetGenerator<D>(sampler,data))
135  {
136  }
137
138
139  template <class C, class D> 
140  EnsembleBuilder<C, D>::EnsembleBuilder(const C& sc, const D& data, 
141                                         const Sampler& sampler,
142                                         FeatureSelector& fs) 
143    : mother_(sc),
144      subset_(new SubsetGenerator<D>(sampler,data,fs))
145  {
146  }
147
148
149  template <class C, class D> 
150  EnsembleBuilder<C, D>::~EnsembleBuilder(void) 
151  {
152    for(size_t i=0; i<classifier_.size(); i++)
153      delete classifier_[i];
154    delete subset_;
155  }
156
157
158  template <class C, class D> 
159  void EnsembleBuilder<C, D>::build(void) 
160  {
161    if (classifier_.empty()){
162      for(u_long i=0; i<subset_->size();++i) {
163        C* classifier = mother_.make_classifier();
164        classifier->train(subset_->training_data(i), 
165                          subset_->training_target(i));
166        classifier_.push_back(classifier);
167      }   
168    }
169  }
170
171
172  template <class C, class D> 
173  const C& EnsembleBuilder<C, D>::classifier(size_t i) const
174  {
175    return *(classifier_[i]);
176  }
177
178
179  template <class C, class D> 
180  void EnsembleBuilder<C, D>::predict
181  (const D& data, std::vector<std::vector<statistics::Averager> >& result)
182  {
183    result = std::vector<std::vector<statistics::Averager> >
184      (subset_->target().nof_classes(), 
185       std::vector<statistics::Averager>(data.columns()));
186   
187    utility::Matrix prediction; 
188
189    for(u_long k=0;k<size();++k) {       
190      D sub_data =  test_data(data, k);
191      classifier(k).predict(sub_data,prediction);
192    }
193
194    for(size_t i=0; i<prediction.rows();i++) 
195      for(size_t j=0; j<prediction.columns();j++) 
196        result[i][j].add(prediction(i,j));   
197  }
198
199 
200  template <class C, class D> 
201  u_long EnsembleBuilder<C, D>::size(void) const
202  {
203    return classifier_.size();
204  }
205
206
207  template <class C, class D> 
208  MatrixLookup EnsembleBuilder<C, D>::test_data(const MatrixLookup& data, 
209                                                size_t k)
210  {
211    return MatrixLookup(data, subset_->training_features(k), true);
212  }
213 
214
215  template <class C, class D> 
216  MatrixLookupWeighted
217  EnsembleBuilder<C, D>::test_data(const MatrixLookupWeighted& data, size_t k)
218  {
219    return MatrixLookupWeighted(data, subset_->training_features(k), true);
220  }
221 
222
223  template <class C, class D> 
224  KernelLookup
225  EnsembleBuilder<C, D>::test_data(const KernelLookup& kernel, size_t k)
226  {
227    // weighted case
228    if (kernel.weighted()){
229      assert(false);
230      // no feature selection
231      if (kernel.data_weighted().rows()==subset_->training_features(k).size())
232        return KernelLookup(kernel, subset_->training_index(k), true);
233      MatrixLookupWeighted mlw = test_data(kernel.data_weighted(), k);
234      return subset_->training_data(k).test_kernel(mlw);
235
236    }
237    // unweighted case
238
239    // no feature selection
240    if (kernel.data().rows()==subset_->training_features(k).size())
241      return KernelLookup(kernel, subset_->training_index(k), true);
242   
243    // feature selection
244    return subset_->training_data(k).test_kernel(test_data(kernel.data(),k));
245  }
246 
247
248  template <class C, class D> 
249  const std::vector<std::vector<statistics::Averager> >& 
250  EnsembleBuilder<C, D>::validate(void)
251  {
252    // Don't recalculate validation_result_
253    if (!validation_result_.empty())
254      return validation_result_;
255
256    validation_result_ = std::vector<std::vector<statistics::Averager> >
257      (subset_->target().nof_classes(), 
258       std::vector<statistics::Averager>(subset_->target().size()));
259
260    utility::Matrix prediction; 
261    for(u_long k=0;k<size();k++) {
262      classifier(k).predict(subset_->validation_data(k),prediction);
263     
264      // map results to indices of samples in training + validation data set
265      for(size_t i=0; i<prediction.rows();i++) 
266        for(size_t j=0; j<prediction.columns();j++) {
267          validation_result_[i][subset_->validation_index(k)[j]].
268            add(prediction(i,j));
269        }           
270    }
271    return validation_result_;
272  }
273
274}}} // of namespace classifier, yat, and theplu
275
276#endif
Note: See TracBrowser for help on using the repository browser.