source: trunk/yat/classifier/KNN.h @ 2335

Last change on this file since 2335 was 2335, checked in by Peter, 11 years ago

prefer lines shorter than 80 characters

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.5 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 2335 2010-10-15 12:22:13Z peter $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009 Peter Johansson
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "DataLookup1D.h"
27#include "DataLookupWeighted1D.h"
28#include "KNN_Uniform.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/concept_check.h"
34#include "yat/utility/Exception.h"
35#include "yat/utility/Matrix.h"
36#include "yat/utility/yat_assert.h"
37
38#include <boost/concept_check.hpp>
39
40#include <cmath>
41#include <limits>
42#include <map>
43#include <stdexcept>
44
45namespace theplu {
46namespace yat {
47namespace classifier {
48
49  /**
50     \brief Nearest Neighbor Classifier
51     
52     A sample is predicted based on the classes of its k nearest
53     neighbors among the training data samples. KNN supports using
54     different measures, for example, Euclidean distance, to define
55     distance between samples. KNN also supports using different ways to
56     weight the votes of the k nearest neighbors. For example, using a
57     uniform vote a test sample gets a vote for each class which is the
58     number of nearest neighbors belonging to the class.
59     
60     The template argument Distance should be a class modelling the
61     concept \ref concept_distance. The template argument
62     NeighborWeighting should be a class modelling the concept \ref
63     concept_neighbor_weighting.
64  */
65  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
66  class KNN : public SupervisedClassifier
67  {
68   
69  public:
70    /**
71       \brief Default constructor.
72       
73       The number of nearest neighbors (k) is set to 3. Distance and
74       NeighborWeighting are initialized using their default
75       constructuors.
76    */
77    KNN(void);
78
79
80    /**
81       \brief Constructor using an intialized distance measure.
82       
83       The number of nearest neighbors (k) is set to
84       3. NeighborWeighting is initialized using its default
85       constructor. This constructor should be used if Distance has
86       parameters and the user wants to specify the parameters by
87       initializing Distance prior to constructing the KNN.
88    */ 
89    KNN(const Distance&);
90
91
92    /**
93       Destructor
94    */
95    virtual ~KNN();
96   
97   
98    /**
99       \brief Get the number of nearest neighbors.
100       \return The number of neighbors.
101    */
102    unsigned int k() const;
103
104    /**
105       \brief Set the number of nearest neighbors.
106       
107       Sets the number of neighbors to \a k_in.
108    */
109    void k(unsigned int k_in);
110
111
112    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
113   
114    /**
115       \brief Make predictions for unweighted test data.
116       
117       Predictions are calculated and returned in \a results.  For
118       each sample in \a data, \a results contains the weighted number
119       of nearest neighbors which belong to each class. Numbers of
120       nearest neighbors are weighted according to
121       NeighborWeighting. If a class has no training samples NaN's are
122       returned for this class in \a results.
123    */
124    void predict(const MatrixLookup& data , utility::Matrix& results) const;
125
126    /**   
127        \brief Make predictions for weighted test data.
128       
129        Predictions are calculated and returned in \a results. For
130        each sample in \a data, \a results contains the weighted
131        number of nearest neighbors which belong to each class as in
132        predict(const MatrixLookup& data, utility::Matrix& results).
133        If a test and training sample pair has no variables with
134        non-zero weights in common, there are no variables which can
135        be used to calculate the distance between the two samples. In
136        this case the distance between the two is set to infinity.
137    */
138    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
139
140
141    /**
142       \brief Train the KNN using unweighted training data with known
143       targets.
144       
145       For KNN there is no actual training; the entire training data
146       set is stored with targets. KNN only stores references to \a data
147       and \a targets as copying these would make the %classifier
148       slow. If the number of training samples set is smaller than k,
149       k is set to the number of training samples.
150       
151       \note If \a data or \a targets go out of scope ore are
152       deleted, the KNN becomes invalid and further use is undefined
153       unless it is trained again.
154    */
155    void train(const MatrixLookup& data, const Target& targets);
156   
157    /**   
158       \brief Train the KNN using weighted training data with known targets.
159   
160       See train(const MatrixLookup& data, const Target& targets) for
161       additional information.
162    */
163    void train(const MatrixLookupWeighted& data, const Target& targets);
164   
165  private:
166   
167    const MatrixLookup* data_ml_;
168    const MatrixLookupWeighted* data_mlw_;
169    const Target* target_;
170
171    // The number of neighbors
172    unsigned int k_;
173
174    Distance distance_;
175    NeighborWeighting weighting_;
176
177    void calculate_unweighted(const MatrixLookup&,
178                              const MatrixLookup&,
179                              utility::Matrix*) const;
180    void calculate_weighted(const MatrixLookupWeighted&,
181                            const MatrixLookupWeighted&,
182                            utility::Matrix*) const;
183
184    void predict_common(const utility::Matrix& distances, 
185                        utility::Matrix& prediction) const;
186
187  };
188 
189 
190  // templates
191 
192  template <typename Distance, typename NeighborWeighting>
193  KNN<Distance, NeighborWeighting>::KNN() 
194    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
195  {
196    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
197  }
198
199  template <typename Distance, typename NeighborWeighting>
200  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
201    : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3), 
202      distance_(dist)
203  {
204    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
205  }
206
207 
208  template <typename Distance, typename NeighborWeighting>
209  KNN<Distance, NeighborWeighting>::~KNN()   
210  {
211  }
212 
213
214  template <typename Distance, typename NeighborWeighting>
215  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
216  (const MatrixLookup& training, const MatrixLookup& test,
217   utility::Matrix* distances) const
218  {
219    for(size_t i=0; i<training.columns(); i++) {
220      for(size_t j=0; j<test.columns(); j++) {
221        (*distances)(i,j) = distance_(training.begin_column(i), 
222                                      training.end_column(i), 
223                                      test.begin_column(j));
224        YAT_ASSERT(!std::isnan((*distances)(i,j)));
225      }
226    }
227  }
228
229 
230  template <typename Distance, typename NeighborWeighting>
231  void 
232  KNN<Distance, NeighborWeighting>::calculate_weighted
233  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
234   utility::Matrix* distances) const
235  {
236    for(size_t i=0; i<training.columns(); i++) { 
237      for(size_t j=0; j<test.columns(); j++) {
238        (*distances)(i,j) = distance_(training.begin_column(i), 
239                                      training.end_column(i), 
240                                      test.begin_column(j));
241        // If the distance is NaN (no common variables with non-zero weights),
242        // the distance is set to infinity to be sorted as a neighbor at the end
243        if(std::isnan((*distances)(i,j))) 
244          (*distances)(i,j)=std::numeric_limits<double>::infinity();
245      }
246    }
247  }
248 
249 
250  template <typename Distance, typename NeighborWeighting>
251  unsigned int KNN<Distance, NeighborWeighting>::k() const
252  {
253    return k_;
254  }
255
256  template <typename Distance, typename NeighborWeighting>
257  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
258  {
259    k_=k;
260  }
261
262
263  template <typename Distance, typename NeighborWeighting>
264  KNN<Distance, NeighborWeighting>* 
265  KNN<Distance, NeighborWeighting>::make_classifier() const 
266  {     
267    // All private members should be copied here to generate an
268    // identical but untrained classifier
269    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
270    knn->weighting_=this->weighting_;
271    knn->k(this->k());
272    return knn;
273  }
274 
275 
276  template <typename Distance, typename NeighborWeighting>
277  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
278                                               const Target& target)
279  {   
280    utility::yat_assert<utility::runtime_error>
281      (data.columns()==target.size(),
282       "KNN::train called with different sizes of target and data");
283    // k has to be at most the number of training samples.
284    if(data.columns()<k_) 
285      k_=data.columns();
286    data_ml_=&data;
287    data_mlw_=0;
288    target_=&target;
289  }
290
291  template <typename Distance, typename NeighborWeighting>
292  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
293                                               const Target& target)
294  {   
295    utility::yat_assert<utility::runtime_error>
296      (data.columns()==target.size(),
297       "KNN::train called with different sizes of target and data");
298    // k has to be at most the number of training samples.
299    if(data.columns()<k_) 
300      k_=data.columns();
301    data_ml_=0;
302    data_mlw_=&data;
303    target_=&target;
304  }
305
306
307  template <typename Distance, typename NeighborWeighting>
308  void 
309  KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
310                                            utility::Matrix& prediction) const
311  {   
312    // matrix with training samples as rows and test samples as columns
313    utility::Matrix* distances = 0;
314    // unweighted training data
315    if(data_ml_ && !data_mlw_) {
316      utility::yat_assert<utility::runtime_error>
317        (data_ml_->rows()==test.rows(),
318         "KNN::predict different number of rows in training and test data");     
319      distances=new utility::Matrix(data_ml_->columns(),test.columns());
320      calculate_unweighted(*data_ml_,test,distances);
321    }
322    else if (data_mlw_ && !data_ml_) {
323      // weighted training data
324      utility::yat_assert<utility::runtime_error>
325        (data_mlw_->rows()==test.rows(),
326         "KNN::predict different number of rows in training and test data");           
327      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
328      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
329                         distances);             
330    }
331    else {
332      throw utility::runtime_error("KNN::predict no training data");
333    }
334
335    prediction.resize(target_->nof_classes(),test.columns(),0.0);
336    predict_common(*distances,prediction);
337    if(distances)
338      delete distances;
339  }
340
341  template <typename Distance, typename NeighborWeighting>
342  void 
343  KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
344                                            utility::Matrix& prediction) const
345  {   
346    // matrix with training samples as rows and test samples as columns
347    utility::Matrix* distances=0; 
348    // unweighted training data
349    if(data_ml_ && !data_mlw_) { 
350      utility::yat_assert<utility::runtime_error>
351        (data_ml_->rows()==test.rows(),
352         "KNN::predict different number of rows in training and test data");   
353      distances=new utility::Matrix(data_ml_->columns(),test.columns());
354      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
355    }
356    // weighted training data
357    else if (data_mlw_ && !data_ml_) {
358      utility::yat_assert<utility::runtime_error>
359        (data_mlw_->rows()==test.rows(),
360         "KNN::predict different number of rows in training and test data");   
361      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
362      calculate_weighted(*data_mlw_,test,distances);             
363    }
364    else {
365      throw utility::runtime_error("KNN::predict no training data");
366    }
367
368    prediction.resize(target_->nof_classes(),test.columns(),0.0);
369    predict_common(*distances,prediction);
370   
371    if(distances)
372      delete distances;
373  }
374 
375  template <typename Distance, typename NeighborWeighting>
376  void KNN<Distance, NeighborWeighting>::predict_common
377  (const utility::Matrix& distances, utility::Matrix& prediction) const
378  {   
379    for(size_t sample=0;sample<distances.columns();sample++) {
380      std::vector<size_t> k_index;
381      utility::VectorConstView dist=distances.column_const_view(sample);
382      utility::sort_smallest_index(k_index,k_,dist);
383      utility::VectorView pred=prediction.column_view(sample);
384      weighting_(dist,k_index,*target_,pred);
385    }
386   
387    // classes for which there are no training samples should be set
388    // to nan in the predictions
389    for(size_t c=0;c<target_->nof_classes(); c++) 
390      if(!target_->size(c)) 
391        for(size_t j=0;j<prediction.columns();j++)
392          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
393  }
394}}} // of namespace classifier, yat, and theplu
395
396#endif
Note: See TracBrowser for help on using the repository browser.