source: trunk/yat/classifier/NCC.h @ 1706

Last change on this file since 1706 was 1487, checked in by Jari Häkkinen, 13 years ago

Addresses #436. GPL license copy reference should also be updated.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.2 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Peter Johansson, Markus Ringnér
8  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "SupervisedClassifier.h"
29#include "Target.h"
30
31#include "yat/statistics/Averager.h"
32#include "yat/statistics/AveragerWeighted.h"
33#include "yat/utility/Matrix.h"
34#include "yat/utility/MatrixWeighted.h"
35#include "yat/utility/Vector.h"
36#include "yat/utility/stl_utility.h"
37#include "yat/utility/yat_assert.h"
38
39#include<iostream>
40#include<iterator>
41#include <map>
42#include <cmath>
43#include <stdexcept>
44
45namespace theplu {
46namespace yat {
47namespace classifier { 
48
49
50  /**
51     \brief Nearest Centroid Classifier
52     
53     A sample is predicted based on its distance to centroids for each
54     class. The centroids are generated using training data. NCC
55     supports using different measures, for example, Euclidean
56     distance, to define distance between samples and centroids.     
57
58     The template argument Distance should be a class modelling
59     the concept \ref concept_distance.
60  */
61  template <typename Distance>
62  class NCC : public SupervisedClassifier
63  {
64 
65  public:
66    /**
67       \brief Constructor
68       
69       Distance is initialized using its default constructor.   
70    */
71    NCC(void);
72   
73    /**
74       \brief Constructor using an initialized distance measure
75
76       This constructor should be used if Distance has parameters and
77       the user wants to specify the parameters by initializing
78       Distance prior to constructing the NCC.
79    */
80    NCC(const Distance&);
81
82
83    /**
84       Destructor
85    */
86    virtual ~NCC(void);
87
88    /**
89       \brief Get the centroids for all classes.
90
91       \return The centroids for each class as columns in a matrix.
92    */
93    const utility::Matrix& centroids(void) const;
94
95    NCC<Distance>* make_classifier(void) const;
96       
97    /**
98       \brief Make predictions for unweighted test data.
99
100       Predictions are calculated and returned in \a results.  For
101       each sample in \a data, \a results contains the distances to
102       the centroids for each class. If a class has no training
103       samples NaN's are returned for this class in \a
104       results. Weighted distance calculations, in which NaN's have
105       zero weights, are used if the centroids contain NaN's.
106       
107       \note NCC returns distances to centroids as the
108       prediction. This means that the best class for a sample has the
109       smallest value in \a results. This is in contrast to, for
110       example, KNN for which the best class for a sample in \a
111       results has the largest number (the largest number of nearest
112       neighbors).
113    */
114    void predict(const MatrixLookup& data, utility::Matrix& results) const;
115   
116    /**
117       \brief Make predictions for weighted test data.
118
119       Predictions are calculated and returned in \a results.  For
120       each sample in \a data, \a results contains the distances to
121       the centroids for each class as in predict(const MatrixLookup&
122       data, utility::Matrix& results). Weighted distance calculations
123       are used, and zero weights are used for NaN's in centroids.  If
124       for a test sample and centroid pair, all variables have either
125       zero weight for the test sample or NaN for the centroid, the
126       centroid and the sample have no variables with values in
127       common. In this case the prediction for the sample is set to
128       NaN for the class in \a results.
129       
130       \note NCC returns distances to centroids as the
131       prediction. This means that the best class for a sample has the
132       smallest value in \a results. This is in contrast to, for
133       example, KNN for which the best class for a sample in \a
134       results has the largest number (the largest number of nearest
135       neighbors).
136    */
137    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
138
139    /**
140       \brief Train the NCC using unweighted training data with known
141       targets.
142
143       A centroid is calculated for each class. For each variable in
144       \a data, a centroid for a class contains the average value of
145       the variable across all training samples in the class.
146    */
147    void train(const MatrixLookup& data, const Target& targets);
148
149
150    /**
151       \brief Train the NCC using weighted training data with known
152       targets.
153
154       A centroid is calculated for each class as in
155       train(const MatrixLookup&, const Target&). 
156       The weights of the data are used when calculating the centroids
157       and the centroids should be interpreted as unweighted
158       (i.e. centroid values have unity weights). If a variable has
159       zero weights for all samples in a class, the centroid is set to
160       NaN for that variable.
161    */
162    void train(const MatrixLookupWeighted& data, const Target& targets);
163
164   
165  private:
166
167    void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
168    void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;   
169
170    utility::Matrix centroids_;
171    bool centroids_nan_;
172    Distance distance_;
173  }; 
174
175  // templates
176
177  template <typename Distance>
178  NCC<Distance>::NCC() 
179    : SupervisedClassifier(), centroids_nan_(false)
180  {
181  }
182
183  template <typename Distance>
184  NCC<Distance>::NCC(const Distance& dist) 
185    : SupervisedClassifier(), centroids_nan_(false), distance_(dist)
186  {
187  }
188
189
190  template <typename Distance>
191  NCC<Distance>::~NCC()   
192  {
193  }
194
195
196  template <typename Distance>
197  const utility::Matrix& NCC<Distance>::centroids(void) const
198  {
199    return centroids_;
200  }
201 
202
203  template <typename Distance>
204  NCC<Distance>* 
205  NCC<Distance>::make_classifier() const 
206  {     
207    // All private members should be copied here to generate an
208    // identical but untrained classifier
209    return new NCC<Distance>(distance_);
210  }
211
212  template <typename Distance>
213  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
214  {   
215    centroids_.resize(data.rows(), target.nof_classes());
216    for(size_t i=0; i<data.rows(); i++) {
217      std::vector<statistics::Averager> class_averager;
218      class_averager.resize(target.nof_classes());
219      for(size_t j=0; j<data.columns(); j++) {
220        class_averager[target(j)].add(data(i,j));
221      }
222      for(size_t c=0;c<target.nof_classes();c++) {
223        centroids_(i,c) = class_averager[c].mean();
224      }
225    }
226  }
227
228
229  template <typename Distance>
230  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
231  {   
232    centroids_.resize(data.rows(), target.nof_classes());
233    for(size_t i=0; i<data.rows(); i++) {
234      std::vector<statistics::AveragerWeighted> class_averager;
235      class_averager.resize(target.nof_classes());
236      for(size_t j=0; j<data.columns(); j++) 
237        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
238      for(size_t c=0;c<target.nof_classes();c++) {
239        if(class_averager[c].sum_w()==0) {
240          centroids_nan_=true;
241        }
242        centroids_(i,c) = class_averager[c].mean();
243      }
244    }
245  }
246
247
248  template <typename Distance>
249  void NCC<Distance>::predict(const MatrixLookup& test,                     
250                              utility::Matrix& prediction) const
251  {   
252    utility::yat_assert<std::runtime_error>
253      (centroids_.rows()==test.rows(),
254       "NCC::predict test data with incorrect number of rows");
255   
256    prediction.resize(centroids_.columns(), test.columns());
257
258    // If weighted training data has resulted in NaN in centroids: weighted calculations
259    if(centroids_nan_) { 
260      predict_weighted(MatrixLookupWeighted(test),prediction);
261    }
262    // If unweighted training data: unweighted calculations
263    else {
264      predict_unweighted(test,prediction);
265    }
266  }
267
268  template <typename Distance>
269  void NCC<Distance>::predict(const MatrixLookupWeighted& test,                     
270                              utility::Matrix& prediction) const
271  {   
272    utility::yat_assert<std::runtime_error>
273      (centroids_.rows()==test.rows(),
274       "NCC::predict test data with incorrect number of rows");
275   
276    prediction.resize(centroids_.columns(), test.columns());
277    predict_weighted(test,prediction);
278  }
279
280 
281  template <typename Distance>
282  void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
283                                         utility::Matrix& prediction) const
284  {
285    for(size_t j=0; j<test.columns();j++)
286      for(size_t k=0; k<centroids_.columns();k++) 
287        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
288                                    centroids_.begin_column(k));
289  }
290 
291  template <typename Distance>
292  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
293                                          utility::Matrix& prediction) const
294  {
295    utility::MatrixWeighted weighted_centroids(centroids_);
296    for(size_t j=0; j<test.columns();j++) 
297      for(size_t k=0; k<centroids_.columns();k++)
298        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
299                                    weighted_centroids.begin_column(k));
300  }
301
302     
303}}} // of namespace classifier, yat, and theplu
304
305#endif
Note: See TracBrowser for help on using the repository browser.