source: trunk/yat/classifier/EnsembleBuilder.h @ 1121

Last change on this file since 1121 was 1121, checked in by Peter, 14 years ago

fixes #308

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 5.7 KB
Line 
1#ifndef _theplu_yat_classifier_ensemblebuilder_
2#define _theplu_yat_classifier_ensemblebuilder_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007, 2008 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "FeatureSelector.h"
30#include "Sampler.h"
31#include "SubsetGenerator.h"
32#include "yat/statistics/Averager.h"
33#include "yat/utility/Matrix.h"
34
35#include <vector>
36
37namespace theplu {
38namespace yat {
39namespace classifier { 
40
41  ///
42  /// @brief Class for ensembles of supervised classifiers
43  ///
44  template <class Classifier, class Data>
45  class EnsembleBuilder
46  {
47 
48  public:
49    typedef Classifier classifier_type;
50    typedef Data data_type;
51
52    ///
53    /// Constructor.
54    ///
55    EnsembleBuilder(const Classifier&, const Data&, const Sampler&);
56
57    ///
58    /// Constructor.
59    ///
60    EnsembleBuilder(const Classifier&, const Data&, const Sampler&, 
61                    FeatureSelector&);
62
63    ///
64    /// Destructor.
65    ///
66    virtual ~EnsembleBuilder(void);
67
68    ///
69    /// Generate ensemble. Function trains each member of the Ensemble.
70    ///
71    void build(void);
72
73    ///
74    /// @Return classifier
75    ///
76    const Classifier& classifier(size_t i) const;
77     
78    ///
79    /// @Return Number of classifiers in ensemble
80    ///
81    u_long size(void) const;
82
83    ///
84    /// @brief Generate validation data for ensemble
85    ///
86    /// validate()[i][j] return averager for class @a i for sample @a j
87    ///
88    const std::vector<std::vector<statistics::Averager> >& validate(void);
89   
90    /**
91       Predict a dataset using the ensemble.
92       
93       If @a data is a KernelLookup each column should correspond to a
94       test sample and each row should correspond to a training
95       sample. More exactly row \f$ i \f$ in @a data should correspond
96       to the same sample as row/column \f$ i \f$ in the training
97       kernel corresponds to.
98    */
99    void predict(const Data& data, 
100                 std::vector<std::vector<statistics::Averager> > &);
101
102  private:
103    // no copying
104    EnsembleBuilder(const EnsembleBuilder&);
105    const EnsembleBuilder& operator=(const EnsembleBuilder&);
106   
107
108    const Classifier& mother_;
109    SubsetGenerator<Data>* subset_;
110    std::vector<Classifier*> classifier_;
111    std::vector<std::vector<statistics::Averager> > validation_result_;
112
113  };
114 
115
116  // implementation
117
118  template <class C, class D> 
119  EnsembleBuilder<C,D>::EnsembleBuilder(const C& sc, const D& data,
120                                        const Sampler& sampler) 
121    : mother_(sc),subset_(new SubsetGenerator<D>(sampler,data))
122  {
123  }
124
125
126  template <class C, class D> 
127  EnsembleBuilder<C, D>::EnsembleBuilder(const C& sc, const D& data, 
128                                         const Sampler& sampler,
129                                         FeatureSelector& fs) 
130    : mother_(sc),
131      subset_(new SubsetGenerator<D>(sampler,data,fs))
132  {
133  }
134
135
136  template <class C, class D> 
137  EnsembleBuilder<C, D>::~EnsembleBuilder(void) 
138  {
139    for(size_t i=0; i<classifier_.size(); i++)
140      delete classifier_[i];
141    delete subset_;
142  }
143
144
145  template <class C, class D> 
146  void EnsembleBuilder<C, D>::build(void) 
147  {
148    for(u_long i=0; i<subset_->size();++i) {
149      C* classifier = mother_.make_classifier(subset_->training_data(i), 
150                                              subset_->training_target(i));
151      classifier->train();
152      classifier_.push_back(classifier);
153    }   
154  }
155
156
157  template <class C, class D> 
158  const C& EnsembleBuilder<C, D>::classifier(size_t i) const
159  {
160    return *(classifier_[i]);
161  }
162
163
164  template <class C, class D> 
165  u_long EnsembleBuilder<C, D>::size(void) const
166  {
167    return classifier_.size();
168  }
169
170
171  template <class C, class D> 
172  void EnsembleBuilder<C, D>::predict
173  (const D& data, std::vector<std::vector<statistics::Averager> >& result)
174  {
175    result.clear();
176    result.reserve(subset_->target().nof_classes());   
177    for(size_t i=0; i<subset_->target().nof_classes();i++)
178      result.push_back(std::vector<statistics::Averager>(data.columns()));
179   
180    utility::Matrix prediction; 
181
182    for(u_long k=0;k<subset_->size();++k) {       
183      const D* sub_data =
184        data.selected(subset_->training_features(k));
185      assert(sub_data);
186      classifier(k).predict(*sub_data,prediction);
187      delete sub_data;
188    }
189
190    for(size_t i=0; i<prediction.rows();i++) 
191      for(size_t j=0; j<prediction.columns();j++) 
192        result[i][j].add(prediction(i,j));   
193  }
194
195 
196  template <class C, class D> 
197  const std::vector<std::vector<statistics::Averager> >& 
198  EnsembleBuilder<C, D>::validate(void)
199  {
200    validation_result_.clear();
201
202    validation_result_.reserve(subset_->target().nof_classes());   
203    for(size_t i=0; i<subset_->target().nof_classes();i++)
204      validation_result_.push_back(std::vector<statistics::Averager>(subset_->target().size()));
205   
206    utility::Matrix prediction; 
207    for(u_long k=0;k<subset_->size();k++) {
208      classifier(k).predict(subset_->validation_data(k),prediction);
209     
210      // map results to indices of samples in training + validation data set
211      for(size_t i=0; i<prediction.rows();i++) 
212        for(size_t j=0; j<prediction.columns();j++) {
213          validation_result_[i][subset_->validation_index(k)[j]].
214            add(prediction(i,j));
215        }           
216    }
217    return validation_result_;
218  }
219
220}}} // of namespace classifier, yat, and theplu
221
222#endif
Note: See TracBrowser for help on using the repository browser.