source: trunk/yat/classifier/NCC.h @ 3701

Last change on this file since 3701 was 3701, checked in by Peter, 6 years ago

avoid includes when not needed

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.3 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Peter Johansson, Markus Ringnér
8  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
9  Copyright (C) 2009, 2010, 2017 Peter Johansson
10
11  This file is part of the yat library, http://dev.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 3 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with yat. If not, see <http://www.gnu.org/licenses/>.
25*/
26
27#include "MatrixLookup.h"
28#include "MatrixLookupWeighted.h"
29#include "SupervisedClassifier.h"
30#include "Target.h"
31
32#include "yat/statistics/Averager.h"
33#include "yat/statistics/AveragerWeighted.h"
34#include "yat/utility/concept_check.h"
35#include "yat/utility/Exception.h"
36#include "yat/utility/Matrix.h"
37#include "yat/utility/MatrixWeighted.h"
38#include "yat/utility/Vector.h"
39#include "yat/utility/yat_assert.h"
40
41#include <boost/concept_check.hpp>
42
43#include <iterator>
44#include <map>
45#include <cmath>
46
47namespace theplu {
48namespace yat {
49namespace classifier {
50
51
52  /**
53     \brief Nearest Centroid Classifier
54
55     A sample is predicted based on its distance to centroids for each
56     class. The centroids are generated using training data. NCC
57     supports using different measures, for example, Euclidean
58     distance, to define distance between samples and centroids.
59
60     The template argument Distance should be a class modelling
61     the concept \ref concept_distance.
62  */
63  template <typename Distance>
64  class NCC : public SupervisedClassifier
65  {
66
67  public:
68    /**
69       \brief Constructor
70
71       Distance is initialized using its default constructor.
72    */
73    NCC(void);
74
75    /**
76       \brief Constructor using an initialized distance measure
77
78       This constructor should be used if Distance has parameters and
79       the user wants to specify the parameters by initializing
80       Distance prior to constructing the NCC.
81    */
82    NCC(const Distance&);
83
84
85    /**
86       Destructor
87    */
88    virtual ~NCC(void);
89
90    /**
91       \brief Get the centroids for all classes.
92
93       \return The centroids for each class as columns in a matrix.
94    */
95    const utility::Matrix& centroids(void) const;
96
97    NCC<Distance>* make_classifier(void) const;
98
99    /**
100       \brief Make predictions for unweighted test data.
101
102       Predictions are calculated and returned in \a results.  For
103       each sample in \a data, \a results contains the distances to
104       the centroids for each class. If a class has no training
105       samples NaN's are returned for this class in \a
106       results. Weighted distance calculations, in which NaN's have
107       zero weights, are used if the centroids contain NaN's.
108
109       \note NCC returns distances to centroids as the
110       prediction. This means that the best class for a sample has the
111       smallest value in \a results. This is in contrast to, for
112       example, KNN for which the best class for a sample in \a
113       results has the largest number (the largest number of nearest
114       neighbors).
115    */
116    void predict(const MatrixLookup& data, utility::Matrix& results) const;
117
118    /**
119       \brief Make predictions for weighted test data.
120
121       Predictions are calculated and returned in \a results.  For
122       each sample in \a data, \a results contains the distances to
123       the centroids for each class as in predict(const MatrixLookup&
124       data, utility::Matrix& results). Weighted distance calculations
125       are used, and zero weights are used for NaN's in centroids.  If
126       for a test sample and centroid pair, all variables have either
127       zero weight for the test sample or NaN for the centroid, the
128       centroid and the sample have no variables with values in
129       common. In this case the prediction for the sample is set to
130       NaN for the class in \a results.
131
132       \note NCC returns distances to centroids as the
133       prediction. This means that the best class for a sample has the
134       smallest value in \a results. This is in contrast to, for
135       example, KNN for which the best class for a sample in \a
136       results has the largest number (the largest number of nearest
137       neighbors).
138    */
139    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
140
141    /**
142       \brief Train the NCC using unweighted training data with known
143       targets.
144
145       A centroid is calculated for each class. For each variable in
146       \a data, a centroid for a class contains the average value of
147       the variable across all training samples in the class.
148    */
149    void train(const MatrixLookup& data, const Target& targets);
150
151
152    /**
153       \brief Train the NCC using weighted training data with known
154       targets.
155
156       A centroid is calculated for each class as in
157       train(const MatrixLookup&, const Target&).
158       The weights of the data are used when calculating the centroids
159       and the centroids should be interpreted as unweighted
160       (i.e. centroid values have unity weights). If a variable has
161       zero weights for all samples in a class, the centroid is set to
162       NaN for that variable.
163    */
164    void train(const MatrixLookupWeighted& data, const Target& targets);
165
166
167  private:
168
169    void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
170    void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;
171
172    utility::Matrix centroids_;
173    bool centroids_nan_;
174    Distance distance_;
175  };
176
177  // templates
178
179  template <typename Distance>
180  NCC<Distance>::NCC()
181    : SupervisedClassifier(), centroids_nan_(false)
182  {
183    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
184  }
185
186  template <typename Distance>
187  NCC<Distance>::NCC(const Distance& dist)
188    : SupervisedClassifier(), centroids_nan_(false), distance_(dist)
189  {
190    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
191  }
192
193
194  template <typename Distance>
195  NCC<Distance>::~NCC()
196  {
197  }
198
199
200  template <typename Distance>
201  const utility::Matrix& NCC<Distance>::centroids(void) const
202  {
203    return centroids_;
204  }
205
206
207  template <typename Distance>
208  NCC<Distance>*
209  NCC<Distance>::make_classifier() const
210  {
211    // All private members should be copied here to generate an
212    // identical but untrained classifier
213    return new NCC<Distance>(distance_);
214  }
215
216  template <typename Distance>
217  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
218  {
219    centroids_.resize(data.rows(), target.nof_classes());
220    for(size_t i=0; i<data.rows(); i++) {
221      std::vector<statistics::Averager> class_averager;
222      class_averager.resize(target.nof_classes());
223      for(size_t j=0; j<data.columns(); j++) {
224        class_averager[target(j)].add(data(i,j));
225      }
226      for(size_t c=0;c<target.nof_classes();c++) {
227        centroids_(i,c) = class_averager[c].mean();
228      }
229    }
230  }
231
232
233  template <typename Distance>
234  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
235  {
236    centroids_.resize(data.rows(), target.nof_classes());
237    for(size_t i=0; i<data.rows(); i++) {
238      std::vector<statistics::AveragerWeighted> class_averager;
239      class_averager.resize(target.nof_classes());
240      for(size_t j=0; j<data.columns(); j++)
241        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
242      for(size_t c=0;c<target.nof_classes();c++) {
243        if(class_averager[c].sum_w()==0) {
244          centroids_nan_=true;
245        }
246        centroids_(i,c) = class_averager[c].mean();
247      }
248    }
249  }
250
251
252  template <typename Distance>
253  void NCC<Distance>::predict(const MatrixLookup& test,
254                              utility::Matrix& prediction) const
255  {
256    utility::yat_assert<utility::runtime_error>
257      (centroids_.rows()==test.rows(),
258       "NCC::predict test data with incorrect number of rows");
259
260    prediction.resize(centroids_.columns(), test.columns());
261
262    // If weighted training data has resulted in NaN in centroids:
263    // weighted calculations
264    if(centroids_nan_) {
265      predict_weighted(MatrixLookupWeighted(test),prediction);
266    }
267    // If unweighted training data: unweighted calculations
268    else {
269      predict_unweighted(test,prediction);
270    }
271  }
272
273  template <typename Distance>
274  void NCC<Distance>::predict(const MatrixLookupWeighted& test,
275                              utility::Matrix& prediction) const
276  {
277    utility::yat_assert<utility::runtime_error>
278      (centroids_.rows()==test.rows(),
279       "NCC::predict test data with incorrect number of rows");
280
281    prediction.resize(centroids_.columns(), test.columns());
282    predict_weighted(test,prediction);
283  }
284
285
286  template <typename Distance>
287  void NCC<Distance>::predict_unweighted(const MatrixLookup& test,
288                                         utility::Matrix& prediction) const
289  {
290    for(size_t j=0; j<test.columns();j++)
291      for(size_t k=0; k<centroids_.columns();k++)
292        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j),
293                                    centroids_.begin_column(k));
294  }
295
296  template <typename Distance>
297  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test,
298                                       utility::Matrix& prediction) const
299  {
300    utility::MatrixWeighted weighted_centroids(centroids_);
301    for(size_t j=0; j<test.columns();j++)
302      for(size_t k=0; k<centroids_.columns();k++)
303        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j),
304                                    weighted_centroids.begin_column(k));
305  }
306
307
308}}} // of namespace classifier, yat, and theplu
309
310#endif
Note: See TracBrowser for help on using the repository browser.