source: trunk/yat/classifier/KNN.h @ 1160

Last change on this file since 1160 was 1160, checked in by Markus Ringnér, 14 years ago

Fixes #333

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.9 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1160 2008-02-26 15:29:50Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// @brief Constructor
60    ///
61    KNN(void);
62
63
64    ///
65    /// @brief Constructor
66    ///
67    KNN(const Distance&);
68
69
70    ///
71    /// @brief Destructor
72    ///
73    virtual ~KNN();
74   
75
76    ///
77    /// Default number of neighbors (k) is set to 3.
78    ///
79    /// @return the number of neighbors
80    ///
81    u_int k() const;
82
83    ///
84    /// @brief sets the number of neighbors, k.
85    ///
86    void k(u_int k_in);
87
88
89    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
90   
91    ///
92    /// Train the classifier using training data and target.
93    ///
94    /// If the number of training samples set is smaller than \a k_in,
95    /// k is set to the number of training samples.
96    ///
97    void train(const MatrixLookup&, const Target&);
98
99    ///
100    /// Train the classifier using weighted training data and target.
101    ///
102    void train(const MatrixLookupWeighted&, const Target&);
103
104   
105    ///
106    /// For each sample, calculate the number of neighbors for each
107    /// class.
108    ///
109    void predict(const MatrixLookup&, utility::Matrix&) const;
110
111    ///
112    /// For each sample, calculate the number of neighbors for each
113    /// class.
114    ///
115    void predict(const MatrixLookupWeighted&, utility::Matrix&) const;
116
117
118  private:
119
120    const MatrixLookup* data_ml_;
121    const MatrixLookupWeighted* data_mlw_;
122    const Target* target_;
123
124    // The number of neighbors
125    u_int k_;
126
127    Distance distance_;
128
129    NeighborWeighting weighting_;
130
131    void calculate_unweighted(const MatrixLookup&,
132                              const MatrixLookup&,
133                              utility::Matrix*) const;
134    void calculate_weighted(const MatrixLookupWeighted&,
135                            const MatrixLookupWeighted&,
136                            utility::Matrix*) const;
137
138    void predict_common(const utility::Matrix& distances, 
139                        utility::Matrix& prediction) const;
140
141  };
142 
143 
144  // templates
145 
146  template <typename Distance, typename NeighborWeighting>
147  KNN<Distance, NeighborWeighting>::KNN() 
148    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
149  {
150  }
151
152  template <typename Distance, typename NeighborWeighting>
153  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
154    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
155  {
156  }
157
158 
159  template <typename Distance, typename NeighborWeighting>
160  KNN<Distance, NeighborWeighting>::~KNN()   
161  {
162  }
163 
164
165  template <typename Distance, typename NeighborWeighting>
166  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
167  (const MatrixLookup& training, const MatrixLookup& test,
168   utility::Matrix* distances) const
169  {
170    for(size_t i=0; i<training.columns(); i++) {
171      for(size_t j=0; j<test.columns(); j++) {
172        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
173                                      test.begin_column(j));
174        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
175      }
176    }
177  }
178
179 
180  template <typename Distance, typename NeighborWeighting>
181  void 
182  KNN<Distance, NeighborWeighting>::calculate_weighted
183  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
184   utility::Matrix* distances) const
185  {
186    for(size_t i=0; i<training.columns(); i++) { 
187      for(size_t j=0; j<test.columns(); j++) {
188        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
189                                      test.begin_column(j));
190        // If the distance is NaN (no common variables with non-zero weights),
191        // the distance is set to infinity to be sorted as a neighbor at the end
192        if(std::isnan((*distances)(i,j))) 
193          (*distances)(i,j)=std::numeric_limits<double>::infinity();
194      }
195    }
196  }
197 
198 
199  template <typename Distance, typename NeighborWeighting>
200  u_int KNN<Distance, NeighborWeighting>::k() const
201  {
202    return k_;
203  }
204
205  template <typename Distance, typename NeighborWeighting>
206  void KNN<Distance, NeighborWeighting>::k(u_int k)
207  {
208    k_=k;
209  }
210
211
212  template <typename Distance, typename NeighborWeighting>
213  KNN<Distance, NeighborWeighting>* 
214  KNN<Distance, NeighborWeighting>::make_classifier() const 
215  {     
216    KNN* knn=new KNN<Distance, NeighborWeighting>();
217    knn->k(this->k());
218    return knn;
219  }
220 
221 
222  template <typename Distance, typename NeighborWeighting>
223  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
224                                               const Target& target)
225  {   
226    utility::yat_assert<std::runtime_error>
227      (data.columns()==target.size(),
228       "KNN::train called with different sizes of target and data");
229    // k has to be at most the number of training samples.
230    if(data.columns()<k_) 
231      k_=data.columns();
232    data_ml_=&data;
233    data_mlw_=0;
234    target_=&target;
235    trained_=true;
236  }
237
238  template <typename Distance, typename NeighborWeighting>
239  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
240                                               const Target& target)
241  {   
242    utility::yat_assert<std::runtime_error>
243      (data.columns()==target.size(),
244       "KNN::train called with different sizes of target and data");
245    // k has to be at most the number of training samples.
246    if(data.columns()<k_) 
247      k_=data.columns();
248    data_ml_=0;
249    data_mlw_=&data;
250    target_=&target;
251    trained_=true;
252  }
253
254
255  template <typename Distance, typename NeighborWeighting>
256  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
257                                                 utility::Matrix& prediction) const
258  {   
259    // matrix with training samples as rows and test samples as columns
260    utility::Matrix* distances = 0;
261    // unweighted training data
262    if(data_ml_ && !data_mlw_) {
263      utility::yat_assert<std::runtime_error>
264        (data_ml_->rows()==test.rows(),
265         "KNN::predict different number of rows in training and test data");     
266      distances=new utility::Matrix(data_ml_->columns(),test.columns());
267      calculate_unweighted(*data_ml_,test,distances);
268    }
269    else if (data_mlw_ && !data_ml_) {
270      // weighted training data
271      utility::yat_assert<std::runtime_error>
272        (data_mlw_->rows()==test.rows(),
273         "KNN::predict different number of rows in training and test data");           
274      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
275      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
276                         distances);             
277    }
278    else {
279      std::runtime_error("KNN::predict no training data");
280    }
281
282    prediction.resize(target_->nof_classes(),test.columns(),0.0);
283    predict_common(*distances,prediction);
284    if(distances)
285      delete distances;
286  }
287
288  template <typename Distance, typename NeighborWeighting>
289  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
290                                                 utility::Matrix& prediction) const
291  {   
292    // matrix with training samples as rows and test samples as columns
293    utility::Matrix* distances=0; 
294    // unweighted training data
295    if(data_ml_ && !data_mlw_) { 
296      utility::yat_assert<std::runtime_error>
297        (data_ml_->rows()==test.rows(),
298         "KNN::predict different number of rows in training and test data");   
299      distances=new utility::Matrix(data_ml_->columns(),test.columns());
300      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
301    }
302    // weighted training data
303    else if (data_mlw_ && !data_ml_) {
304      utility::yat_assert<std::runtime_error>
305        (data_mlw_->rows()==test.rows(),
306         "KNN::predict different number of rows in training and test data");   
307      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
308      calculate_weighted(*data_mlw_,test,distances);             
309    }
310    else {
311      std::runtime_error("KNN::predict no training data");
312    }
313
314    prediction.resize(target_->nof_classes(),test.columns(),0.0);
315    predict_common(*distances,prediction);
316   
317    if(distances)
318      delete distances;
319  }
320 
321  template <typename Distance, typename NeighborWeighting>
322  void KNN<Distance, NeighborWeighting>::predict_common
323  (const utility::Matrix& distances, utility::Matrix& prediction) const
324  {   
325    for(size_t sample=0;sample<distances.columns();sample++) {
326      std::vector<size_t> k_index;
327      utility::VectorConstView dist=distances.column_const_view(sample);
328      utility::sort_smallest_index(k_index,k_,dist);
329      utility::VectorView pred=prediction.column_view(sample);
330      weighting_(dist,k_index,*target_,pred);
331    }
332   
333    // classes for which there are no training samples should be set
334    // to nan in the predictions
335    for(size_t c=0;c<target_->nof_classes(); c++) 
336      if(!target_->size(c)) 
337        for(size_t j=0;j<prediction.columns();j++)
338          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
339  }
340
341
342}}} // of namespace classifier, yat, and theplu
343
344#endif
345
Note: See TracBrowser for help on using the repository browser.