source: trunk/yat/classifier/EnsembleBuilder.cc @ 865

Last change on this file since 865 was 865, checked in by Peter, 15 years ago

changing URL to http://trac.thep.lu.se/trac/yat

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 3.9 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005 Markus Ringnér
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "EnsembleBuilder.h"
27#include "DataLookup2D.h"
28#include "FeatureSelector.h"
29#include "KernelLookup.h"
30#include "Sampler.h"
31#include "SubsetGenerator.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/matrix.h"
35
36namespace theplu {
37namespace yat {
38namespace classifier {
39
40  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
41                                   const Sampler& sampler) 
42    : mother_(sc),subset_(new SubsetGenerator(sampler,sc.data()))
43  {
44  }
45
46  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
47                                   const Sampler& sampler,
48                                   FeatureSelector& fs) 
49    : mother_(sc),subset_(new SubsetGenerator(sampler,sc.data(),fs))
50  {
51  }
52
53  EnsembleBuilder::~EnsembleBuilder(void) 
54  {
55    for(size_t i=0; i<classifier_.size(); i++)
56      delete classifier_[i];
57    delete subset_;
58  }
59
60  void EnsembleBuilder::build(void) 
61  {
62    for(u_long i=0; i<subset_->size();++i) {
63      SupervisedClassifier* classifier=
64        mother_.make_classifier(subset_->training_data(i), 
65                                subset_->training_target(i));
66      classifier->train();
67      classifier_.push_back(classifier);
68    }   
69  }
70
71
72  const SupervisedClassifier& EnsembleBuilder::classifier(size_t i) const
73  {
74    return *(classifier_[i]);
75  }
76
77
78  u_long EnsembleBuilder::size(void) const
79  {
80    return classifier_.size();
81  }
82
83
84  void  EnsembleBuilder::predict
85  (const DataLookup2D& data, 
86   std::vector<std::vector<statistics::Averager> >& result)
87  {
88    result.clear();
89    result.reserve(subset_->target().nof_classes());   
90    for(size_t i=0; i<subset_->target().nof_classes();i++)
91      result.push_back(std::vector<statistics::Averager>(data.columns()));
92   
93    utility::matrix prediction; 
94    const KernelLookup* kernel = dynamic_cast<const KernelLookup*>(&data);
95    if (kernel) {
96      for(u_long k=0;k<subset_->size();k++) {
97        KernelLookup sub_kernel(*kernel,subset_->training_index(k),true);
98        classifier(k).predict(sub_kernel,prediction);
99
100        for(size_t i=0; i<prediction.rows();i++) 
101          for(size_t j=0; j<prediction.columns();j++) 
102            result[i][j].add(prediction(i,j));
103      }
104    }
105    else {
106      for(u_long k=0;k<subset_->size();k++) {
107        classifier(k).predict(data,prediction);
108        for(size_t i=0; i<prediction.rows();i++) 
109          for(size_t j=0; j<prediction.columns();j++) 
110            result[i][j].add(prediction(i,j));
111       
112      }
113    }
114  }
115
116
117  const std::vector<std::vector<statistics::Averager> >& 
118  EnsembleBuilder::validate(void)
119  {
120    validation_result_.clear();
121
122    validation_result_.reserve(subset_->target().nof_classes());   
123    for(size_t i=0; i<subset_->target().nof_classes();i++)
124      validation_result_.push_back(std::vector<statistics::Averager>(subset_->target().size()));
125   
126    utility::matrix prediction; 
127    for(u_long k=0;k<subset_->size();k++) {
128      classifier(k).predict(subset_->validation_data(k),prediction);
129
130      for(size_t i=0; i<prediction.rows();i++) 
131        for(size_t j=0; j<prediction.columns();j++) {
132          validation_result_[i][subset_->validation_index(k)[j]].
133            add(prediction(i,j));
134        }           
135    }
136    return validation_result_;
137  }
138
139}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.