source: trunk/yat/classifier/EnsembleBuilder.cc @ 1072

Last change on this file since 1072 was 1072, checked in by Peter, 16 years ago

fixes #309

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 3.8 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005 Markus Ringnér
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "EnsembleBuilder.h"
27#include "DataLookup2D.h"
28#include "FeatureSelector.h"
29#include "KernelLookup.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "Sampler.h"
33#include "SubsetGenerator.h"
34#include "SupervisedClassifier.h"
35#include "Target.h"
36#include "yat/utility/matrix.h"
37
38#include <cassert>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
45                                   const Sampler& sampler) 
46    : mother_(sc),subset_(new SubsetGenerator<DataLookup2D>(sampler,sc.data()))
47  {
48  }
49
50  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
51                                   const Sampler& sampler,
52                                   FeatureSelector& fs) 
53    : mother_(sc),
54      subset_(new SubsetGenerator<DataLookup2D>(sampler,sc.data(),fs))
55  {
56  }
57
58  EnsembleBuilder::~EnsembleBuilder(void) 
59  {
60    for(size_t i=0; i<classifier_.size(); i++)
61      delete classifier_[i];
62    delete subset_;
63  }
64
65  void EnsembleBuilder::build(void) 
66  {
67    for(u_long i=0; i<subset_->size();++i) {
68      SupervisedClassifier* classifier=
69        mother_.make_classifier(subset_->training_data(i), 
70                                subset_->training_target(i));
71      classifier->train();
72      classifier_.push_back(classifier);
73    }   
74  }
75
76
77  const SupervisedClassifier& EnsembleBuilder::classifier(size_t i) const
78  {
79    return *(classifier_[i]);
80  }
81
82
83  u_long EnsembleBuilder::size(void) const
84  {
85    return classifier_.size();
86  }
87
88
89  void EnsembleBuilder::predict
90  (const DataLookup2D& data, 
91   std::vector<std::vector<statistics::Averager> >& result)
92  {
93    result.clear();
94    result.reserve(subset_->target().nof_classes());   
95    for(size_t i=0; i<subset_->target().nof_classes();i++)
96      result.push_back(std::vector<statistics::Averager>(data.columns()));
97   
98    utility::matrix prediction; 
99
100    for(u_long k=0;k<subset_->size();++k) {       
101      const DataLookup2D* sub_data =
102        data.selected(subset_->training_features(k));
103      assert(sub_data);
104      classifier(k).predict(*sub_data,prediction);
105      delete sub_data;
106    }
107
108    for(size_t i=0; i<prediction.rows();i++) 
109      for(size_t j=0; j<prediction.columns();j++) 
110        result[i][j].add(prediction(i,j));   
111  }
112
113 
114  const std::vector<std::vector<statistics::Averager> >& 
115  EnsembleBuilder::validate(void)
116  {
117    validation_result_.clear();
118
119    validation_result_.reserve(subset_->target().nof_classes());   
120    for(size_t i=0; i<subset_->target().nof_classes();i++)
121      validation_result_.push_back(std::vector<statistics::Averager>(subset_->target().size()));
122   
123    utility::matrix prediction; 
124    for(u_long k=0;k<subset_->size();k++) {
125      classifier(k).predict(subset_->validation_data(k),prediction);
126     
127      // map results to indices of samples in training + validation data set
128      for(size_t i=0; i<prediction.rows();i++) 
129        for(size_t j=0; j<prediction.columns();j++) {
130          validation_result_[i][subset_->validation_index(k)[j]].
131            add(prediction(i,j));
132        }           
133    }
134    return validation_result_;
135  }
136
137}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.