source: trunk/yat/classifier/NCC.cc @ 857

Last change on this file since 857 was 857, checked in by Peter, 16 years ago

refs #148 in NCC

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 4.9 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005 Markus Ringnér, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen
7
8  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "NCC.h"
27#include "DataLookup1D.h"
28#include "DataLookup2D.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "Target.h"
32#include "yat/utility/matrix.h"
33#include "yat/utility/vector.h"
34#include "yat/statistics/Distance.h"
35#include "yat/utility/stl_utility.h"
36
37#include<iostream>
38#include<iterator>
39#include <map>
40#include <cmath>
41
42namespace theplu {
43namespace yat {
44namespace classifier {
45
46  NCC::NCC(const MatrixLookup& data, const Target& target, 
47           const statistics::Distance& distance) 
48    : SupervisedClassifier(target), distance_(distance), data_(data)
49  {
50  }
51
52  NCC::NCC(const MatrixLookupWeighted& data, const Target& target, 
53           const statistics::Distance& distance) 
54    : SupervisedClassifier(target), distance_(distance), data_(data)
55  {
56  }
57
58  NCC::~NCC()   
59  {
60  }
61
62
63  const utility::matrix& NCC::centroids(void) const
64  {
65    return centroids_;
66  }
67
68    const DataLookup2D& NCC::data(void) const
69    {
70    return data_;
71    }
72
73  SupervisedClassifier* 
74  NCC::make_classifier(const DataLookup2D& data, const Target& target) const 
75  {     
76    NCC* ncc=0;
77    if(data.weighted()) {
78      ncc=new NCC(dynamic_cast<const MatrixLookupWeighted&>(data),
79                  target,this->distance_);
80    }
81    else {
82      ncc=new NCC(dynamic_cast<const MatrixLookup&>(data),
83                  target,this->distance_);
84    }
85    return ncc;
86  }
87
88
89  bool NCC::train()
90  {   
91    centroids_.clone(utility::matrix(data_.rows(), target_.nof_classes()));
92    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
93    const MatrixLookupWeighted* weighted_data = 
94      dynamic_cast<const MatrixLookupWeighted*>(&data_);
95    bool weighted = weighted_data;
96
97    for(size_t i=0; i<data_.rows(); i++) {
98      for(size_t j=0; j<data_.columns(); j++) {
99        centroids_(i,target_(j)) += data_(i,j);
100        if (weighted)
101          nof_in_class(i,target_(j))+= weighted_data->weight(i,j);
102        else
103          nof_in_class(i,target_(j))+=1.0;
104      }
105    }   
106    centroids_.div(nof_in_class);
107    trained_=true;
108    return trained_;
109  }
110
111
112  void NCC::predict(const DataLookup1D& input, const utility::vector& weights,
113                    utility::vector& prediction) const
114  {
115    prediction.clone(utility::vector(centroids_.columns()));
116
117    utility::vector value(input.size(),0);
118    for(size_t i=0; i<input.size(); i++)
119      value(i)=input(i);
120   
121    // take care of nan's in centroids
122    for(size_t j=0; j<centroids_.columns(); j++) {
123      const utility::vector centroid(utility::vector(centroids_,j,false));
124      utility::vector wc(centroid.size(),0);
125      for(size_t i=0; i<centroid.size(); i++)  { 
126        if(!std::isnan(centroid(i)))
127          wc(i)=1.0;
128      }
129      prediction(j)=distance_(value,centroid,weights,wc);   
130    }
131  }
132
133
134  void NCC::predict(const DataLookup2D& input,                   
135                    utility::matrix& prediction) const
136  {   
137    prediction.clone(utility::matrix(centroids_.columns(), input.columns()));
138    // weighted case
139    const MatrixLookupWeighted* data =
140      dynamic_cast<const MatrixLookupWeighted*>(&input); 
141    if (data) {
142      for(size_t j=0; j<input.columns();j++) {     
143        DataLookup1D in(input,j,false);
144        utility::vector weights(in.size(),0);
145        for(size_t i=0; i<in.size();i++) 
146          weights(i)=data->weight(i,j);
147        utility::vector out;
148        predict(in,weights,out);
149        prediction.column(j,out);
150      }
151      return;
152    }
153    // non-weighted case
154    const MatrixLookup* x = dynamic_cast<const MatrixLookup*>(&input);
155    if (!x){
156      std::string str;
157      str = "Error in NCC::predict: DataLookup2D of unexpected class.";
158      throw std::runtime_error(str);
159    }
160    for(size_t j=0; j<input.columns();j++) {     
161      DataLookup1D in(input,j,false);
162      utility::vector weights(in.size(),1.0);
163      utility::vector out;
164      predict(in,weights,out);
165      prediction.column(j,out);
166    }
167  }
168
169 
170  // additional operators
171
172//  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
173//    std::copy(ncc.classes().begin(), ncc.classes().end(),
174//              std::ostream_iterator<std::map<double, u_int>::value_type>
175//              (s, "\n"));
176//    s << "\n" << ncc.centroids() << "\n";
177//    return s;
178//  }
179
180}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.