source: branches/0.4-stable/yat/classifier/NCC.h @ 1392

Last change on this file since 1392 was 1392, checked in by Peter, 13 years ago

trac has moved

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.3 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Peter Johansson, Markus Ringnér
8  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 2 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with this program; if not, write to the Free Software
24  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
25  02111-1307, USA.
26*/
27
28#include "MatrixLookup.h"
29#include "MatrixLookupWeighted.h"
30#include "SupervisedClassifier.h"
31#include "Target.h"
32
33#include "yat/statistics/Averager.h"
34#include "yat/statistics/AveragerWeighted.h"
35#include "yat/utility/Matrix.h"
36#include "yat/utility/Vector.h"
37#include "yat/utility/stl_utility.h"
38#include "yat/utility/yat_assert.h"
39
40#include<iostream>
41#include<iterator>
42#include <map>
43#include <cmath>
44#include <stdexcept>
45
46namespace theplu {
47namespace yat {
48namespace classifier { 
49
50
51  /**
52     \brief Nearest Centroid Classifier
53     
54     A sample is predicted based on its distance to centroids for each
55     class. The centroids are generated using training data. NCC
56     supports using different measures, for example, Euclidean
57     distance, to define distance between samples and centroids.     
58
59     The template argument Distance should be a class modelling
60     the concept \ref concept_distance.
61  */
62  template <typename Distance>
63  class NCC : public SupervisedClassifier
64  {
65 
66  public:
67    /**
68       \brief Constructor
69       
70       Distance is initialized using its default constructor.   
71    */
72    NCC(void);
73   
74    /**
75       \brief Constructor using an initialized distance measure
76
77       This constructor should be used if Distance has parameters and
78       the user wants to specify the parameters by initializing
79       Distance prior to constructing the NCC.
80    */
81    NCC(const Distance&);
82
83
84    /**
85       Destructor
86    */
87    virtual ~NCC(void);
88
89    /**
90       \brief Get the centroids for all classes.
91
92       \return The centroids for each class as columns in a matrix.
93    */
94    const utility::Matrix& centroids(void) const;
95
96    NCC<Distance>* make_classifier(void) const;
97       
98    /**
99       \brief Make predictions for unweighted test data.
100
101       Predictions are calculated and returned in \a results.  For
102       each sample in \a data, \a results contains the distances to
103       the centroids for each class. If a class has no training
104       samples NaN's are returned for this class in \a
105       results. Weighted distance calculations, in which NaN's have
106       zero weights, are used if the centroids contain NaN's.
107       
108       \note NCC returns distances to centroids as the
109       prediction. This means that the best class for a sample has the
110       smallest value in \a results. This is in contrast to, for
111       example, KNN for which the best class for a sample in \a
112       results has the largest number (the largest number of nearest
113       neighbors).
114    */
115    void predict(const MatrixLookup& data, utility::Matrix& results) const;
116   
117    /**
118       \brief Make predictions for weighted test data.
119
120       Predictions are calculated and returned in \a results.  For
121       each sample in \a data, \a results contains the distances to
122       the centroids for each class as in predict(const MatrixLookup&
123       data, utility::Matrix& results). Weighted distance calculations
124       are used, and zero weights are used for NaN's in centroids.  If
125       for a test sample and centroid pair, all variables have either
126       zero weight for the test sample or NaN for the centroid, the
127       centroid and the sample have no variables with values in
128       common. In this case the prediction for the sample is set to
129       NaN for the class in \a results.
130       
131       \note NCC returns distances to centroids as the
132       prediction. This means that the best class for a sample has the
133       smallest value in \a results. This is in contrast to, for
134       example, KNN for which the best class for a sample in \a
135       results has the largest number (the largest number of nearest
136       neighbors).
137    */
138    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
139
140    /**
141       \brief Train the NCC using unweighted training data with known
142       targets.
143
144       A centroid is calculated for each class. For each variable in
145       \a data, a centroid for a class contains the average value of
146       the variable across all training samples in the class.
147    */
148    void train(const MatrixLookup& data, const Target& targets);
149
150
151    /**
152       \brief Train the NCC using weighted training data with known
153       targets.
154
155       A centroid is calculated for each class as in
156       train(const MatrixLookup&, const Target&). 
157       The weights of the data are used when calculating the centroids
158       and the centroids should be interpreted as unweighted
159       (i.e. centroid values have unity weights). If a variable has
160       zero weights for all samples in a class, the centroid is set to
161       NaN for that variable.
162    */
163    void train(const MatrixLookupWeighted& data, const Target& targets);
164
165   
166  private:
167
168    void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
169    void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;   
170
171    utility::Matrix centroids_;
172    bool centroids_nan_;
173    Distance distance_;
174  }; 
175
176  // templates
177
178  template <typename Distance>
179  NCC<Distance>::NCC() 
180    : SupervisedClassifier(), centroids_nan_(false)
181  {
182  }
183
184  template <typename Distance>
185  NCC<Distance>::NCC(const Distance& dist) 
186    : SupervisedClassifier(), centroids_nan_(false), distance_(dist)
187  {
188  }
189
190
191  template <typename Distance>
192  NCC<Distance>::~NCC()   
193  {
194  }
195
196
197  template <typename Distance>
198  const utility::Matrix& NCC<Distance>::centroids(void) const
199  {
200    return centroids_;
201  }
202 
203
204  template <typename Distance>
205  NCC<Distance>* 
206  NCC<Distance>::make_classifier() const 
207  {     
208    // All private members should be copied here to generate an
209    // identical but untrained classifier
210    return new NCC<Distance>(distance_);
211  }
212
213  template <typename Distance>
214  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
215  {   
216    centroids_.resize(data.rows(), target.nof_classes());
217    for(size_t i=0; i<data.rows(); i++) {
218      std::vector<statistics::Averager> class_averager;
219      class_averager.resize(target.nof_classes());
220      for(size_t j=0; j<data.columns(); j++) {
221        class_averager[target(j)].add(data(i,j));
222      }
223      for(size_t c=0;c<target.nof_classes();c++) {
224        centroids_(i,c) = class_averager[c].mean();
225      }
226    }
227  }
228
229
230  template <typename Distance>
231  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
232  {   
233    centroids_.resize(data.rows(), target.nof_classes());
234    for(size_t i=0; i<data.rows(); i++) {
235      std::vector<statistics::AveragerWeighted> class_averager;
236      class_averager.resize(target.nof_classes());
237      for(size_t j=0; j<data.columns(); j++) 
238        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
239      for(size_t c=0;c<target.nof_classes();c++) {
240        if(class_averager[c].sum_w()==0) {
241          centroids_nan_=true;
242        }
243        centroids_(i,c) = class_averager[c].mean();
244      }
245    }
246  }
247
248
249  template <typename Distance>
250  void NCC<Distance>::predict(const MatrixLookup& test,                     
251                              utility::Matrix& prediction) const
252  {   
253    utility::yat_assert<std::runtime_error>
254      (centroids_.rows()==test.rows(),
255       "NCC::predict test data with incorrect number of rows");
256   
257    prediction.resize(centroids_.columns(), test.columns());
258
259    // If weighted training data has resulted in NaN in centroids: weighted calculations
260    if(centroids_nan_) { 
261      predict_weighted(MatrixLookupWeighted(test),prediction);
262    }
263    // If unweighted training data: unweighted calculations
264    else {
265      predict_unweighted(test,prediction);
266    }
267  }
268
269  template <typename Distance>
270  void NCC<Distance>::predict(const MatrixLookupWeighted& test,                     
271                              utility::Matrix& prediction) const
272  {   
273    utility::yat_assert<std::runtime_error>
274      (centroids_.rows()==test.rows(),
275       "NCC::predict test data with incorrect number of rows");
276   
277    prediction.resize(centroids_.columns(), test.columns());
278    predict_weighted(test,prediction);
279  }
280
281 
282  template <typename Distance>
283  void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
284                                         utility::Matrix& prediction) const
285  {
286    for(size_t j=0; j<test.columns();j++)
287      for(size_t k=0; k<centroids_.columns();k++) 
288        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
289                                    centroids_.begin_column(k));
290  }
291 
292  template <typename Distance>
293  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
294                                          utility::Matrix& prediction) const
295  {
296    MatrixLookupWeighted weighted_centroids(centroids_);
297    for(size_t j=0; j<test.columns();j++) 
298      for(size_t k=0; k<centroids_.columns();k++)
299        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
300                                    weighted_centroids.begin_column(k));
301  }
302
303     
304}}} // of namespace classifier, yat, and theplu
305
306#endif
Note: See TracBrowser for help on using the repository browser.