source: trunk/yat/classifier/KNN.h @ 2210

Last change on this file since 2210 was 2210, checked in by Peter, 12 years ago

fixes #281. Change all throws of std::runtime_error to theplu::yat::utility::runtime_error to clarify that the error comes from yat. Also removed some throw declarations.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.3 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 2210 2010-03-05 22:59:01Z peter $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009 Peter Johansson
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "DataLookup1D.h"
27#include "DataLookupWeighted1D.h"
28#include "KNN_Uniform.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/Exception.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <limits>
39#include <map>
40#include <stdexcept>
41
42namespace theplu {
43namespace yat {
44namespace classifier {
45
46  /**
47     \brief Nearest Neighbor Classifier
48     
49     A sample is predicted based on the classes of its k nearest
50     neighbors among the training data samples. KNN supports using
51     different measures, for example, Euclidean distance, to define
52     distance between samples. KNN also supports using different ways to
53     weight the votes of the k nearest neighbors. For example, using a
54     uniform vote a test sample gets a vote for each class which is the
55     number of nearest neighbors belonging to the class.
56     
57     The template argument Distance should be a class modelling the
58     concept \ref concept_distance. The template argument
59     NeighborWeighting should be a class modelling the concept \ref
60     concept_neighbor_weighting.
61  */
62  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
63  class KNN : public SupervisedClassifier
64  {
65   
66  public:
67    /**
68       \brief Default constructor.
69       
70       The number of nearest neighbors (k) is set to 3. Distance and
71       NeighborWeighting are initialized using their default
72       constructuors.
73    */
74    KNN(void);
75
76
77    /**
78       \brief Constructor using an intialized distance measure.
79       
80       The number of nearest neighbors (k) is set to
81       3. NeighborWeighting is initialized using its default
82       constructor. This constructor should be used if Distance has
83       parameters and the user wants to specify the parameters by
84       initializing Distance prior to constructing the KNN.
85    */ 
86    KNN(const Distance&);
87
88
89    /**
90       Destructor
91    */
92    virtual ~KNN();
93   
94   
95    /**
96       \brief Get the number of nearest neighbors.
97       \return The number of neighbors.
98    */
99    unsigned int k() const;
100
101    /**
102       \brief Set the number of nearest neighbors.
103       
104       Sets the number of neighbors to \a k_in.
105    */
106    void k(unsigned int k_in);
107
108
109    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
110   
111    /**
112       \brief Make predictions for unweighted test data.
113       
114       Predictions are calculated and returned in \a results.  For
115       each sample in \a data, \a results contains the weighted number
116       of nearest neighbors which belong to each class. Numbers of
117       nearest neighbors are weighted according to
118       NeighborWeighting. If a class has no training samples NaN's are
119       returned for this class in \a results.
120    */
121    void predict(const MatrixLookup& data , utility::Matrix& results) const;
122
123    /**   
124        \brief Make predictions for weighted test data.
125       
126        Predictions are calculated and returned in \a results. For
127        each sample in \a data, \a results contains the weighted
128        number of nearest neighbors which belong to each class as in
129        predict(const MatrixLookup& data, utility::Matrix& results).
130        If a test and training sample pair has no variables with
131        non-zero weights in common, there are no variables which can
132        be used to calculate the distance between the two samples. In
133        this case the distance between the two is set to infinity.
134    */
135    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
136
137
138    /**
139       \brief Train the KNN using unweighted training data with known
140       targets.
141       
142       For KNN there is no actual training; the entire training data
143       set is stored with targets. KNN only stores references to \a data
144       and \a targets as copying these would make the %classifier
145       slow. If the number of training samples set is smaller than k,
146       k is set to the number of training samples.
147       
148       \note If \a data or \a targets go out of scope ore are
149       deleted, the KNN becomes invalid and further use is undefined
150       unless it is trained again.
151    */
152    void train(const MatrixLookup& data, const Target& targets);
153   
154    /**   
155       \brief Train the KNN using weighted training data with known targets.
156   
157       See train(const MatrixLookup& data, const Target& targets) for
158       additional information.
159    */
160    void train(const MatrixLookupWeighted& data, const Target& targets);
161   
162  private:
163   
164    const MatrixLookup* data_ml_;
165    const MatrixLookupWeighted* data_mlw_;
166    const Target* target_;
167
168    // The number of neighbors
169    unsigned int k_;
170
171    Distance distance_;
172    NeighborWeighting weighting_;
173
174    void calculate_unweighted(const MatrixLookup&,
175                              const MatrixLookup&,
176                              utility::Matrix*) const;
177    void calculate_weighted(const MatrixLookupWeighted&,
178                            const MatrixLookupWeighted&,
179                            utility::Matrix*) const;
180
181    void predict_common(const utility::Matrix& distances, 
182                        utility::Matrix& prediction) const;
183
184  };
185 
186 
187  // templates
188 
189  template <typename Distance, typename NeighborWeighting>
190  KNN<Distance, NeighborWeighting>::KNN() 
191    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
192  {
193  }
194
195  template <typename Distance, typename NeighborWeighting>
196  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
197    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
198  {
199  }
200
201 
202  template <typename Distance, typename NeighborWeighting>
203  KNN<Distance, NeighborWeighting>::~KNN()   
204  {
205  }
206 
207
208  template <typename Distance, typename NeighborWeighting>
209  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
210  (const MatrixLookup& training, const MatrixLookup& test,
211   utility::Matrix* distances) const
212  {
213    for(size_t i=0; i<training.columns(); i++) {
214      for(size_t j=0; j<test.columns(); j++) {
215        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
216                                      test.begin_column(j));
217        YAT_ASSERT(!std::isnan((*distances)(i,j)));
218      }
219    }
220  }
221
222 
223  template <typename Distance, typename NeighborWeighting>
224  void 
225  KNN<Distance, NeighborWeighting>::calculate_weighted
226  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
227   utility::Matrix* distances) const
228  {
229    for(size_t i=0; i<training.columns(); i++) { 
230      for(size_t j=0; j<test.columns(); j++) {
231        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
232                                      test.begin_column(j));
233        // If the distance is NaN (no common variables with non-zero weights),
234        // the distance is set to infinity to be sorted as a neighbor at the end
235        if(std::isnan((*distances)(i,j))) 
236          (*distances)(i,j)=std::numeric_limits<double>::infinity();
237      }
238    }
239  }
240 
241 
242  template <typename Distance, typename NeighborWeighting>
243  unsigned int KNN<Distance, NeighborWeighting>::k() const
244  {
245    return k_;
246  }
247
248  template <typename Distance, typename NeighborWeighting>
249  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
250  {
251    k_=k;
252  }
253
254
255  template <typename Distance, typename NeighborWeighting>
256  KNN<Distance, NeighborWeighting>* 
257  KNN<Distance, NeighborWeighting>::make_classifier() const 
258  {     
259    // All private members should be copied here to generate an
260    // identical but untrained classifier
261    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
262    knn->weighting_=this->weighting_;
263    knn->k(this->k());
264    return knn;
265  }
266 
267 
268  template <typename Distance, typename NeighborWeighting>
269  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
270                                               const Target& target)
271  {   
272    utility::yat_assert<utility::runtime_error>
273      (data.columns()==target.size(),
274       "KNN::train called with different sizes of target and data");
275    // k has to be at most the number of training samples.
276    if(data.columns()<k_) 
277      k_=data.columns();
278    data_ml_=&data;
279    data_mlw_=0;
280    target_=&target;
281  }
282
283  template <typename Distance, typename NeighborWeighting>
284  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
285                                               const Target& target)
286  {   
287    utility::yat_assert<utility::runtime_error>
288      (data.columns()==target.size(),
289       "KNN::train called with different sizes of target and data");
290    // k has to be at most the number of training samples.
291    if(data.columns()<k_) 
292      k_=data.columns();
293    data_ml_=0;
294    data_mlw_=&data;
295    target_=&target;
296  }
297
298
299  template <typename Distance, typename NeighborWeighting>
300  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
301                                                 utility::Matrix& prediction) const
302  {   
303    // matrix with training samples as rows and test samples as columns
304    utility::Matrix* distances = 0;
305    // unweighted training data
306    if(data_ml_ && !data_mlw_) {
307      utility::yat_assert<utility::runtime_error>
308        (data_ml_->rows()==test.rows(),
309         "KNN::predict different number of rows in training and test data");     
310      distances=new utility::Matrix(data_ml_->columns(),test.columns());
311      calculate_unweighted(*data_ml_,test,distances);
312    }
313    else if (data_mlw_ && !data_ml_) {
314      // weighted training data
315      utility::yat_assert<utility::runtime_error>
316        (data_mlw_->rows()==test.rows(),
317         "KNN::predict different number of rows in training and test data");           
318      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
319      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
320                         distances);             
321    }
322    else {
323      throw utility::runtime_error("KNN::predict no training data");
324    }
325
326    prediction.resize(target_->nof_classes(),test.columns(),0.0);
327    predict_common(*distances,prediction);
328    if(distances)
329      delete distances;
330  }
331
332  template <typename Distance, typename NeighborWeighting>
333  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
334                                                 utility::Matrix& prediction) const
335  {   
336    // matrix with training samples as rows and test samples as columns
337    utility::Matrix* distances=0; 
338    // unweighted training data
339    if(data_ml_ && !data_mlw_) { 
340      utility::yat_assert<utility::runtime_error>
341        (data_ml_->rows()==test.rows(),
342         "KNN::predict different number of rows in training and test data");   
343      distances=new utility::Matrix(data_ml_->columns(),test.columns());
344      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
345    }
346    // weighted training data
347    else if (data_mlw_ && !data_ml_) {
348      utility::yat_assert<utility::runtime_error>
349        (data_mlw_->rows()==test.rows(),
350         "KNN::predict different number of rows in training and test data");   
351      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
352      calculate_weighted(*data_mlw_,test,distances);             
353    }
354    else {
355      throw utility::runtime_error("KNN::predict no training data");
356    }
357
358    prediction.resize(target_->nof_classes(),test.columns(),0.0);
359    predict_common(*distances,prediction);
360   
361    if(distances)
362      delete distances;
363  }
364 
365  template <typename Distance, typename NeighborWeighting>
366  void KNN<Distance, NeighborWeighting>::predict_common
367  (const utility::Matrix& distances, utility::Matrix& prediction) const
368  {   
369    for(size_t sample=0;sample<distances.columns();sample++) {
370      std::vector<size_t> k_index;
371      utility::VectorConstView dist=distances.column_const_view(sample);
372      utility::sort_smallest_index(k_index,k_,dist);
373      utility::VectorView pred=prediction.column_view(sample);
374      weighting_(dist,k_index,*target_,pred);
375    }
376   
377    // classes for which there are no training samples should be set
378    // to nan in the predictions
379    for(size_t c=0;c<target_->nof_classes(); c++) 
380      if(!target_->size(c)) 
381        for(size_t j=0;j<prediction.columns();j++)
382          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
383  }
384
385
386}}} // of namespace classifier, yat, and theplu
387
388#endif
389
Note: See TracBrowser for help on using the repository browser.