source: trunk/yat/classifier/NCC.cc @ 874

Last change on this file since 874 was 874, checked in by Markus Ringnér, 14 years ago

Fixes ticket:237

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 4.9 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005 Markus Ringnér, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "NCC.h"
27#include "DataLookup1D.h"
28#include "DataLookup2D.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "Target.h"
32#include "yat/utility/matrix.h"
33#include "yat/utility/vector.h"
34#include "yat/statistics/Distance.h"
35#include "yat/utility/stl_utility.h"
36
37#include<iostream>
38#include<iterator>
39#include <map>
40#include <cmath>
41
42namespace theplu {
43namespace yat {
44namespace classifier {
45
46  NCC::NCC(const MatrixLookup& data, const Target& target, 
47           const statistics::Distance& distance) 
48    : SupervisedClassifier(target), distance_(distance), data_(data)
49  {
50  }
51
52  NCC::NCC(const MatrixLookupWeighted& data, const Target& target, 
53           const statistics::Distance& distance) 
54    : SupervisedClassifier(target), distance_(distance), data_(data)
55  {
56  }
57
58  NCC::~NCC()   
59  {
60  }
61
62
63  const utility::matrix& NCC::centroids(void) const
64  {
65    return centroids_;
66  }
67
68    const DataLookup2D& NCC::data(void) const
69    {
70    return data_;
71    }
72
73  SupervisedClassifier* 
74  NCC::make_classifier(const DataLookup2D& data, const Target& target) const 
75  {     
76    NCC* ncc=0;
77    if(data.weighted()) {
78      ncc=new NCC(dynamic_cast<const MatrixLookupWeighted&>(data),
79                  target,this->distance_);
80    }
81    else {
82      ncc=new NCC(dynamic_cast<const MatrixLookup&>(data),
83                  target,this->distance_);
84    }
85    return ncc;
86  }
87
88
89  bool NCC::train()
90  {   
91    centroids_.clone(utility::matrix(data_.rows(), target_.nof_classes()));
92    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
93    const MatrixLookupWeighted* weighted_data = 
94      dynamic_cast<const MatrixLookupWeighted*>(&data_);
95    bool weighted = weighted_data;
96
97    for(size_t i=0; i<data_.rows(); i++) {
98      for(size_t j=0; j<data_.columns(); j++) {
99        centroids_(i,target_(j)) += data_(i,j);
100        if (weighted)
101          nof_in_class(i,target_(j))+= weighted_data->weight(i,j);
102        else
103          nof_in_class(i,target_(j))+=1.0;
104      }
105    }   
106    centroids_.div(nof_in_class);
107    trained_=true;
108    return trained_;
109  }
110
111
112  void NCC::predict(const utility::vector& input, const utility::vector& weights,
113                    utility::vector& prediction) const
114  {
115    prediction.clone(utility::vector(centroids_.columns()));
116   
117    // take care of nan's in centroids
118    for(size_t j=0; j<centroids_.columns(); j++) {
119      const utility::vector centroid(utility::vector(centroids_,j,false));
120      utility::vector wc(centroid.size(),0);
121      for(size_t i=0; i<centroid.size(); i++)  { 
122        if(!std::isnan(centroid(i)))
123          wc(i)=1.0;
124      }
125      prediction(j)=distance_(input,centroid,weights,wc);   
126    }
127  }
128
129
130  void NCC::predict(const DataLookup2D& input,                   
131                    utility::matrix& prediction) const
132  {   
133    prediction.clone(utility::matrix(centroids_.columns(), input.columns()));
134    // weighted case
135    const MatrixLookupWeighted* data =
136      dynamic_cast<const MatrixLookupWeighted*>(&input); 
137    if (data) {
138      for(size_t j=0; j<input.columns();j++) {     
139        utility::vector in(input.rows(),0);
140        for(size_t i=0; i<in.size();i++) 
141          in(i)=data->data(i,j);
142        utility::vector weights(in.size(),0);
143        for(size_t i=0; i<in.size();i++) 
144          weights(i)=data->weight(i,j);
145        utility::vector out;
146        predict(in,weights,out);
147        prediction.column(j,out);
148      }
149      return;
150    }
151    // non-weighted case
152    const MatrixLookup* x = dynamic_cast<const MatrixLookup*>(&input);
153    if (!x){
154      std::string str;
155      str = "Error in NCC::predict: DataLookup2D of unexpected class.";
156      throw std::runtime_error(str);
157    }
158    for(size_t j=0; j<input.columns();j++) {     
159      utility::vector in(input.rows(),0);
160      for(size_t i=0; i<in.size();i++) 
161        in(i)=(*data)(i,j);
162      utility::vector weights(in.size(),1.0);
163      utility::vector out;
164      predict(in,weights,out);
165      prediction.column(j,out);
166    }
167  }
168
169 
170  // additional operators
171
172//  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
173//    std::copy(ncc.classes().begin(), ncc.classes().end(),
174//              std::ostream_iterator<std::map<double, u_int>::value_type>
175//              (s, "\n"));
176//    s << "\n" << ncc.centroids() << "\n";
177//    return s;
178//  }
179
180}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.