source: trunk/yat/classifier/KNN.h @ 1875

Last change on this file since 1875 was 1875, checked in by Peter, 13 years ago

fixes #504. Also added pp macro YAT_ASSERT that calls yat_assert with an appropriate msg

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.2 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1875 2009-03-19 12:35:47Z peter $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009 Peter Johansson
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "DataLookup1D.h"
27#include "DataLookupWeighted1D.h"
28#include "KNN_Uniform.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/Matrix.h"
34#include "yat/utility/yat_assert.h"
35
36#include <cmath>
37#include <map>
38#include <stdexcept>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  /**
45     \brief Nearest Neighbor Classifier
46     
47     A sample is predicted based on the classes of its k nearest
48     neighbors among the training data samples. KNN supports using
49     different measures, for example, Euclidean distance, to define
50     distance between samples. KNN also supports using different ways to
51     weight the votes of the k nearest neighbors. For example, using a
52     uniform vote a test sample gets a vote for each class which is the
53     number of nearest neighbors belonging to the class.
54     
55     The template argument Distance should be a class modelling the
56     concept \ref concept_distance. The template argument
57     NeighborWeighting should be a class modelling the concept \ref
58     concept_neighbor_weighting.
59  */
60  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
61  class KNN : public SupervisedClassifier
62  {
63   
64  public:
65    /**
66       \brief Default constructor.
67       
68       The number of nearest neighbors (k) is set to 3. Distance and
69       NeighborWeighting are initialized using their default
70       constructuors.
71    */
72    KNN(void);
73
74
75    /**
76       \brief Constructor using an intialized distance measure.
77       
78       The number of nearest neighbors (k) is set to
79       3. NeighborWeighting is initialized using its default
80       constructor. This constructor should be used if Distance has
81       parameters and the user wants to specify the parameters by
82       initializing Distance prior to constructing the KNN.
83    */ 
84    KNN(const Distance&);
85
86
87    /**
88       Destructor
89    */
90    virtual ~KNN();
91   
92   
93    /**
94       \brief Get the number of nearest neighbors.
95       \return The number of neighbors.
96    */
97    unsigned int k() const;
98
99    /**
100       \brief Set the number of nearest neighbors.
101       
102       Sets the number of neighbors to \a k_in.
103    */
104    void k(unsigned int k_in);
105
106
107    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
108   
109    /**
110       \brief Make predictions for unweighted test data.
111       
112       Predictions are calculated and returned in \a results.  For
113       each sample in \a data, \a results contains the weighted number
114       of nearest neighbors which belong to each class. Numbers of
115       nearest neighbors are weighted according to
116       NeighborWeighting. If a class has no training samples NaN's are
117       returned for this class in \a results.
118    */
119    void predict(const MatrixLookup& data , utility::Matrix& results) const;
120
121    /**   
122        \brief Make predictions for weighted test data.
123       
124        Predictions are calculated and returned in \a results. For
125        each sample in \a data, \a results contains the weighted
126        number of nearest neighbors which belong to each class as in
127        predict(const MatrixLookup& data, utility::Matrix& results).
128        If a test and training sample pair has no variables with
129        non-zero weights in common, there are no variables which can
130        be used to calculate the distance between the two samples. In
131        this case the distance between the two is set to infinity.
132    */
133    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
134
135
136    /**
137       \brief Train the KNN using unweighted training data with known
138       targets.
139       
140       For KNN there is no actual training; the entire training data
141       set is stored with targets. KNN only stores references to \a data
142       and \a targets as copying these would make the %classifier
143       slow. If the number of training samples set is smaller than k,
144       k is set to the number of training samples.
145       
146       \note If \a data or \a targets go out of scope ore are
147       deleted, the KNN becomes invalid and further use is undefined
148       unless it is trained again.
149    */
150    void train(const MatrixLookup& data, const Target& targets);
151   
152    /**   
153       \brief Train the KNN using weighted training data with known targets.
154   
155       See train(const MatrixLookup& data, const Target& targets) for
156       additional information.
157    */
158    void train(const MatrixLookupWeighted& data, const Target& targets);
159   
160  private:
161   
162    const MatrixLookup* data_ml_;
163    const MatrixLookupWeighted* data_mlw_;
164    const Target* target_;
165
166    // The number of neighbors
167    unsigned int k_;
168
169    Distance distance_;
170    NeighborWeighting weighting_;
171
172    void calculate_unweighted(const MatrixLookup&,
173                              const MatrixLookup&,
174                              utility::Matrix*) const;
175    void calculate_weighted(const MatrixLookupWeighted&,
176                            const MatrixLookupWeighted&,
177                            utility::Matrix*) const;
178
179    void predict_common(const utility::Matrix& distances, 
180                        utility::Matrix& prediction) const;
181
182  };
183 
184 
185  // templates
186 
187  template <typename Distance, typename NeighborWeighting>
188  KNN<Distance, NeighborWeighting>::KNN() 
189    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
190  {
191  }
192
193  template <typename Distance, typename NeighborWeighting>
194  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
195    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
196  {
197  }
198
199 
200  template <typename Distance, typename NeighborWeighting>
201  KNN<Distance, NeighborWeighting>::~KNN()   
202  {
203  }
204 
205
206  template <typename Distance, typename NeighborWeighting>
207  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
208  (const MatrixLookup& training, const MatrixLookup& test,
209   utility::Matrix* distances) const
210  {
211    for(size_t i=0; i<training.columns(); i++) {
212      for(size_t j=0; j<test.columns(); j++) {
213        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
214                                      test.begin_column(j));
215        YAT_ASSERT(!std::isnan((*distances)(i,j)));
216      }
217    }
218  }
219
220 
221  template <typename Distance, typename NeighborWeighting>
222  void 
223  KNN<Distance, NeighborWeighting>::calculate_weighted
224  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
225   utility::Matrix* distances) const
226  {
227    for(size_t i=0; i<training.columns(); i++) { 
228      for(size_t j=0; j<test.columns(); j++) {
229        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
230                                      test.begin_column(j));
231        // If the distance is NaN (no common variables with non-zero weights),
232        // the distance is set to infinity to be sorted as a neighbor at the end
233        if(std::isnan((*distances)(i,j))) 
234          (*distances)(i,j)=std::numeric_limits<double>::infinity();
235      }
236    }
237  }
238 
239 
240  template <typename Distance, typename NeighborWeighting>
241  unsigned int KNN<Distance, NeighborWeighting>::k() const
242  {
243    return k_;
244  }
245
246  template <typename Distance, typename NeighborWeighting>
247  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
248  {
249    k_=k;
250  }
251
252
253  template <typename Distance, typename NeighborWeighting>
254  KNN<Distance, NeighborWeighting>* 
255  KNN<Distance, NeighborWeighting>::make_classifier() const 
256  {     
257    // All private members should be copied here to generate an
258    // identical but untrained classifier
259    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
260    knn->weighting_=this->weighting_;
261    knn->k(this->k());
262    return knn;
263  }
264 
265 
266  template <typename Distance, typename NeighborWeighting>
267  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
268                                               const Target& target)
269  {   
270    utility::yat_assert<std::runtime_error>
271      (data.columns()==target.size(),
272       "KNN::train called with different sizes of target and data");
273    // k has to be at most the number of training samples.
274    if(data.columns()<k_) 
275      k_=data.columns();
276    data_ml_=&data;
277    data_mlw_=0;
278    target_=&target;
279  }
280
281  template <typename Distance, typename NeighborWeighting>
282  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
283                                               const Target& target)
284  {   
285    utility::yat_assert<std::runtime_error>
286      (data.columns()==target.size(),
287       "KNN::train called with different sizes of target and data");
288    // k has to be at most the number of training samples.
289    if(data.columns()<k_) 
290      k_=data.columns();
291    data_ml_=0;
292    data_mlw_=&data;
293    target_=&target;
294  }
295
296
297  template <typename Distance, typename NeighborWeighting>
298  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
299                                                 utility::Matrix& prediction) const
300  {   
301    // matrix with training samples as rows and test samples as columns
302    utility::Matrix* distances = 0;
303    // unweighted training data
304    if(data_ml_ && !data_mlw_) {
305      utility::yat_assert<std::runtime_error>
306        (data_ml_->rows()==test.rows(),
307         "KNN::predict different number of rows in training and test data");     
308      distances=new utility::Matrix(data_ml_->columns(),test.columns());
309      calculate_unweighted(*data_ml_,test,distances);
310    }
311    else if (data_mlw_ && !data_ml_) {
312      // weighted training data
313      utility::yat_assert<std::runtime_error>
314        (data_mlw_->rows()==test.rows(),
315         "KNN::predict different number of rows in training and test data");           
316      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
317      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
318                         distances);             
319    }
320    else {
321      std::runtime_error("KNN::predict no training data");
322    }
323
324    prediction.resize(target_->nof_classes(),test.columns(),0.0);
325    predict_common(*distances,prediction);
326    if(distances)
327      delete distances;
328  }
329
330  template <typename Distance, typename NeighborWeighting>
331  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
332                                                 utility::Matrix& prediction) const
333  {   
334    // matrix with training samples as rows and test samples as columns
335    utility::Matrix* distances=0; 
336    // unweighted training data
337    if(data_ml_ && !data_mlw_) { 
338      utility::yat_assert<std::runtime_error>
339        (data_ml_->rows()==test.rows(),
340         "KNN::predict different number of rows in training and test data");   
341      distances=new utility::Matrix(data_ml_->columns(),test.columns());
342      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
343    }
344    // weighted training data
345    else if (data_mlw_ && !data_ml_) {
346      utility::yat_assert<std::runtime_error>
347        (data_mlw_->rows()==test.rows(),
348         "KNN::predict different number of rows in training and test data");   
349      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
350      calculate_weighted(*data_mlw_,test,distances);             
351    }
352    else {
353      std::runtime_error("KNN::predict no training data");
354    }
355
356    prediction.resize(target_->nof_classes(),test.columns(),0.0);
357    predict_common(*distances,prediction);
358   
359    if(distances)
360      delete distances;
361  }
362 
363  template <typename Distance, typename NeighborWeighting>
364  void KNN<Distance, NeighborWeighting>::predict_common
365  (const utility::Matrix& distances, utility::Matrix& prediction) const
366  {   
367    for(size_t sample=0;sample<distances.columns();sample++) {
368      std::vector<size_t> k_index;
369      utility::VectorConstView dist=distances.column_const_view(sample);
370      utility::sort_smallest_index(k_index,k_,dist);
371      utility::VectorView pred=prediction.column_view(sample);
372      weighting_(dist,k_index,*target_,pred);
373    }
374   
375    // classes for which there are no training samples should be set
376    // to nan in the predictions
377    for(size_t c=0;c<target_->nof_classes(); c++) 
378      if(!target_->size(c)) 
379        for(size_t j=0;j<prediction.columns();j++)
380          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
381  }
382
383
384}}} // of namespace classifier, yat, and theplu
385
386#endif
387
Note: See TracBrowser for help on using the repository browser.