source: trunk/lib/classifier/NCC.cc @ 523

Last change on this file since 523 was 523, checked in by Peter, 17 years ago

add prediction functions to SVM

File size: 2.4 KB
Line 
1// $Id$
2
3#include <c++_tools/classifier/NCC.h>
4
5#include <c++_tools/classifier/DataLookup1D.h>
6#include <c++_tools/classifier/DataLookup2D.h>
7#include <c++_tools/classifier/Target.h>
8#include <c++_tools/gslapi/vector.h>
9#include <c++_tools/statistics/Distance.h>
10#include <c++_tools/utility/stl_utility.h>
11
12#include<iostream>
13#include<iterator>
14#include <map>
15#include <cmath>
16
17namespace theplu {
18namespace classifier {
19
20  NCC::NCC(const DataLookup2D& data, const Target& target, 
21           const statistics::Distance& distance) 
22    : SupervisedClassifier(target), distance_(distance), matrix_(data)
23  {   
24  }
25
26  SupervisedClassifier* 
27  NCC::make_classifier(const DataLookup2D& data, 
28                       const Target& target) const 
29  {     
30    NCC* sc= new NCC(data,target,this->distance_);
31    return sc;
32  }
33
34
35  bool NCC::train()
36  {
37    // Calculate the centroids for each class
38    centroids_=gslapi::matrix(matrix_.rows(),target_.nof_classes());
39    gslapi::matrix nof_in_class(matrix_.rows(),target_.nof_classes());
40    for(size_t i=0; i<matrix_.rows(); i++) {
41      for(size_t j=0; j<matrix_.columns(); j++) {
42        if(!std::isnan(matrix_(i,j))) {
43          centroids_(i,target_(j)) += matrix_(i,j);
44          nof_in_class(i,target_(j))++;
45        }
46      }
47    }
48    centroids_.div_elements(nof_in_class);
49    trained_=true;
50    return trained_;
51  }
52
53  void NCC::predict(const DataLookup1D& input, 
54                    gslapi::vector& prediction) const
55  {
56    prediction=gslapi::vector(centroids_.columns());   
57    gslapi::vector w(input.size(),0);
58    for(size_t i=0; i<input.size(); i++)  // take care of missing values
59      if(!std::isnan(input(i)))
60        w(i)=1.0;
61    for(size_t j=0; j<centroids_.columns(); j++) 
62      prediction(j)=distance_(gslapi::vector(input),
63                             gslapi::vector(centroids_,j,false),w, w);   
64  }
65
66
67  void NCC::predict(const DataLookup2D& input,                   
68                    gslapi::matrix& prediction) const
69  {
70    prediction=gslapi::matrix(centroids_.columns(), input.columns());   
71    for(size_t j=0; j<input.columns();j++) {     
72      DataLookup1D in(input,j,true);
73      gslapi::vector out;
74      predict(in,out);
75      prediction.set_column(j,out);
76    }
77  }
78
79 
80  // additional operators
81
82  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
83//    std::copy(ncc.classes().begin(), ncc.classes().end(),
84//              std::ostream_iterator<std::map<double, u_int>::value_type>
85//              (s, "\n"));
86    s << "\n" << ncc.centroids() << "\n";
87    return s;
88  }
89
90}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.