source: trunk/lib/classifier/NCC.cc @ 454

Last change on this file since 454 was 454, checked in by Markus Ringnér, 16 years ago

Added class for Nearest Centroid Classifier

File size: 2.0 KB
Line 
1// $Id$
2
3#include <c++_tools/classifier/NCC.h>
4
5#include <c++_tools/gslapi/matrix.h>
6#include <c++_tools/gslapi/vector.h>
7#include <c++_tools/statistics/Averager.h>
8#include <c++_tools/utility/stl_utility.h>
9
10#include<iostream>
11#include<iterator>
12#include <map>
13#include <cmath>
14
15namespace theplu {
16namespace classifier {
17
18  NCC::NCC(const gslapi::matrix& data, const gslapi::vector& target) 
19  {
20    gslapi::vector sorted_target(target);
21    sorted_target.sort(); // if targets contain NaN => infinite loop
22   
23    // Find the classes of targets
24    u_int nof_classes=0;
25    for (size_t i=0; i<sorted_target.size(); i++) {
26      std::pair<const double, u_int> p(sorted_target(i),nof_classes);
27      std::pair<std::map<double,u_int>::iterator,bool> 
28        status=classes_.insert(p);
29      if(status.second==true) {
30        nof_classes++;
31      }
32    }
33   
34    // Calculate the centroids for each class
35    centroids_=gslapi::matrix(data.rows(),classes_.size());
36    gslapi::matrix nof_in_class(data.rows(),classes_.size());
37    for(size_t i=0; i<data.rows(); i++) {
38      for(size_t j=0; j<data.columns(); j++) {
39        if(!std::isnan(data(i,j))) {
40          centroids_(i,classes_[target(j)]) += data(i,j);
41          nof_in_class(i,classes_[target(j)])++;
42        }
43      }
44    }
45    centroids_.div_elements(nof_in_class);
46  }
47
48  void NCC::predict(const gslapi::vector& input, 
49                    statistics::Score& score, 
50                    gslapi::vector& prediction) 
51  {
52    prediction=gslapi::vector(classes_.size());   
53    std::vector<size_t> use;
54    for(size_t i=0; i<input.size(); i++)  // take care of missing values
55      if(!std::isnan(input(i)))
56        use.push_back(i);
57    for(size_t j=0; j<centroids_.columns(); j++) 
58      prediction(j)=score.score(input,gslapi::vector(centroids_,j,false),use);
59  }
60
61 
62  // additional operators
63
64  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
65    std::copy(ncc.classes().begin(), ncc.classes().end(), 
66              std::ostream_iterator<std::map<double, u_int>::value_type>
67              (s, "\n"));
68    s << "\n" << ncc.centroids() << "\n";
69    return s;
70  }
71
72}} // of namespace svm and namespace theplu
Note: See TracBrowser for help on using the repository browser.