source: trunk/c++_tools/classifier/NBC.cc @ 662

Last change on this file since 662 was 662, checked in by Peter, 17 years ago

refs #57 added class for Naive Baysian Classification. Predict function not yet implemented though.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 1.9 KB
Line 
1// $Id: NBC.cc 662 2006-09-27 10:16:17Z peter $
2
3#include <c++_tools/classifier/NBC.h>
4
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/MatrixLookup.h>
7#include <c++_tools/classifier/MatrixLookupWeighted.h>
8#include <c++_tools/classifier/Target.h>
9#include <c++_tools/statistics/AveragerWeighted.h>
10#include <c++_tools/utility/matrix.h>
11
12#include <vector>
13
14namespace theplu {
15namespace classifier {
16
17  NBC::NBC(const MatrixLookup& data, const Target& target) 
18    : SupervisedClassifier(target), data_(data)
19  {
20  }
21
22  NBC::NBC(const MatrixLookupWeighted& data, const Target& target) 
23    : SupervisedClassifier(target), data_(data)
24  {
25  }
26
27  NBC::~NBC()   
28  {
29  }
30
31
32  SupervisedClassifier* 
33  NBC::make_classifier(const DataLookup2D& data, const Target& target) const 
34  {     
35    NBC* ncc=0;
36    if(data.weighted()) {
37      ncc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
38    }
39    else {
40      ncc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
41    }
42    return ncc;
43  }
44
45
46  bool NBC::train()
47  {   
48    sigma_=centroids_=utility::matrix(data_.rows(), target_.nof_classes());
49    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
50   
51   
52    for(size_t i=0; i<data_.rows(); ++i) {
53      std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
54      for(size_t j=0; j<data_.columns(); ++j) {
55        if (data_.weighted()){
56          const MatrixLookupWeighted& data = 
57            dynamic_cast<const MatrixLookupWeighted&>(data_);
58            aver[target_(j)].add(data.data(i,j), data.weight(i,j));
59        }
60        else
61          aver[target_(j)].add(data_(i,j),1.0);
62      }
63      for (size_t j=0; target_.nof_classes(); ++j){
64        centroids_(i,j) = aver[j].mean();
65        sigma_(i,j) = aver[j].variance();
66      }
67    }   
68    trained_=true;
69    return trained_;
70  }
71
72
73  void NBC::predict(const DataLookup2D& input,                   
74                    utility::matrix& prediction) const
75  {   
76    std::cerr << "NBC::predict not implemented\n";
77    exit(1);
78  }
79
80
81}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.