source: trunk/yat/classifier/EnsembleBuilder.h @ 1090

Last change on this file since 1090 was 1088, checked in by Peter, 13 years ago

Closes #247. Removed IteratorWeighted? iterators over weighted container instead can use Iterator with a special Policy.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 5.7 KB
Line 
1#ifndef _theplu_yat_classifier_ensemblebuilder_
2#define _theplu_yat_classifier_ensemblebuilder_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007, 2008 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "FeatureSelector.h"
30#include "Sampler.h"
31#include "SubsetGenerator.h"
32#include "yat/statistics/Averager.h"
33
34#include <vector>
35
36namespace theplu {
37namespace yat {
38namespace classifier { 
39
40  ///
41  /// @brief Class for ensembles of supervised classifiers
42  ///
43  template <class Classifier, class Data>
44  class EnsembleBuilder
45  {
46 
47  public:
48    typedef Classifier classifier_type;
49    typedef Data data_type;
50
51    ///
52    /// Constructor.
53    ///
54    EnsembleBuilder(const Classifier&, const Data&, const Sampler&);
55
56    ///
57    /// Constructor.
58    ///
59    EnsembleBuilder(const Classifier&, const Data&, const Sampler&, 
60                    FeatureSelector&);
61
62    ///
63    /// Destructor.
64    ///
65    virtual ~EnsembleBuilder(void);
66
67    ///
68    /// Generate ensemble. Function trains each member of the Ensemble.
69    ///
70    void build(void);
71
72    ///
73    /// @Return classifier
74    ///
75    const Classifier& classifier(size_t i) const;
76     
77    ///
78    /// @Return Number of classifiers in ensemble
79    ///
80    u_long size(void) const;
81
82    ///
83    /// @brief Generate validation data for ensemble
84    ///
85    /// validate()[i][j] return averager for class @a i for sample @a j
86    ///
87    const std::vector<std::vector<statistics::Averager> >& validate(void);
88   
89    /**
90       Predict a dataset using the ensemble.
91       
92       If @a data is a KernelLookup each column should correspond to a
93       test sample and each row should correspond to a training
94       sample. More exactly row \f$ i \f$ in @a data should correspond
95       to the same sample as row/column \f$ i \f$ in the training
96       kernel corresponds to.
97    */
98    void predict(const Data& data, 
99                 std::vector<std::vector<statistics::Averager> > &);
100
101  private:
102    // no copying
103    EnsembleBuilder(const EnsembleBuilder&);
104    const EnsembleBuilder& operator=(const EnsembleBuilder&);
105   
106
107    const Classifier& mother_;
108    SubsetGenerator<Data>* subset_;
109    std::vector<Classifier*> classifier_;
110    std::vector<std::vector<statistics::Averager> > validation_result_;
111
112  };
113 
114
115  // implementation
116
117  template <class C, class D> 
118  EnsembleBuilder<C,D>::EnsembleBuilder(const C& sc, const D& data,
119                                        const Sampler& sampler) 
120    : mother_(sc),subset_(new SubsetGenerator<D>(sampler,data))
121  {
122  }
123
124
125  template <class C, class D> 
126  EnsembleBuilder<C, D>::EnsembleBuilder(const C& sc, const D& data, 
127                                         const Sampler& sampler,
128                                         FeatureSelector& fs) 
129    : mother_(sc),
130      subset_(new SubsetGenerator<D>(sampler,data,fs))
131  {
132  }
133
134
135  template <class C, class D> 
136  EnsembleBuilder<C, D>::~EnsembleBuilder(void) 
137  {
138    for(size_t i=0; i<classifier_.size(); i++)
139      delete classifier_[i];
140    delete subset_;
141  }
142
143
144  template <class C, class D> 
145  void EnsembleBuilder<C, D>::build(void) 
146  {
147    for(u_long i=0; i<subset_->size();++i) {
148      C* classifier = mother_.make_classifier(subset_->training_data(i), 
149                                              subset_->training_target(i));
150      classifier->train();
151      classifier_.push_back(classifier);
152    }   
153  }
154
155
156  template <class C, class D> 
157  const C& EnsembleBuilder<C, D>::classifier(size_t i) const
158  {
159    return *(classifier_[i]);
160  }
161
162
163  template <class C, class D> 
164  u_long EnsembleBuilder<C, D>::size(void) const
165  {
166    return classifier_.size();
167  }
168
169
170  template <class C, class D> 
171  void EnsembleBuilder<C, D>::predict
172  (const D& data, std::vector<std::vector<statistics::Averager> >& result)
173  {
174    result.clear();
175    result.reserve(subset_->target().nof_classes());   
176    for(size_t i=0; i<subset_->target().nof_classes();i++)
177      result.push_back(std::vector<statistics::Averager>(data.columns()));
178   
179    utility::matrix prediction; 
180
181    for(u_long k=0;k<subset_->size();++k) {       
182      const D* sub_data =
183        data.selected(subset_->training_features(k));
184      assert(sub_data);
185      classifier(k).predict(*sub_data,prediction);
186      delete sub_data;
187    }
188
189    for(size_t i=0; i<prediction.rows();i++) 
190      for(size_t j=0; j<prediction.columns();j++) 
191        result[i][j].add(prediction(i,j));   
192  }
193
194 
195  template <class C, class D> 
196  const std::vector<std::vector<statistics::Averager> >& 
197  EnsembleBuilder<C, D>::validate(void)
198  {
199    validation_result_.clear();
200
201    validation_result_.reserve(subset_->target().nof_classes());   
202    for(size_t i=0; i<subset_->target().nof_classes();i++)
203      validation_result_.push_back(std::vector<statistics::Averager>(subset_->target().size()));
204   
205    utility::matrix prediction; 
206    for(u_long k=0;k<subset_->size();k++) {
207      classifier(k).predict(subset_->validation_data(k),prediction);
208     
209      // map results to indices of samples in training + validation data set
210      for(size_t i=0; i<prediction.rows();i++) 
211        for(size_t j=0; j<prediction.columns();j++) {
212          validation_result_[i][subset_->validation_index(k)[j]].
213            add(prediction(i,j));
214        }           
215    }
216    return validation_result_;
217  }
218
219}}} // of namespace classifier, yat, and theplu
220
221#endif
Note: See TracBrowser for help on using the repository browser.