source: trunk/lib/classifier/EnsembleBuilder.cc @ 559

Last change on this file since 559 was 559, checked in by Peter, 17 years ago

some changes in EB

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 3.0 KB
Line 
1// $Id: EnsembleBuilder.cc 559 2006-03-11 22:21:27Z peter $
2
3#include <c++_tools/classifier/EnsembleBuilder.h>
4
5#include <c++_tools/classifier/CrossSplitter.h>
6#include <c++_tools/classifier/DataLookup2D.h>
7#include <c++_tools/classifier/KernelLookup.h>
8#include <c++_tools/classifier/SupervisedClassifier.h>
9#include <c++_tools/classifier/Target.h>
10
11#include <c++_tools/gslapi/matrix.h>
12
13namespace theplu {
14namespace classifier {
15
16  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc, 
17                                   CrossSplitter& cs) 
18    : mother_(sc), cross_splitter_(cs)
19  {
20  }
21
22  EnsembleBuilder::~EnsembleBuilder(void) 
23  {
24    for(size_t i=0; i<classifier_.size(); i++)
25      delete classifier_[i];
26  }
27
28  void EnsembleBuilder::build(void) 
29  {
30    cross_splitter_.reset();
31    while(cross_splitter_.more()) {
32      const DataLookup2D& training=cross_splitter_.training_data();
33      const Target& targets=cross_splitter_.training_target();
34
35      SupervisedClassifier* classifier=
36        mother_.make_classifier(training,targets);
37      classifier->train();
38      classifier_.push_back(classifier);
39      cross_splitter_.next();
40    }   
41  }
42
43  void  EnsembleBuilder::predict
44  (const DataLookup2D& data, 
45   std::vector<std::vector<statistics::Averager> >& result)
46  {
47    cross_splitter_.reset();
48
49    result.clear();
50    result.reserve(cross_splitter_.target().nof_classes());   
51    for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
52      result.push_back(std::vector<statistics::Averager>(data.columns()));
53   
54    size_t k=0;
55    gslapi::matrix prediction;   
56   
57
58    try {
59      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
60      while(cross_splitter_.more()) {
61        classifier(k++).predict(KernelLookup(kernel,
62                                             cross_splitter_.training_index(),
63                                             true),
64                                prediction);
65        for(size_t i=0; i<prediction.rows();i++) 
66          for(size_t j=0; j<prediction.columns();j++) 
67            result[i][j].add(prediction(i,j));
68        cross_splitter_.next();
69      }
70    }
71    catch (std::bad_cast) {
72      while(cross_splitter_.more()) {
73        classifier(k++).predict(data,prediction);
74        for(size_t i=0; i<prediction.rows();i++) 
75          for(size_t j=0; j<prediction.columns();j++) 
76            result[i][j].add(prediction(i,j));
77       
78        cross_splitter_.next();
79      }
80    }
81  }
82
83
84
85  const std::vector<std::vector<statistics::Averager> >& 
86  EnsembleBuilder::validate(void)
87  {
88    cross_splitter_.reset();
89    validation_result_.clear();
90
91    validation_result_.reserve(cross_splitter_.target().nof_classes());   
92    for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
93      validation_result_.push_back(std::vector<statistics::Averager>(cross_splitter_.target().size()));
94   
95    size_t k=0;
96    gslapi::matrix prediction;   
97    while(cross_splitter_.more()) {
98      classifier(k++).predict(cross_splitter_.validation_data(),prediction);
99
100      for(size_t i=0; i<prediction.rows();i++) 
101        for(size_t j=0; j<prediction.columns();j++) {
102          validation_result_[i][cross_splitter_.validation_index()[j]].
103            add(prediction(i,j));
104        } 
105         
106      cross_splitter_.next();
107    }
108    return validation_result_;
109  }
110
111}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.