source: trunk/yat/classifier/NCC.h @ 1173

Last change on this file since 1173 was 1173, checked in by Markus Ringnér, 14 years ago

Minor fixes

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 6.8 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér, Peter Johansson
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33
34#include "yat/statistics/Averager.h"
35#include "yat/statistics/AveragerWeighted.h"
36#include "yat/utility/Matrix.h"
37#include "yat/utility/Vector.h"
38#include "yat/utility/stl_utility.h"
39#include "yat/utility/yat_assert.h"
40
41#include<iostream>
42#include<iterator>
43#include <map>
44#include <cmath>
45#include <stdexcept>
46
47namespace theplu {
48namespace yat {
49namespace classifier { 
50
51
52  ///
53  /// @brief Class for Nearest Centroid Classification.
54  ///
55  /// The template argument Distance should be a class modelling
56  /// the concept \ref concept_distance.
57  ///
58  template <typename Distance>
59  class NCC : public SupervisedClassifier
60  {
61 
62  public:
63    ///
64    /// @brief Constructor
65    ///
66    NCC(void);
67   
68    ///
69    /// @brief Constructor
70    ///
71    NCC(const Distance&);
72
73
74    ///
75    /// @brief Destructor
76    ///
77    virtual ~NCC(void);
78
79    ///
80    /// @return the centroids for each class as columns in a matrix.
81    ///
82    const utility::Matrix& centroids(void) const;
83
84    NCC<Distance>* make_classifier(void) const;
85   
86    ///
87    /// Train the classifier with a training data set and
88    /// targets. Centroids are calculated for each class.
89    ///
90    void train(const MatrixLookup&, const Target&);
91
92
93    ///
94    /// Train the classifier with a weighted training data set and
95    /// targets. Centroids are calculated for each class.
96    ///
97    void train(const MatrixLookupWeighted&, const Target&);
98
99   
100    ///
101    /// Calculate the distance to each centroid for test samples
102    ///
103    void predict(const MatrixLookup&, utility::Matrix&) const;
104   
105    ///
106    /// Calculate the distance to each centroid for weighted test samples
107    ///
108    void predict(const MatrixLookupWeighted&, utility::Matrix&) const;
109
110   
111  private:
112
113    void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
114    void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;   
115
116    utility::Matrix centroids_;
117    bool centroids_nan_;
118    Distance distance_;
119  };
120
121  ///
122  /// The output operator for the NCC class.
123  ///
124  //  std::ostream& operator<< (std::ostream&, const NCC&);
125 
126
127  // templates
128
129  template <typename Distance>
130  NCC<Distance>::NCC() 
131    : SupervisedClassifier(), centroids_nan_(false)
132  {
133  }
134
135  template <typename Distance>
136  NCC<Distance>::NCC(const Distance& dist) 
137    : SupervisedClassifier(), centroids_nan_(false), distance_(dist)
138  {
139  }
140
141
142  template <typename Distance>
143  NCC<Distance>::~NCC()   
144  {
145  }
146
147
148  template <typename Distance>
149  const utility::Matrix& NCC<Distance>::centroids(void) const
150  {
151    return centroids_;
152  }
153 
154
155  template <typename Distance>
156  NCC<Distance>* 
157  NCC<Distance>::make_classifier() const 
158  {     
159    // All private members should be copied here to generate an
160    // identical but untrained classifier
161    return new NCC<Distance>(distance_);
162  }
163
164  template <typename Distance>
165  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
166  {   
167    centroids_.resize(data.rows(), target.nof_classes());
168    for(size_t i=0; i<data.rows(); i++) {
169      std::vector<statistics::Averager> class_averager;
170      class_averager.resize(target.nof_classes());
171      for(size_t j=0; j<data.columns(); j++) {
172        class_averager[target(j)].add(data(i,j));
173      }
174      for(size_t c=0;c<target.nof_classes();c++) {
175        centroids_(i,c) = class_averager[c].mean();
176      }
177    }
178  }
179
180
181  template <typename Distance>
182  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
183  {   
184    centroids_.resize(data.rows(), target.nof_classes());
185    for(size_t i=0; i<data.rows(); i++) {
186      std::vector<statistics::AveragerWeighted> class_averager;
187      class_averager.resize(target.nof_classes());
188      for(size_t j=0; j<data.columns(); j++) 
189        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
190      for(size_t c=0;c<target.nof_classes();c++) {
191        if(class_averager[c].sum_w()==0) {
192          centroids_nan_=true;
193        }
194        centroids_(i,c) = class_averager[c].mean();
195      }
196    }
197  }
198
199
200  template <typename Distance>
201  void NCC<Distance>::predict(const MatrixLookup& test,                     
202                              utility::Matrix& prediction) const
203  {   
204    utility::yat_assert<std::runtime_error>
205      (centroids_.rows()==test.rows(),
206       "NCC::predict test data with incorrect number of rows");
207   
208    prediction.resize(centroids_.columns(), test.columns());
209
210    // If weighted training data has resulted in NaN in centroids: weighted calculations
211    if(centroids_nan_) { 
212      predict_weighted(MatrixLookupWeighted(test),prediction);
213    }
214    // If unweighted training data: unweighted calculations
215    else {
216      predict_unweighted(test,prediction);
217    }
218  }
219
220  template <typename Distance>
221  void NCC<Distance>::predict(const MatrixLookupWeighted& test,                     
222                              utility::Matrix& prediction) const
223  {   
224    utility::yat_assert<std::runtime_error>
225      (centroids_.rows()==test.rows(),
226       "NCC::predict test data with incorrect number of rows");
227   
228    prediction.resize(centroids_.columns(), test.columns());
229    predict_weighted(test,prediction);
230  }
231
232 
233  template <typename Distance>
234  void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
235                                         utility::Matrix& prediction) const
236  {
237    MatrixLookup centroids(centroids_);
238    for(size_t j=0; j<test.columns();j++)
239      for(size_t k=0; k<centroids_.columns();k++) 
240        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
241                                    centroids.begin_column(k));
242  }
243 
244  template <typename Distance>
245  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
246                                          utility::Matrix& prediction) const
247  {
248    MatrixLookupWeighted weighted_centroids(centroids_);
249    for(size_t j=0; j<test.columns();j++) 
250      for(size_t k=0; k<centroids_.columns();k++)
251        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
252                                    weighted_centroids.begin_column(k));
253  }
254
255     
256}}} // of namespace classifier, yat, and theplu
257
258#endif
Note: See TracBrowser for help on using the repository browser.