source: trunk/yat/classifier/EnsembleBuilder.cc @ 866

Last change on this file since 866 was 866, checked in by Markus Ringnér, 14 years ago

refs #236. EnsembleBuilder::predict bug fixed for matrix-based classifiers. However this is not the case for Kernel-based classifiers!

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 4.5 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005 Markus Ringnér
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "EnsembleBuilder.h"
27#include "DataLookup2D.h"
28#include "FeatureSelector.h"
29#include "KernelLookup.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "Sampler.h"
33#include "SubsetGenerator.h"
34#include "SupervisedClassifier.h"
35#include "Target.h"
36#include "yat/utility/matrix.h"
37
38namespace theplu {
39namespace yat {
40namespace classifier {
41
42  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
43                                   const Sampler& sampler) 
44    : mother_(sc),subset_(new SubsetGenerator(sampler,sc.data()))
45  {
46  }
47
48  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
49                                   const Sampler& sampler,
50                                   FeatureSelector& fs) 
51    : mother_(sc),subset_(new SubsetGenerator(sampler,sc.data(),fs))
52  {
53  }
54
55  EnsembleBuilder::~EnsembleBuilder(void) 
56  {
57    for(size_t i=0; i<classifier_.size(); i++)
58      delete classifier_[i];
59    delete subset_;
60  }
61
62  void EnsembleBuilder::build(void) 
63  {
64    for(u_long i=0; i<subset_->size();++i) {
65      SupervisedClassifier* classifier=
66        mother_.make_classifier(subset_->training_data(i), 
67                                subset_->training_target(i));
68      classifier->train();
69      classifier_.push_back(classifier);
70    }   
71  }
72
73
74  const SupervisedClassifier& EnsembleBuilder::classifier(size_t i) const
75  {
76    return *(classifier_[i]);
77  }
78
79
80  u_long EnsembleBuilder::size(void) const
81  {
82    return classifier_.size();
83  }
84
85
86  void EnsembleBuilder::predict
87  (const DataLookup2D& data, 
88   std::vector<std::vector<statistics::Averager> >& result)
89  {
90    result.clear();
91    result.reserve(subset_->target().nof_classes());   
92    for(size_t i=0; i<subset_->target().nof_classes();i++)
93      result.push_back(std::vector<statistics::Averager>(data.columns()));
94   
95    utility::matrix prediction; 
96   
97    const KernelLookup* kernel = dynamic_cast<const KernelLookup*>(&data);   
98    const MatrixLookupWeighted* mw = 
99      dynamic_cast<const MatrixLookupWeighted*>(&data);   
100    const MatrixLookup* m = dynamic_cast<const MatrixLookup*>(&data);
101
102    if(kernel) {
103      for(u_long k=0;k<subset_->size();k++) {       
104        KernelLookup sub_kernel(*kernel,subset_->training_index(k),true);
105        classifier(k).predict(sub_kernel,prediction);
106      }
107    }
108    else if(mw) {
109      for(u_long k=0;k<subset_->size();k++) {
110        MatrixLookupWeighted sub_matrix(*mw,subset_->training_features(k),true);
111        classifier(k).predict(sub_matrix,prediction);
112      }
113    }
114    else if(m) {
115      for(u_long k=0;k<subset_->size();k++) {       
116        MatrixLookup sub_matrix(*m,subset_->training_features(k),true);
117        classifier(k).predict(sub_matrix,prediction);       
118      }
119    }
120    else {
121      std::string str;
122      str = "Error in NCC::predict: DataLookup2D of unexpected class.";
123      throw std::runtime_error(str);
124    }
125
126    for(size_t i=0; i<prediction.rows();i++) 
127      for(size_t j=0; j<prediction.columns();j++) 
128        result[i][j].add(prediction(i,j));   
129  }
130
131 
132  const std::vector<std::vector<statistics::Averager> >& 
133  EnsembleBuilder::validate(void)
134  {
135    validation_result_.clear();
136
137    validation_result_.reserve(subset_->target().nof_classes());   
138    for(size_t i=0; i<subset_->target().nof_classes();i++)
139      validation_result_.push_back(std::vector<statistics::Averager>(subset_->target().size()));
140   
141    utility::matrix prediction; 
142    for(u_long k=0;k<subset_->size();k++) {
143      classifier(k).predict(subset_->validation_data(k),prediction);
144     
145      // map results to indices of samples in training + validation data set
146      for(size_t i=0; i<prediction.rows();i++) 
147        for(size_t j=0; j<prediction.columns();j++) {
148          validation_result_[i][subset_->validation_index(k)[j]].
149            add(prediction(i,j));
150        }           
151    }
152    return validation_result_;
153  }
154
155}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.