source: trunk/lib/classifier/EnsembleBuilder.cc @ 509

Last change on this file since 509 was 509, checked in by Peter, 17 years ago

added test for target
redesign crossSplitter
added two class function in Target

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 1.8 KB
Line 
1// $Id: EnsembleBuilder.cc 509 2006-02-18 13:47:32Z peter $
2
3#include <c++_tools/classifier/EnsembleBuilder.h>
4
5#include <c++_tools/classifier/CrossSplitter.h>
6#include <c++_tools/classifier/DataLookup2D.h>
7#include <c++_tools/classifier/SupervisedClassifier.h>
8#include <c++_tools/classifier/Target.h>
9
10#include <c++_tools/gslapi/matrix.h>
11
12namespace theplu {
13namespace classifier {
14
15  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
16                                   CrossSplitter& cs) 
17    : mother_(sc), cross_splitter_(cs)
18  {
19  }
20
21  EnsembleBuilder::~EnsembleBuilder(void) 
22  {
23    for(size_t i=0; i<classifier_.size(); i++)
24      delete classifier_[i];
25  }
26
27  void EnsembleBuilder::build(void) 
28  {
29    cross_splitter_.reset();
30    while(cross_splitter_.more()) {
31      const DataLookup2D& training=cross_splitter_.training_data();
32      const Target& targets=cross_splitter_.training_target();
33      SupervisedClassifier* classifier=
34        mother_.make_classifier(training,targets);
35      classifier->train();
36      classifier_.push_back(classifier);
37      cross_splitter_.next();
38    }   
39  }
40
41  const std::vector<std::vector<statistics::Averager> >& 
42  EnsembleBuilder::validate(void)
43  {
44    cross_splitter_.reset();
45    validation_result_.clear();
46
47    validation_result_.reserve(cross_splitter_.target().nof_classes());   
48    for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
49      validation_result_.push_back(std::vector<statistics::Averager>(cross_splitter_.target().size()));
50   
51    size_t k=0;
52    gslapi::matrix prediction;   
53    while(cross_splitter_.more()) {
54      classifier(k++).predict(cross_splitter_.validation_data(),prediction);
55
56      for(size_t i=0; i<prediction.rows();i++) 
57        for(size_t j=0; j<prediction.columns();j++) {
58          validation_result_[i][cross_splitter_.validation_index()[j]].
59            add(prediction(i,j));
60        } 
61         
62      cross_splitter_.next();
63    }
64    return validation_result_;
65  }
66
67}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.