source: branches/peters_vector/lib/classifier/NCC.cc @ 469

Last change on this file since 469 was 469, checked in by Peter, 16 years ago

non compiling checking before revision after design meeting

File size: 2.2 KB
Line 
1// $Id$
2
3#include <c++_tools/classifier/NCC.h>
4
5#include <c++_tools/classifier/DataView.h>
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/gslapi/vector.h>
8#include <c++_tools/statistics/Averager.h>
9#include <c++_tools/statistics/AveragerPairWeighted.h>
10#include <c++_tools/utility/stl_utility.h>
11
12#include<iostream>
13#include<iterator>
14#include <map>
15#include <cmath>
16
17namespace theplu {
18namespace classifier {
19
20  NCC::NCC(const DataView& data, const Target& target) 
21  {
22    Target sorted_target(target);
23    sorted_target.sort(); // if targets contain NaN => infinite loop
24   
25    // Find the classes of targets
26    u_int nof_classes=0;
27    for (size_t i=0; i<sorted_target.size(); i++) {
28      std::pair<const double, u_int> p(sorted_target(i),nof_classes);
29      std::pair<std::map<double,u_int>::iterator,bool> 
30        status=classes_.insert(p);
31      if(status.second==true) {
32        nof_classes++;
33      }
34    }
35   
36    // Calculate the centroids for each class
37    centroids_=gslapi::matrix(data.rows(),classes_.size());
38    gslapi::matrix nof_in_class(data.rows(),classes_.size());
39    for(size_t i=0; i<data.rows(); i++) {
40      for(size_t j=0; j<data.columns(); j++) {
41        if(!std::isnan(data(i,j))) {
42          centroids_(i,classes_[target(j)]) += data(i,j);
43          nof_in_class(i,classes_[target(j)])++;
44        }
45      }
46    }
47    centroids_.div_elements(nof_in_class);
48  }
49
50  void NCC::predict(const gslapi::vector& input, 
51                    statistics::Score& score, 
52                    gslapi::vector& prediction) 
53  {
54    prediction=gslapi::vector(classes_.size());   
55    gslapi::vector w(input.size(),0);
56    for(size_t i=0; i<input.size(); i++)  // take care of missing values
57      if(!std::isnan(input(i)))
58        w(i)=1.0;
59    statistics::AveragerPairWeighted ap;
60    for(size_t j=0; j<centroids_.columns(); j++) {
61      ap.add(input,gslapi::vector(centroids_,j,false),w, w);
62      prediction(j)=ap.correlation();
63      ap.reset();
64    }
65  }
66
67 
68  // additional operators
69
70  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
71    std::copy(ncc.classes().begin(), ncc.classes().end(), 
72              std::ostream_iterator<std::map<double, u_int>::value_type>
73              (s, "\n"));
74    s << "\n" << ncc.centroids() << "\n";
75    return s;
76  }
77
78}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.