source: trunk/yat/classifier/KNN.h @ 2334

Last change on this file since 2334 was 2334, checked in by Peter, 11 years ago

fixes #625

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.5 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 2334 2010-10-15 02:35:09Z peter $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009 Peter Johansson
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "DataLookup1D.h"
27#include "DataLookupWeighted1D.h"
28#include "KNN_Uniform.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/concept_check.h"
34#include "yat/utility/Exception.h"
35#include "yat/utility/Matrix.h"
36#include "yat/utility/yat_assert.h"
37
38#include <boost/concept_check.hpp>
39
40#include <cmath>
41#include <limits>
42#include <map>
43#include <stdexcept>
44
45namespace theplu {
46namespace yat {
47namespace classifier {
48
49  /**
50     \brief Nearest Neighbor Classifier
51     
52     A sample is predicted based on the classes of its k nearest
53     neighbors among the training data samples. KNN supports using
54     different measures, for example, Euclidean distance, to define
55     distance between samples. KNN also supports using different ways to
56     weight the votes of the k nearest neighbors. For example, using a
57     uniform vote a test sample gets a vote for each class which is the
58     number of nearest neighbors belonging to the class.
59     
60     The template argument Distance should be a class modelling the
61     concept \ref concept_distance. The template argument
62     NeighborWeighting should be a class modelling the concept \ref
63     concept_neighbor_weighting.
64  */
65  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
66  class KNN : public SupervisedClassifier
67  {
68   
69  public:
70    /**
71       \brief Default constructor.
72       
73       The number of nearest neighbors (k) is set to 3. Distance and
74       NeighborWeighting are initialized using their default
75       constructuors.
76    */
77    KNN(void);
78
79
80    /**
81       \brief Constructor using an intialized distance measure.
82       
83       The number of nearest neighbors (k) is set to
84       3. NeighborWeighting is initialized using its default
85       constructor. This constructor should be used if Distance has
86       parameters and the user wants to specify the parameters by
87       initializing Distance prior to constructing the KNN.
88    */ 
89    KNN(const Distance&);
90
91
92    /**
93       Destructor
94    */
95    virtual ~KNN();
96   
97   
98    /**
99       \brief Get the number of nearest neighbors.
100       \return The number of neighbors.
101    */
102    unsigned int k() const;
103
104    /**
105       \brief Set the number of nearest neighbors.
106       
107       Sets the number of neighbors to \a k_in.
108    */
109    void k(unsigned int k_in);
110
111
112    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
113   
114    /**
115       \brief Make predictions for unweighted test data.
116       
117       Predictions are calculated and returned in \a results.  For
118       each sample in \a data, \a results contains the weighted number
119       of nearest neighbors which belong to each class. Numbers of
120       nearest neighbors are weighted according to
121       NeighborWeighting. If a class has no training samples NaN's are
122       returned for this class in \a results.
123    */
124    void predict(const MatrixLookup& data , utility::Matrix& results) const;
125
126    /**   
127        \brief Make predictions for weighted test data.
128       
129        Predictions are calculated and returned in \a results. For
130        each sample in \a data, \a results contains the weighted
131        number of nearest neighbors which belong to each class as in
132        predict(const MatrixLookup& data, utility::Matrix& results).
133        If a test and training sample pair has no variables with
134        non-zero weights in common, there are no variables which can
135        be used to calculate the distance between the two samples. In
136        this case the distance between the two is set to infinity.
137    */
138    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
139
140
141    /**
142       \brief Train the KNN using unweighted training data with known
143       targets.
144       
145       For KNN there is no actual training; the entire training data
146       set is stored with targets. KNN only stores references to \a data
147       and \a targets as copying these would make the %classifier
148       slow. If the number of training samples set is smaller than k,
149       k is set to the number of training samples.
150       
151       \note If \a data or \a targets go out of scope ore are
152       deleted, the KNN becomes invalid and further use is undefined
153       unless it is trained again.
154    */
155    void train(const MatrixLookup& data, const Target& targets);
156   
157    /**   
158       \brief Train the KNN using weighted training data with known targets.
159   
160       See train(const MatrixLookup& data, const Target& targets) for
161       additional information.
162    */
163    void train(const MatrixLookupWeighted& data, const Target& targets);
164   
165  private:
166   
167    const MatrixLookup* data_ml_;
168    const MatrixLookupWeighted* data_mlw_;
169    const Target* target_;
170
171    // The number of neighbors
172    unsigned int k_;
173
174    Distance distance_;
175    NeighborWeighting weighting_;
176
177    void calculate_unweighted(const MatrixLookup&,
178                              const MatrixLookup&,
179                              utility::Matrix*) const;
180    void calculate_weighted(const MatrixLookupWeighted&,
181                            const MatrixLookupWeighted&,
182                            utility::Matrix*) const;
183
184    void predict_common(const utility::Matrix& distances, 
185                        utility::Matrix& prediction) const;
186
187  };
188 
189 
190  // templates
191 
192  template <typename Distance, typename NeighborWeighting>
193  KNN<Distance, NeighborWeighting>::KNN() 
194    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
195  {
196    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
197  }
198
199  template <typename Distance, typename NeighborWeighting>
200  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
201    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
202  {
203    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
204  }
205
206 
207  template <typename Distance, typename NeighborWeighting>
208  KNN<Distance, NeighborWeighting>::~KNN()   
209  {
210  }
211 
212
213  template <typename Distance, typename NeighborWeighting>
214  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
215  (const MatrixLookup& training, const MatrixLookup& test,
216   utility::Matrix* distances) const
217  {
218    for(size_t i=0; i<training.columns(); i++) {
219      for(size_t j=0; j<test.columns(); j++) {
220        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
221                                      test.begin_column(j));
222        YAT_ASSERT(!std::isnan((*distances)(i,j)));
223      }
224    }
225  }
226
227 
228  template <typename Distance, typename NeighborWeighting>
229  void 
230  KNN<Distance, NeighborWeighting>::calculate_weighted
231  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
232   utility::Matrix* distances) const
233  {
234    for(size_t i=0; i<training.columns(); i++) { 
235      for(size_t j=0; j<test.columns(); j++) {
236        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
237                                      test.begin_column(j));
238        // If the distance is NaN (no common variables with non-zero weights),
239        // the distance is set to infinity to be sorted as a neighbor at the end
240        if(std::isnan((*distances)(i,j))) 
241          (*distances)(i,j)=std::numeric_limits<double>::infinity();
242      }
243    }
244  }
245 
246 
247  template <typename Distance, typename NeighborWeighting>
248  unsigned int KNN<Distance, NeighborWeighting>::k() const
249  {
250    return k_;
251  }
252
253  template <typename Distance, typename NeighborWeighting>
254  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
255  {
256    k_=k;
257  }
258
259
260  template <typename Distance, typename NeighborWeighting>
261  KNN<Distance, NeighborWeighting>* 
262  KNN<Distance, NeighborWeighting>::make_classifier() const 
263  {     
264    // All private members should be copied here to generate an
265    // identical but untrained classifier
266    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
267    knn->weighting_=this->weighting_;
268    knn->k(this->k());
269    return knn;
270  }
271 
272 
273  template <typename Distance, typename NeighborWeighting>
274  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
275                                               const Target& target)
276  {   
277    utility::yat_assert<utility::runtime_error>
278      (data.columns()==target.size(),
279       "KNN::train called with different sizes of target and data");
280    // k has to be at most the number of training samples.
281    if(data.columns()<k_) 
282      k_=data.columns();
283    data_ml_=&data;
284    data_mlw_=0;
285    target_=&target;
286  }
287
288  template <typename Distance, typename NeighborWeighting>
289  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
290                                               const Target& target)
291  {   
292    utility::yat_assert<utility::runtime_error>
293      (data.columns()==target.size(),
294       "KNN::train called with different sizes of target and data");
295    // k has to be at most the number of training samples.
296    if(data.columns()<k_) 
297      k_=data.columns();
298    data_ml_=0;
299    data_mlw_=&data;
300    target_=&target;
301  }
302
303
304  template <typename Distance, typename NeighborWeighting>
305  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
306                                                 utility::Matrix& prediction) const
307  {   
308    // matrix with training samples as rows and test samples as columns
309    utility::Matrix* distances = 0;
310    // unweighted training data
311    if(data_ml_ && !data_mlw_) {
312      utility::yat_assert<utility::runtime_error>
313        (data_ml_->rows()==test.rows(),
314         "KNN::predict different number of rows in training and test data");     
315      distances=new utility::Matrix(data_ml_->columns(),test.columns());
316      calculate_unweighted(*data_ml_,test,distances);
317    }
318    else if (data_mlw_ && !data_ml_) {
319      // weighted training data
320      utility::yat_assert<utility::runtime_error>
321        (data_mlw_->rows()==test.rows(),
322         "KNN::predict different number of rows in training and test data");           
323      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
324      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
325                         distances);             
326    }
327    else {
328      throw utility::runtime_error("KNN::predict no training data");
329    }
330
331    prediction.resize(target_->nof_classes(),test.columns(),0.0);
332    predict_common(*distances,prediction);
333    if(distances)
334      delete distances;
335  }
336
337  template <typename Distance, typename NeighborWeighting>
338  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
339                                                 utility::Matrix& prediction) const
340  {   
341    // matrix with training samples as rows and test samples as columns
342    utility::Matrix* distances=0; 
343    // unweighted training data
344    if(data_ml_ && !data_mlw_) { 
345      utility::yat_assert<utility::runtime_error>
346        (data_ml_->rows()==test.rows(),
347         "KNN::predict different number of rows in training and test data");   
348      distances=new utility::Matrix(data_ml_->columns(),test.columns());
349      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
350    }
351    // weighted training data
352    else if (data_mlw_ && !data_ml_) {
353      utility::yat_assert<utility::runtime_error>
354        (data_mlw_->rows()==test.rows(),
355         "KNN::predict different number of rows in training and test data");   
356      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
357      calculate_weighted(*data_mlw_,test,distances);             
358    }
359    else {
360      throw utility::runtime_error("KNN::predict no training data");
361    }
362
363    prediction.resize(target_->nof_classes(),test.columns(),0.0);
364    predict_common(*distances,prediction);
365   
366    if(distances)
367      delete distances;
368  }
369 
370  template <typename Distance, typename NeighborWeighting>
371  void KNN<Distance, NeighborWeighting>::predict_common
372  (const utility::Matrix& distances, utility::Matrix& prediction) const
373  {   
374    for(size_t sample=0;sample<distances.columns();sample++) {
375      std::vector<size_t> k_index;
376      utility::VectorConstView dist=distances.column_const_view(sample);
377      utility::sort_smallest_index(k_index,k_,dist);
378      utility::VectorView pred=prediction.column_view(sample);
379      weighting_(dist,k_index,*target_,pred);
380    }
381   
382    // classes for which there are no training samples should be set
383    // to nan in the predictions
384    for(size_t c=0;c<target_->nof_classes(); c++) 
385      if(!target_->size(c)) 
386        for(size_t j=0;j<prediction.columns();j++)
387          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
388  }
389
390
391}}} // of namespace classifier, yat, and theplu
392
393#endif
394
Note: See TracBrowser for help on using the repository browser.