source: trunk/yat/classifier/NCC.h @ 2334

Last change on this file since 2334 was 2334, checked in by Peter, 11 years ago

fixes #625

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.5 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Peter Johansson, Markus Ringnér
8  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
9  Copyright (C) 2009 Peter Johansson
10
11  This file is part of the yat library, http://dev.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 3 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with yat. If not, see <http://www.gnu.org/licenses/>.
25*/
26
27#include "MatrixLookup.h"
28#include "MatrixLookupWeighted.h"
29#include "SupervisedClassifier.h"
30#include "Target.h"
31
32#include "yat/statistics/Averager.h"
33#include "yat/statistics/AveragerWeighted.h"
34#include "yat/utility/concept_check.h"
35#include "yat/utility/Exception.h"
36#include "yat/utility/Matrix.h"
37#include "yat/utility/MatrixWeighted.h"
38#include "yat/utility/Vector.h"
39#include "yat/utility/stl_utility.h"
40#include "yat/utility/yat_assert.h"
41
42#include <boost/concept_check.hpp>
43
44#include <iterator>
45#include <map>
46#include <cmath>
47
48namespace theplu {
49namespace yat {
50namespace classifier { 
51
52
53  /**
54     \brief Nearest Centroid Classifier
55     
56     A sample is predicted based on its distance to centroids for each
57     class. The centroids are generated using training data. NCC
58     supports using different measures, for example, Euclidean
59     distance, to define distance between samples and centroids.     
60
61     The template argument Distance should be a class modelling
62     the concept \ref concept_distance.
63  */
64  template <typename Distance>
65  class NCC : public SupervisedClassifier
66  {
67 
68  public:
69    /**
70       \brief Constructor
71       
72       Distance is initialized using its default constructor.   
73    */
74    NCC(void);
75   
76    /**
77       \brief Constructor using an initialized distance measure
78
79       This constructor should be used if Distance has parameters and
80       the user wants to specify the parameters by initializing
81       Distance prior to constructing the NCC.
82    */
83    NCC(const Distance&);
84
85
86    /**
87       Destructor
88    */
89    virtual ~NCC(void);
90
91    /**
92       \brief Get the centroids for all classes.
93
94       \return The centroids for each class as columns in a matrix.
95    */
96    const utility::Matrix& centroids(void) const;
97
98    NCC<Distance>* make_classifier(void) const;
99       
100    /**
101       \brief Make predictions for unweighted test data.
102
103       Predictions are calculated and returned in \a results.  For
104       each sample in \a data, \a results contains the distances to
105       the centroids for each class. If a class has no training
106       samples NaN's are returned for this class in \a
107       results. Weighted distance calculations, in which NaN's have
108       zero weights, are used if the centroids contain NaN's.
109       
110       \note NCC returns distances to centroids as the
111       prediction. This means that the best class for a sample has the
112       smallest value in \a results. This is in contrast to, for
113       example, KNN for which the best class for a sample in \a
114       results has the largest number (the largest number of nearest
115       neighbors).
116    */
117    void predict(const MatrixLookup& data, utility::Matrix& results) const;
118   
119    /**
120       \brief Make predictions for weighted test data.
121
122       Predictions are calculated and returned in \a results.  For
123       each sample in \a data, \a results contains the distances to
124       the centroids for each class as in predict(const MatrixLookup&
125       data, utility::Matrix& results). Weighted distance calculations
126       are used, and zero weights are used for NaN's in centroids.  If
127       for a test sample and centroid pair, all variables have either
128       zero weight for the test sample or NaN for the centroid, the
129       centroid and the sample have no variables with values in
130       common. In this case the prediction for the sample is set to
131       NaN for the class in \a results.
132       
133       \note NCC returns distances to centroids as the
134       prediction. This means that the best class for a sample has the
135       smallest value in \a results. This is in contrast to, for
136       example, KNN for which the best class for a sample in \a
137       results has the largest number (the largest number of nearest
138       neighbors).
139    */
140    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
141
142    /**
143       \brief Train the NCC using unweighted training data with known
144       targets.
145
146       A centroid is calculated for each class. For each variable in
147       \a data, a centroid for a class contains the average value of
148       the variable across all training samples in the class.
149    */
150    void train(const MatrixLookup& data, const Target& targets);
151
152
153    /**
154       \brief Train the NCC using weighted training data with known
155       targets.
156
157       A centroid is calculated for each class as in
158       train(const MatrixLookup&, const Target&). 
159       The weights of the data are used when calculating the centroids
160       and the centroids should be interpreted as unweighted
161       (i.e. centroid values have unity weights). If a variable has
162       zero weights for all samples in a class, the centroid is set to
163       NaN for that variable.
164    */
165    void train(const MatrixLookupWeighted& data, const Target& targets);
166
167   
168  private:
169
170    void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
171    void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;   
172
173    utility::Matrix centroids_;
174    bool centroids_nan_;
175    Distance distance_;
176  }; 
177
178  // templates
179
180  template <typename Distance>
181  NCC<Distance>::NCC() 
182    : SupervisedClassifier(), centroids_nan_(false)
183  {
184    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
185  }
186
187  template <typename Distance>
188  NCC<Distance>::NCC(const Distance& dist) 
189    : SupervisedClassifier(), centroids_nan_(false), distance_(dist)
190  {
191    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
192  }
193
194
195  template <typename Distance>
196  NCC<Distance>::~NCC()   
197  {
198  }
199
200
201  template <typename Distance>
202  const utility::Matrix& NCC<Distance>::centroids(void) const
203  {
204    return centroids_;
205  }
206 
207
208  template <typename Distance>
209  NCC<Distance>* 
210  NCC<Distance>::make_classifier() const 
211  {     
212    // All private members should be copied here to generate an
213    // identical but untrained classifier
214    return new NCC<Distance>(distance_);
215  }
216
217  template <typename Distance>
218  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
219  {   
220    centroids_.resize(data.rows(), target.nof_classes());
221    for(size_t i=0; i<data.rows(); i++) {
222      std::vector<statistics::Averager> class_averager;
223      class_averager.resize(target.nof_classes());
224      for(size_t j=0; j<data.columns(); j++) {
225        class_averager[target(j)].add(data(i,j));
226      }
227      for(size_t c=0;c<target.nof_classes();c++) {
228        centroids_(i,c) = class_averager[c].mean();
229      }
230    }
231  }
232
233
234  template <typename Distance>
235  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
236  {   
237    centroids_.resize(data.rows(), target.nof_classes());
238    for(size_t i=0; i<data.rows(); i++) {
239      std::vector<statistics::AveragerWeighted> class_averager;
240      class_averager.resize(target.nof_classes());
241      for(size_t j=0; j<data.columns(); j++) 
242        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
243      for(size_t c=0;c<target.nof_classes();c++) {
244        if(class_averager[c].sum_w()==0) {
245          centroids_nan_=true;
246        }
247        centroids_(i,c) = class_averager[c].mean();
248      }
249    }
250  }
251
252
253  template <typename Distance>
254  void NCC<Distance>::predict(const MatrixLookup& test,                     
255                              utility::Matrix& prediction) const
256  {   
257    utility::yat_assert<utility::runtime_error>
258      (centroids_.rows()==test.rows(),
259       "NCC::predict test data with incorrect number of rows");
260   
261    prediction.resize(centroids_.columns(), test.columns());
262
263    // If weighted training data has resulted in NaN in centroids: weighted calculations
264    if(centroids_nan_) { 
265      predict_weighted(MatrixLookupWeighted(test),prediction);
266    }
267    // If unweighted training data: unweighted calculations
268    else {
269      predict_unweighted(test,prediction);
270    }
271  }
272
273  template <typename Distance>
274  void NCC<Distance>::predict(const MatrixLookupWeighted& test,                     
275                              utility::Matrix& prediction) const
276  {   
277    utility::yat_assert<utility::runtime_error>
278      (centroids_.rows()==test.rows(),
279       "NCC::predict test data with incorrect number of rows");
280   
281    prediction.resize(centroids_.columns(), test.columns());
282    predict_weighted(test,prediction);
283  }
284
285 
286  template <typename Distance>
287  void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
288                                         utility::Matrix& prediction) const
289  {
290    for(size_t j=0; j<test.columns();j++)
291      for(size_t k=0; k<centroids_.columns();k++) 
292        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
293                                    centroids_.begin_column(k));
294  }
295 
296  template <typename Distance>
297  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
298                                          utility::Matrix& prediction) const
299  {
300    utility::MatrixWeighted weighted_centroids(centroids_);
301    for(size_t j=0; j<test.columns();j++) 
302      for(size_t k=0; k<centroids_.columns();k++)
303        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
304                                    weighted_centroids.begin_column(k));
305  }
306
307     
308}}} // of namespace classifier, yat, and theplu
309
310#endif
Note: See TracBrowser for help on using the repository browser.