source: trunk/yat/classifier/KNN.h @ 2336

Last change on this file since 2336 was 2336, checked in by Peter, 11 years ago

prefer lines shorter than 80 characters (again)

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.6 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 2336 2010-10-15 12:26:44Z peter $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009 Peter Johansson
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "DataLookup1D.h"
27#include "DataLookupWeighted1D.h"
28#include "KNN_Uniform.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/concept_check.h"
34#include "yat/utility/Exception.h"
35#include "yat/utility/Matrix.h"
36#include "yat/utility/yat_assert.h"
37
38#include <boost/concept_check.hpp>
39
40#include <cmath>
41#include <limits>
42#include <map>
43#include <stdexcept>
44
45namespace theplu {
46namespace yat {
47namespace classifier {
48
49  /**
50     \brief Nearest Neighbor Classifier
51     
52     A sample is predicted based on the classes of its k nearest
53     neighbors among the training data samples. KNN supports using
54     different measures, for example, Euclidean distance, to define
55     distance between samples. KNN also supports using different ways to
56     weight the votes of the k nearest neighbors. For example, using a
57     uniform vote a test sample gets a vote for each class which is the
58     number of nearest neighbors belonging to the class.
59     
60     The template argument Distance should be a class modelling the
61     concept \ref concept_distance. The template argument
62     NeighborWeighting should be a class modelling the concept \ref
63     concept_neighbor_weighting.
64  */
65  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
66  class KNN : public SupervisedClassifier
67  {
68   
69  public:
70    /**
71       \brief Default constructor.
72       
73       The number of nearest neighbors (k) is set to 3. Distance and
74       NeighborWeighting are initialized using their default
75       constructuors.
76    */
77    KNN(void);
78
79
80    /**
81       \brief Constructor using an intialized distance measure.
82       
83       The number of nearest neighbors (k) is set to
84       3. NeighborWeighting is initialized using its default
85       constructor. This constructor should be used if Distance has
86       parameters and the user wants to specify the parameters by
87       initializing Distance prior to constructing the KNN.
88    */ 
89    KNN(const Distance&);
90
91
92    /**
93       Destructor
94    */
95    virtual ~KNN();
96   
97   
98    /**
99       \brief Get the number of nearest neighbors.
100       \return The number of neighbors.
101    */
102    unsigned int k() const;
103
104    /**
105       \brief Set the number of nearest neighbors.
106       
107       Sets the number of neighbors to \a k_in.
108    */
109    void k(unsigned int k_in);
110
111
112    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
113   
114    /**
115       \brief Make predictions for unweighted test data.
116       
117       Predictions are calculated and returned in \a results.  For
118       each sample in \a data, \a results contains the weighted number
119       of nearest neighbors which belong to each class. Numbers of
120       nearest neighbors are weighted according to
121       NeighborWeighting. If a class has no training samples NaN's are
122       returned for this class in \a results.
123    */
124    void predict(const MatrixLookup& data , utility::Matrix& results) const;
125
126    /**   
127        \brief Make predictions for weighted test data.
128       
129        Predictions are calculated and returned in \a results. For
130        each sample in \a data, \a results contains the weighted
131        number of nearest neighbors which belong to each class as in
132        predict(const MatrixLookup& data, utility::Matrix& results).
133        If a test and training sample pair has no variables with
134        non-zero weights in common, there are no variables which can
135        be used to calculate the distance between the two samples. In
136        this case the distance between the two is set to infinity.
137    */
138    void predict(const MatrixLookupWeighted& data, 
139                 utility::Matrix& results) const;
140
141
142    /**
143       \brief Train the KNN using unweighted training data with known
144       targets.
145       
146       For KNN there is no actual training; the entire training data
147       set is stored with targets. KNN only stores references to \a data
148       and \a targets as copying these would make the %classifier
149       slow. If the number of training samples set is smaller than k,
150       k is set to the number of training samples.
151       
152       \note If \a data or \a targets go out of scope ore are
153       deleted, the KNN becomes invalid and further use is undefined
154       unless it is trained again.
155    */
156    void train(const MatrixLookup& data, const Target& targets);
157   
158    /**   
159       \brief Train the KNN using weighted training data with known targets.
160   
161       See train(const MatrixLookup& data, const Target& targets) for
162       additional information.
163    */
164    void train(const MatrixLookupWeighted& data, const Target& targets);
165   
166  private:
167   
168    const MatrixLookup* data_ml_;
169    const MatrixLookupWeighted* data_mlw_;
170    const Target* target_;
171
172    // The number of neighbors
173    unsigned int k_;
174
175    Distance distance_;
176    NeighborWeighting weighting_;
177
178    void calculate_unweighted(const MatrixLookup&,
179                              const MatrixLookup&,
180                              utility::Matrix*) const;
181    void calculate_weighted(const MatrixLookupWeighted&,
182                            const MatrixLookupWeighted&,
183                            utility::Matrix*) const;
184
185    void predict_common(const utility::Matrix& distances, 
186                        utility::Matrix& prediction) const;
187
188  };
189 
190 
191  // templates
192 
193  template <typename Distance, typename NeighborWeighting>
194  KNN<Distance, NeighborWeighting>::KNN() 
195    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
196  {
197    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
198  }
199
200  template <typename Distance, typename NeighborWeighting>
201  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
202    : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3), 
203      distance_(dist)
204  {
205    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
206  }
207
208 
209  template <typename Distance, typename NeighborWeighting>
210  KNN<Distance, NeighborWeighting>::~KNN()   
211  {
212  }
213 
214
215  template <typename Distance, typename NeighborWeighting>
216  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
217  (const MatrixLookup& training, const MatrixLookup& test,
218   utility::Matrix* distances) const
219  {
220    for(size_t i=0; i<training.columns(); i++) {
221      for(size_t j=0; j<test.columns(); j++) {
222        (*distances)(i,j) = distance_(training.begin_column(i), 
223                                      training.end_column(i), 
224                                      test.begin_column(j));
225        YAT_ASSERT(!std::isnan((*distances)(i,j)));
226      }
227    }
228  }
229
230 
231  template <typename Distance, typename NeighborWeighting>
232  void 
233  KNN<Distance, NeighborWeighting>::calculate_weighted
234  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
235   utility::Matrix* distances) const
236  {
237    for(size_t i=0; i<training.columns(); i++) { 
238      for(size_t j=0; j<test.columns(); j++) {
239        (*distances)(i,j) = distance_(training.begin_column(i), 
240                                      training.end_column(i), 
241                                      test.begin_column(j));
242        // If the distance is NaN (no common variables with non-zero weights),
243        // the distance is set to infinity to be sorted as a neighbor at the end
244        if(std::isnan((*distances)(i,j))) 
245          (*distances)(i,j)=std::numeric_limits<double>::infinity();
246      }
247    }
248  }
249 
250 
251  template <typename Distance, typename NeighborWeighting>
252  unsigned int KNN<Distance, NeighborWeighting>::k() const
253  {
254    return k_;
255  }
256
257  template <typename Distance, typename NeighborWeighting>
258  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
259  {
260    k_=k;
261  }
262
263
264  template <typename Distance, typename NeighborWeighting>
265  KNN<Distance, NeighborWeighting>* 
266  KNN<Distance, NeighborWeighting>::make_classifier() const 
267  {     
268    // All private members should be copied here to generate an
269    // identical but untrained classifier
270    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
271    knn->weighting_=this->weighting_;
272    knn->k(this->k());
273    return knn;
274  }
275 
276 
277  template <typename Distance, typename NeighborWeighting>
278  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
279                                               const Target& target)
280  {   
281    utility::yat_assert<utility::runtime_error>
282      (data.columns()==target.size(),
283       "KNN::train called with different sizes of target and data");
284    // k has to be at most the number of training samples.
285    if(data.columns()<k_) 
286      k_=data.columns();
287    data_ml_=&data;
288    data_mlw_=0;
289    target_=&target;
290  }
291
292  template <typename Distance, typename NeighborWeighting>
293  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
294                                               const Target& target)
295  {   
296    utility::yat_assert<utility::runtime_error>
297      (data.columns()==target.size(),
298       "KNN::train called with different sizes of target and data");
299    // k has to be at most the number of training samples.
300    if(data.columns()<k_) 
301      k_=data.columns();
302    data_ml_=0;
303    data_mlw_=&data;
304    target_=&target;
305  }
306
307
308  template <typename Distance, typename NeighborWeighting>
309  void 
310  KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
311                                            utility::Matrix& prediction) const
312  {   
313    // matrix with training samples as rows and test samples as columns
314    utility::Matrix* distances = 0;
315    // unweighted training data
316    if(data_ml_ && !data_mlw_) {
317      utility::yat_assert<utility::runtime_error>
318        (data_ml_->rows()==test.rows(),
319         "KNN::predict different number of rows in training and test data");
320      distances=new utility::Matrix(data_ml_->columns(),test.columns());
321      calculate_unweighted(*data_ml_,test,distances);
322    }
323    else if (data_mlw_ && !data_ml_) {
324      // weighted training data
325      utility::yat_assert<utility::runtime_error>
326        (data_mlw_->rows()==test.rows(),
327         "KNN::predict different number of rows in training and test data");
328      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
329      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
330                         distances);             
331    }
332    else {
333      throw utility::runtime_error("KNN::predict no training data");
334    }
335
336    prediction.resize(target_->nof_classes(),test.columns(),0.0);
337    predict_common(*distances,prediction);
338    if(distances)
339      delete distances;
340  }
341
342  template <typename Distance, typename NeighborWeighting>
343  void 
344  KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
345                                            utility::Matrix& prediction) const
346  {   
347    // matrix with training samples as rows and test samples as columns
348    utility::Matrix* distances=0; 
349    // unweighted training data
350    if(data_ml_ && !data_mlw_) { 
351      utility::yat_assert<utility::runtime_error>
352        (data_ml_->rows()==test.rows(),
353         "KNN::predict different number of rows in training and test data");   
354      distances=new utility::Matrix(data_ml_->columns(),test.columns());
355      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
356    }
357    // weighted training data
358    else if (data_mlw_ && !data_ml_) {
359      utility::yat_assert<utility::runtime_error>
360        (data_mlw_->rows()==test.rows(),
361         "KNN::predict different number of rows in training and test data");   
362      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
363      calculate_weighted(*data_mlw_,test,distances);             
364    }
365    else {
366      throw utility::runtime_error("KNN::predict no training data");
367    }
368
369    prediction.resize(target_->nof_classes(),test.columns(),0.0);
370    predict_common(*distances,prediction);
371   
372    if(distances)
373      delete distances;
374  }
375 
376  template <typename Distance, typename NeighborWeighting>
377  void KNN<Distance, NeighborWeighting>::predict_common
378  (const utility::Matrix& distances, utility::Matrix& prediction) const
379  {   
380    for(size_t sample=0;sample<distances.columns();sample++) {
381      std::vector<size_t> k_index;
382      utility::VectorConstView dist=distances.column_const_view(sample);
383      utility::sort_smallest_index(k_index,k_,dist);
384      utility::VectorView pred=prediction.column_view(sample);
385      weighting_(dist,k_index,*target_,pred);
386    }
387   
388    // classes for which there are no training samples should be set
389    // to nan in the predictions
390    for(size_t c=0;c<target_->nof_classes(); c++) 
391      if(!target_->size(c)) 
392        for(size_t j=0;j<prediction.columns();j++)
393          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
394  }
395}}} // of namespace classifier, yat, and theplu
396
397#endif
Note: See TracBrowser for help on using the repository browser.