source: trunk/c++_tools/classifier/EnsembleBuilder.cc @ 615

Last change on this file since 615 was 615, checked in by Peter, 15 years ago

ref #60 NOTE: there is most likely a bug around. I have removed the ensemble.build() test in the ensemble_test to get the test go through. I will try to find and remove this bug asap.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 2.7 KB
Line 
1// $Id$
2
3#include <c++_tools/classifier/EnsembleBuilder.h>
4
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/KernelLookup.h>
7#include <c++_tools/classifier/SubsetGenerator.h>
8#include <c++_tools/classifier/SupervisedClassifier.h>
9#include <c++_tools/classifier/Target.h>
10
11#include <c++_tools/gslapi/matrix.h>
12
13namespace theplu {
14namespace classifier {
15
16  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
17                                   SubsetGenerator& subset) 
18    : mother_(sc), subset_(subset_)
19  {
20  }
21
22  EnsembleBuilder::~EnsembleBuilder(void) 
23  {
24    for(size_t i=0; i<classifier_.size(); i++)
25      delete classifier_[i];
26  }
27
28  void EnsembleBuilder::build(void) 
29  {
30    subset_.reset();
31    while(subset_.more()) {
32      SupervisedClassifier* classifier=
33        mother_.make_classifier(subset_);
34      classifier->train();
35      classifier_.push_back(classifier);
36      subset_.next();
37    }   
38  }
39
40  void  EnsembleBuilder::predict
41  (const DataLookup2D& data, 
42   std::vector<std::vector<statistics::Averager> >& result)
43  {
44    subset_.reset();
45
46    result.clear();
47    result.reserve(subset_.target().nof_classes());   
48    for(size_t i=0; i<subset_.target().nof_classes();i++)
49      result.push_back(std::vector<statistics::Averager>(data.columns()));
50   
51    size_t k=0;
52    gslapi::matrix prediction;   
53   
54
55    try {
56      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
57      while(subset_.more()) {
58        classifier(k++).predict(KernelLookup(kernel,
59                                             subset_.training_index(),
60                                             true),
61                                prediction);
62        for(size_t i=0; i<prediction.rows();i++) 
63          for(size_t j=0; j<prediction.columns();j++) 
64            result[i][j].add(prediction(i,j));
65        subset_.next();
66      }
67    }
68    catch (std::bad_cast) {
69      while(subset_.more()) {
70        classifier(k++).predict(data,prediction);
71        for(size_t i=0; i<prediction.rows();i++) 
72          for(size_t j=0; j<prediction.columns();j++) 
73            result[i][j].add(prediction(i,j));
74       
75        subset_.next();
76      }
77    }
78  }
79
80
81
82  const std::vector<std::vector<statistics::Averager> >& 
83  EnsembleBuilder::validate(void)
84  {
85    subset_.reset();
86    validation_result_.clear();
87
88    validation_result_.reserve(subset_.target().nof_classes());   
89    for(size_t i=0; i<subset_.target().nof_classes();i++)
90      validation_result_.push_back(std::vector<statistics::Averager>(subset_.target().size()));
91   
92    size_t k=0;
93    gslapi::matrix prediction;   
94    while(subset_.more()) {
95      classifier(k++).predict(subset_.validation_data(),prediction);
96
97      for(size_t i=0; i<prediction.rows();i++) 
98        for(size_t j=0; j<prediction.columns();j++) {
99          validation_result_[i][subset_.validation_index()[j]].
100            add(prediction(i,j));
101        } 
102         
103      subset_.next();
104    }
105    return validation_result_;
106  }
107
108}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.