source: trunk/lib/classifier/EnsembleBuilder.cc @ 531

Last change on this file since 531 was 531, checked in by Markus Ringnér, 17 years ago

Note there are some problem when creating MatrixLookups? from MatrixLookups? and index vectors. Halfway on the path to finding the bugs. This means CrossSplitter? and layers of CrossSplitters? do not work properly.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 2.5 KB
Line 
1// $Id: EnsembleBuilder.cc 531 2006-03-02 17:35:12Z markus $
2
3#include <c++_tools/classifier/EnsembleBuilder.h>
4
5#include <c++_tools/classifier/CrossSplitter.h>
6#include <c++_tools/classifier/DataLookup2D.h>
7#include <c++_tools/classifier/SupervisedClassifier.h>
8#include <c++_tools/classifier/Target.h>
9
10#include <c++_tools/gslapi/matrix.h>
11
12namespace theplu {
13namespace classifier {
14
15  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
16                                   CrossSplitter& cs) 
17    : mother_(sc), cross_splitter_(cs)
18  {
19  }
20
21  EnsembleBuilder::~EnsembleBuilder(void) 
22  {
23    for(size_t i=0; i<classifier_.size(); i++)
24      delete classifier_[i];
25  }
26
27  void EnsembleBuilder::build(void) 
28  {
29    cross_splitter_.reset();
30    while(cross_splitter_.more()) {
31      const DataLookup2D& training=cross_splitter_.training_data();
32      const Target& targets=cross_splitter_.training_target();
33      SupervisedClassifier* classifier=
34        mother_.make_classifier(training,targets);
35      classifier->train();
36      classifier_.push_back(classifier);
37      cross_splitter_.next();
38    }   
39  }
40
41  void  EnsembleBuilder::predict
42  (const DataLookup2D& data, 
43   std::vector<std::vector<statistics::Averager> >& result)
44  {
45    cross_splitter_.reset();
46
47    result.clear();
48    result.reserve(cross_splitter_.target().nof_classes());   
49    for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
50      result.push_back(std::vector<statistics::Averager>(data.columns()));
51   
52    size_t k=0;
53    gslapi::matrix prediction;   
54    while(cross_splitter_.more()) {
55      classifier(k++).predict(data,prediction);
56
57      for(size_t i=0; i<prediction.rows();i++) 
58        for(size_t j=0; j<prediction.columns();j++) 
59          result[i][j].add(prediction(i,j));
60
61      cross_splitter_.next();
62    }
63
64  }
65
66
67
68  const std::vector<std::vector<statistics::Averager> >& 
69  EnsembleBuilder::validate(void)
70  {
71    cross_splitter_.reset();
72    validation_result_.clear();
73
74    validation_result_.reserve(cross_splitter_.target().nof_classes());   
75    for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
76      validation_result_.push_back(std::vector<statistics::Averager>(cross_splitter_.target().size()));
77   
78    size_t k=0;
79    gslapi::matrix prediction;   
80    while(cross_splitter_.more()) {
81      classifier(k++).predict(cross_splitter_.validation_data(),prediction);
82
83      for(size_t i=0; i<prediction.rows();i++) 
84        for(size_t j=0; j<prediction.columns();j++) {
85          validation_result_[i][cross_splitter_.validation_index()[j]].
86            add(prediction(i,j));
87        } 
88         
89      cross_splitter_.next();
90    }
91    return validation_result_;
92  }
93
94}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.