source: trunk/c++_tools/classifier/NCC.cc @ 593

Last change on this file since 593 was 593, checked in by Markus Ringnér, 15 years ago

Fixed std includes to compile with g++ 4.1.

File size: 4.9 KB
Line 
1// $Id$
2
3#include <c++_tools/classifier/NCC.h>
4
5#include <c++_tools/classifier/CrossSplitter.h>
6#include <c++_tools/classifier/DataLookup1D.h>
7#include <c++_tools/classifier/DataLookup2D.h>
8#include <c++_tools/classifier/MatrixLookup.h>
9#include <c++_tools/classifier/InputRanker.h>
10#include <c++_tools/classifier/Target.h>
11#include <c++_tools/gslapi/vector.h>
12#include <c++_tools/statistics/Distance.h>
13#include <c++_tools/utility/stl_utility.h>
14
15#include<iostream>
16#include<iterator>
17#include <map>
18#include <cmath>
19
20namespace theplu {
21namespace classifier {
22
23  NCC::NCC(const MatrixLookup& data, const Target& target, 
24           const statistics::Distance& distance) 
25    : SupervisedClassifier(target), distance_(distance), matrix_(data),
26      weighted_(false),
27      weights_(new MatrixLookup(data.rows(),data.columns(),1.0))
28  {
29  }
30
31  NCC::NCC(const MatrixLookup& data, const Target& target, 
32           const statistics::Distance& distance, const MatrixLookup& weights) 
33    : SupervisedClassifier(target), distance_(distance), matrix_(data),
34      weighted_(true),weights_(&weights)
35  {
36  }
37
38
39  NCC::NCC(const MatrixLookup& data, const Target& target, 
40           const statistics::Distance& distance, 
41           statistics::Score& score, size_t nof_inputs) 
42    : SupervisedClassifier(target, &score, nof_inputs), 
43      distance_(distance), matrix_(data),weighted_(false),
44      weights_(new MatrixLookup(data.rows(),data.columns(),1.0))
45  {
46  }
47
48  NCC::NCC(const MatrixLookup& data, const Target& target, 
49           const statistics::Distance& distance,
50           const MatrixLookup& weights,
51           statistics::Score& score, size_t nof_inputs) 
52    : SupervisedClassifier(target, &score, nof_inputs), 
53      distance_(distance), matrix_(data),weighted_(true),
54      weights_(&weights)
55  {
56  }
57
58
59  NCC::~NCC()   
60  {
61    if(!weighted_) 
62      if(weights_)
63        delete weights_;
64      else 
65        std::cerr << "Error in NCC implementation: probably a constructor"
66                  << " should be debugged to make sure a unity weight matrix is"
67                  << " dynamically allocated for all unweighted cases" 
68                  << std::endl;
69  }
70
71
72  SupervisedClassifier* 
73  NCC::make_classifier(const CrossSplitter& cs) const 
74  {     
75    const MatrixLookup& training_data = 
76      dynamic_cast<const MatrixLookup&>(cs.training_data());
77    NCC* ncc=0;
78    if(cs.weighted()) {
79      ncc= new NCC(training_data,cs.training_target(),this->distance_,
80                        cs.training_weight());
81    }
82    else {
83      ncc= new NCC(training_data,cs.training_target(),this->distance_);
84    }
85    ncc->score_=this->score_;
86    ncc->nof_inputs_=this->nof_inputs_;
87    return ncc;
88  }
89
90
91  bool NCC::train()
92  {
93    // If score is set calculate centroids only for nof_inputs_ number
94    // of top ranked inputs. Otherwise calculate centroids based on
95    // all inputs ( = all rows in data matrix).
96    if(ranker_)
97      delete ranker_;
98    size_t rows=matrix_.rows();
99    if(score_) {
100      ranker_=new InputRanker(matrix_, target_, *score_, *weights_);
101      rows=nof_inputs_;
102    }
103    centroids_=gslapi::matrix(rows, target_.nof_classes());
104    gslapi::matrix nof_in_class(rows, target_.nof_classes());
105    for(size_t i=0; i<rows; i++) {
106      for(size_t j=0; j<matrix_.columns(); j++) {
107        double value=matrix_(i,j);
108        double weight=(*weights_)(i,j);
109        if(score_) {
110          value=matrix_(ranker_->id(i),j);
111          weight=(*weights_)(ranker_->id(i),j);
112        }
113        if(weight) {
114          centroids_(i,target_(j)) += value*weight;
115          nof_in_class(i,target_(j))+=weight;
116        }
117      }
118    }
119    centroids_.div_elements(nof_in_class);
120    trained_=true;
121    return trained_;
122  }
123
124
125  void NCC::predict(const DataLookup1D& input, 
126                    gslapi::vector& prediction) const
127  {
128    prediction=gslapi::vector(centroids_.columns());   
129    size_t size=input.size();
130    if(ranker_)
131      size=nof_inputs_;
132    gslapi::vector w(size,0);
133    gslapi::vector value(size,0);
134    for(size_t i=0; i<size; i++)  { // take care of missing values
135      value(i)=input(i);
136      if(ranker_)
137        value(i)=input(ranker_->id(i));
138      if(!std::isnan(value(i)))
139        w(i)=1.0;
140    }
141    for(size_t j=0; j<centroids_.columns(); j++) {
142      gslapi::vector centroid=gslapi::vector(centroids_,j,false);
143      gslapi::vector wc(centroid.size(),0);
144      for(size_t i=0; i<centroid.size(); i++)  { // take care of missing values
145        if(!std::isnan(centroid(i)))
146          wc(i)=1.0;
147      }
148      prediction(j)=distance_(value,centroid,w,wc);   
149    }
150  }
151
152
153  void NCC::predict(const DataLookup2D& input,                   
154                    gslapi::matrix& prediction) const
155  {
156    prediction=gslapi::matrix(centroids_.columns(), input.columns());   
157    for(size_t j=0; j<input.columns();j++) {     
158      DataLookup1D in(input,j,false);
159      gslapi::vector out;
160      predict(in,out);
161      prediction.set_column(j,out);
162    }
163  }
164
165 
166  // additional operators
167
168//  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
169//    std::copy(ncc.classes().begin(), ncc.classes().end(),
170//              std::ostream_iterator<std::map<double, u_int>::value_type>
171//              (s, "\n"));
172//    s << "\n" << ncc.centroids() << "\n";
173//    return s;
174//  }
175
176}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.