source: trunk/yat/classifier/EnsembleBuilder.h @ 1206

Last change on this file since 1206 was 1206, checked in by Peter, 14 years ago

fixes #345

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.2 KB
Line 
1#ifndef _theplu_yat_classifier_ensemblebuilder_
2#define _theplu_yat_classifier_ensemblebuilder_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007, 2008 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "FeatureSelector.h"
30#include "Sampler.h"
31#include "SubsetGenerator.h"
32#include "yat/statistics/Averager.h"
33#include "yat/utility/Matrix.h"
34
35#include <vector>
36
37namespace theplu {
38namespace yat {
39namespace classifier { 
40
41  ///
42  /// @brief Class for ensembles of supervised classifiers
43  ///
44  template <class Classifier, class Data>
45  class EnsembleBuilder
46  {
47  public:
48    /**
49       \brief Type of classifier that ensemble is built on.
50     */
51    typedef Classifier classifier_type;
52
53    /**
54       Type of container used for storing data e.g. MatrixLookup or KernelLookup
55     */
56    typedef Data data_type;
57
58    ///
59    /// Constructor.
60    ///
61    EnsembleBuilder(const Classifier&, const Data&, const Sampler&);
62
63    ///
64    /// Constructor.
65    ///
66    EnsembleBuilder(const Classifier&, const Data&, const Sampler&, 
67                    FeatureSelector&);
68
69    ///
70    /// Destructor.
71    ///
72    virtual ~EnsembleBuilder(void);
73
74    ///
75    /// Generate ensemble. Function trains each member of the Ensemble.
76    ///
77    void build(void);
78
79    ///
80    /// @return classifier
81    ///
82    const Classifier& classifier(size_t i) const;
83     
84    ///
85    /// @return Number of classifiers in ensemble
86    ///
87    u_long size(void) const;
88
89    ///
90    /// @brief Generate validation data for ensemble
91    ///
92    /// validate()[i][j] return averager for class @a i for sample @a j
93    ///
94    const std::vector<std::vector<statistics::Averager> >& validate(void);
95   
96    /**
97       Predict a dataset using the ensemble.
98       
99       If @a data is a KernelLookup each column should correspond to a
100       test sample and each row should correspond to a training
101       sample. More exactly row \f$ i \f$ in @a data should correspond
102       to the same sample as row/column \f$ i \f$ in the training
103       kernel corresponds to.
104    */
105    void predict(const Data& data, 
106                 std::vector<std::vector<statistics::Averager> > &);
107
108  private:
109    // no copying
110    EnsembleBuilder(const EnsembleBuilder&);
111    const EnsembleBuilder& operator=(const EnsembleBuilder&);
112   
113
114    const Classifier& mother_;
115    SubsetGenerator<Data>* subset_;
116    std::vector<Classifier*> classifier_;
117    KernelLookup test_data(const KernelLookup&, size_t k);
118    MatrixLookup test_data(const MatrixLookup&, size_t k);
119    MatrixLookupWeighted test_data(const MatrixLookupWeighted&, size_t k);
120    std::vector<std::vector<statistics::Averager> > validation_result_;
121
122  };
123 
124
125  // implementation
126
127  template <class C, class D> 
128  EnsembleBuilder<C,D>::EnsembleBuilder(const C& sc, const D& data,
129                                        const Sampler& sampler) 
130    : mother_(sc),subset_(new SubsetGenerator<D>(sampler,data))
131  {
132  }
133
134
135  template <class C, class D> 
136  EnsembleBuilder<C, D>::EnsembleBuilder(const C& sc, const D& data, 
137                                         const Sampler& sampler,
138                                         FeatureSelector& fs) 
139    : mother_(sc),
140      subset_(new SubsetGenerator<D>(sampler,data,fs))
141  {
142  }
143
144
145  template <class C, class D> 
146  EnsembleBuilder<C, D>::~EnsembleBuilder(void) 
147  {
148    for(size_t i=0; i<classifier_.size(); i++)
149      delete classifier_[i];
150    delete subset_;
151  }
152
153
154  template <class C, class D> 
155  void EnsembleBuilder<C, D>::build(void) 
156  {
157    for(u_long i=0; i<subset_->size();++i) {
158      C* classifier = mother_.make_classifier();
159      classifier->train(subset_->training_data(i), 
160                        subset_->training_target(i));
161      classifier_.push_back(classifier);
162    }   
163  }
164
165
166  template <class C, class D> 
167  const C& EnsembleBuilder<C, D>::classifier(size_t i) const
168  {
169    return *(classifier_[i]);
170  }
171
172
173  template <class C, class D> 
174  void EnsembleBuilder<C, D>::predict
175  (const D& data, std::vector<std::vector<statistics::Averager> >& result)
176  {
177    result.clear();
178    result.reserve(subset_->target().nof_classes());   
179    for(size_t i=0; i<subset_->target().nof_classes();i++)
180      result.push_back(std::vector<statistics::Averager>(data.columns()));
181   
182    utility::Matrix prediction; 
183
184    for(u_long k=0;k<subset_->size();++k) {       
185      D sub_data =  test_data(data, k);
186      classifier(k).predict(sub_data,prediction);
187    }
188
189    for(size_t i=0; i<prediction.rows();i++) 
190      for(size_t j=0; j<prediction.columns();j++) 
191        result[i][j].add(prediction(i,j));   
192  }
193
194 
195  template <class C, class D> 
196  u_long EnsembleBuilder<C, D>::size(void) const
197  {
198    return classifier_.size();
199  }
200
201
202  template <class C, class D> 
203  MatrixLookup EnsembleBuilder<C, D>::test_data(const MatrixLookup& data, 
204                                                size_t k)
205  {
206    return MatrixLookup(data, subset_->training_features(k), true);
207  }
208 
209
210  template <class C, class D> 
211  MatrixLookupWeighted
212  EnsembleBuilder<C, D>::test_data(const MatrixLookupWeighted& data, size_t k)
213  {
214    return MatrixLookupWeighted(data, subset_->training_features(k), true);
215  }
216 
217
218  template <class C, class D> 
219  KernelLookup
220  EnsembleBuilder<C, D>::test_data(const KernelLookup& kernel, size_t k)
221  {
222    // weighted case
223    if (kernel.weighted()){
224      assert(false);
225      // no feature selection
226      if (kernel.data_weighted().rows()==subset_->training_features(k).size())
227        return KernelLookup(kernel, subset_->training_index(k), true);
228      MatrixLookupWeighted mlw = test_data(kernel.data_weighted(), k);
229      return subset_->training_data(k).test_kernel(mlw);
230
231    }
232    // unweighted case
233
234    // no feature selection
235    if (kernel.data().rows()==subset_->training_features(k).size())
236      return KernelLookup(kernel, subset_->training_index(k), true);
237   
238    // feature selection
239    return subset_->training_data(k).test_kernel(test_data(kernel.data(),k));
240  }
241 
242
243  template <class C, class D> 
244  const std::vector<std::vector<statistics::Averager> >& 
245  EnsembleBuilder<C, D>::validate(void)
246  {
247    validation_result_.clear();
248
249    validation_result_.reserve(subset_->target().nof_classes());   
250    for(size_t i=0; i<subset_->target().nof_classes();i++)
251      validation_result_.push_back(std::vector<statistics::Averager>(subset_->target().size()));
252   
253    utility::Matrix prediction; 
254    for(u_long k=0;k<subset_->size();k++) {
255      classifier(k).predict(subset_->validation_data(k),prediction);
256     
257      // map results to indices of samples in training + validation data set
258      for(size_t i=0; i<prediction.rows();i++) 
259        for(size_t j=0; j<prediction.columns();j++) {
260          validation_result_[i][subset_->validation_index(k)[j]].
261            add(prediction(i,j));
262        }           
263    }
264    return validation_result_;
265  }
266
267}}} // of namespace classifier, yat, and theplu
268
269#endif
Note: See TracBrowser for help on using the repository browser.