source: trunk/c++_tools/classifier/EnsembleBuilder.cc @ 608

Last change on this file since 608 was 608, checked in by Peter, 15 years ago

set properties

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 2.9 KB
Line 
1// $Id$
2
3#include <c++_tools/classifier/EnsembleBuilder.h>
4
5#include <c++_tools/classifier/CrossSplitter.h>
6#include <c++_tools/classifier/DataLookup2D.h>
7#include <c++_tools/classifier/KernelLookup.h>
8#include <c++_tools/classifier/SupervisedClassifier.h>
9#include <c++_tools/classifier/Target.h>
10
11#include <c++_tools/gslapi/matrix.h>
12
13namespace theplu {
14namespace classifier {
15
16  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
17                                   CrossSplitter& cs) 
18    : mother_(sc), cross_splitter_(cs)
19  {
20  }
21
22  EnsembleBuilder::~EnsembleBuilder(void) 
23  {
24    for(size_t i=0; i<classifier_.size(); i++)
25      delete classifier_[i];
26  }
27
28  void EnsembleBuilder::build(void) 
29  {
30    cross_splitter_.reset();
31    while(cross_splitter_.more()) {
32      SupervisedClassifier* classifier=
33        mother_.make_classifier(cross_splitter_);
34      classifier->train();
35      classifier_.push_back(classifier);
36      cross_splitter_.next();
37    }   
38  }
39
40  void  EnsembleBuilder::predict
41  (const DataLookup2D& data, 
42   std::vector<std::vector<statistics::Averager> >& result)
43  {
44    cross_splitter_.reset();
45
46    result.clear();
47    result.reserve(cross_splitter_.target().nof_classes());   
48    for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
49      result.push_back(std::vector<statistics::Averager>(data.columns()));
50   
51    size_t k=0;
52    gslapi::matrix prediction;   
53   
54
55    try {
56      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
57      while(cross_splitter_.more()) {
58        classifier(k++).predict(KernelLookup(kernel,
59                                             cross_splitter_.training_index(),
60                                             true),
61                                prediction);
62        for(size_t i=0; i<prediction.rows();i++) 
63          for(size_t j=0; j<prediction.columns();j++) 
64            result[i][j].add(prediction(i,j));
65        cross_splitter_.next();
66      }
67    }
68    catch (std::bad_cast) {
69      while(cross_splitter_.more()) {
70        classifier(k++).predict(data,prediction);
71        for(size_t i=0; i<prediction.rows();i++) 
72          for(size_t j=0; j<prediction.columns();j++) 
73            result[i][j].add(prediction(i,j));
74       
75        cross_splitter_.next();
76      }
77    }
78  }
79
80
81
82  const std::vector<std::vector<statistics::Averager> >& 
83  EnsembleBuilder::validate(void)
84  {
85    cross_splitter_.reset();
86    validation_result_.clear();
87
88    validation_result_.reserve(cross_splitter_.target().nof_classes());   
89    for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
90      validation_result_.push_back(std::vector<statistics::Averager>(cross_splitter_.target().size()));
91   
92    size_t k=0;
93    gslapi::matrix prediction;   
94    while(cross_splitter_.more()) {
95      classifier(k++).predict(cross_splitter_.validation_data(),prediction);
96
97      for(size_t i=0; i<prediction.rows();i++) 
98        for(size_t j=0; j<prediction.columns();j++) {
99          validation_result_[i][cross_splitter_.validation_index()[j]].
100            add(prediction(i,j));
101        } 
102         
103      cross_splitter_.next();
104    }
105    return validation_result_;
106  }
107
108}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.