source: trunk/yat/classifier/KNN.h @ 1580

Last change on this file since 1580 was 1487, checked in by Jari Häkkinen, 13 years ago

Addresses #436. GPL license copy reference should also be updated.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.2 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1487 2008-09-10 08:41:36Z jari $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://dev.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 3 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with yat. If not, see <http://www.gnu.org/licenses/>.
23*/
24
25#include "DataLookup1D.h"
26#include "DataLookupWeighted1D.h"
27#include "KNN_Uniform.h"
28#include "MatrixLookup.h"
29#include "MatrixLookupWeighted.h"
30#include "SupervisedClassifier.h"
31#include "Target.h"
32#include "yat/utility/Matrix.h"
33#include "yat/utility/yat_assert.h"
34
35#include <cmath>
36#include <map>
37#include <stdexcept>
38
39namespace theplu {
40namespace yat {
41namespace classifier {
42
43  /**
44     \brief Nearest Neighbor Classifier
45     
46     A sample is predicted based on the classes of its k nearest
47     neighbors among the training data samples. KNN supports using
48     different measures, for example, Euclidean distance, to define
49     distance between samples. KNN also supports using different ways to
50     weight the votes of the k nearest neighbors. For example, using a
51     uniform vote a test sample gets a vote for each class which is the
52     number of nearest neighbors belonging to the class.
53     
54     The template argument Distance should be a class modelling the
55     concept \ref concept_distance. The template argument
56     NeighborWeighting should be a class modelling the concept \ref
57     concept_neighbor_weighting.
58  */
59  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
60  class KNN : public SupervisedClassifier
61  {
62   
63  public:
64    /**
65       \brief Default constructor.
66       
67       The number of nearest neighbors (k) is set to 3. Distance and
68       NeighborWeighting are initialized using their default
69       constructuors.
70    */
71    KNN(void);
72
73
74    /**
75       \brief Constructor using an intialized distance measure.
76       
77       The number of nearest neighbors (k) is set to
78       3. NeighborWeighting is initialized using its default
79       constructor. This constructor should be used if Distance has
80       parameters and the user wants to specify the parameters by
81       initializing Distance prior to constructing the KNN.
82    */ 
83    KNN(const Distance&);
84
85
86    /**
87       Destructor
88    */
89    virtual ~KNN();
90   
91   
92    /**
93       \brief Get the number of nearest neighbors.
94       \return The number of neighbors.
95    */
96    unsigned int k() const;
97
98    /**
99       \brief Set the number of nearest neighbors.
100       
101       Sets the number of neighbors to \a k_in.
102    */
103    void k(unsigned int k_in);
104
105
106    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
107   
108    /**
109       \brief Make predictions for unweighted test data.
110       
111       Predictions are calculated and returned in \a results.  For
112       each sample in \a data, \a results contains the weighted number
113       of nearest neighbors which belong to each class. Numbers of
114       nearest neighbors are weighted according to
115       NeighborWeighting. If a class has no training samples NaN's are
116       returned for this class in \a results.
117    */
118    void predict(const MatrixLookup& data , utility::Matrix& results) const;
119
120    /**   
121        \brief Make predictions for weighted test data.
122       
123        Predictions are calculated and returned in \a results. For
124        each sample in \a data, \a results contains the weighted
125        number of nearest neighbors which belong to each class as in
126        predict(const MatrixLookup& data, utility::Matrix& results).
127        If a test and training sample pair has no variables with
128        non-zero weights in common, there are no variables which can
129        be used to calculate the distance between the two samples. In
130        this case the distance between the two is set to infinity.
131    */
132    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
133
134
135    /**
136       \brief Train the KNN using unweighted training data with known
137       targets.
138       
139       For KNN there is no actual training; the entire training data
140       set is stored with targets. KNN only stores references to \a data
141       and \a targets as copying these would make the %classifier
142       slow. If the number of training samples set is smaller than k,
143       k is set to the number of training samples.
144       
145       \note If \a data or \a targets go out of scope ore are
146       deleted, the KNN becomes invalid and further use is undefined
147       unless it is trained again.
148    */
149    void train(const MatrixLookup& data, const Target& targets);
150   
151    /**   
152       \brief Train the KNN using weighted training data with known targets.
153   
154       See train(const MatrixLookup& data, const Target& targets) for
155       additional information.
156    */
157    void train(const MatrixLookupWeighted& data, const Target& targets);
158   
159  private:
160   
161    const MatrixLookup* data_ml_;
162    const MatrixLookupWeighted* data_mlw_;
163    const Target* target_;
164
165    // The number of neighbors
166    unsigned int k_;
167
168    Distance distance_;
169    NeighborWeighting weighting_;
170
171    void calculate_unweighted(const MatrixLookup&,
172                              const MatrixLookup&,
173                              utility::Matrix*) const;
174    void calculate_weighted(const MatrixLookupWeighted&,
175                            const MatrixLookupWeighted&,
176                            utility::Matrix*) const;
177
178    void predict_common(const utility::Matrix& distances, 
179                        utility::Matrix& prediction) const;
180
181  };
182 
183 
184  // templates
185 
186  template <typename Distance, typename NeighborWeighting>
187  KNN<Distance, NeighborWeighting>::KNN() 
188    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
189  {
190  }
191
192  template <typename Distance, typename NeighborWeighting>
193  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
194    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
195  {
196  }
197
198 
199  template <typename Distance, typename NeighborWeighting>
200  KNN<Distance, NeighborWeighting>::~KNN()   
201  {
202  }
203 
204
205  template <typename Distance, typename NeighborWeighting>
206  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
207  (const MatrixLookup& training, const MatrixLookup& test,
208   utility::Matrix* distances) const
209  {
210    for(size_t i=0; i<training.columns(); i++) {
211      for(size_t j=0; j<test.columns(); j++) {
212        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
213                                      test.begin_column(j));
214        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
215      }
216    }
217  }
218
219 
220  template <typename Distance, typename NeighborWeighting>
221  void 
222  KNN<Distance, NeighborWeighting>::calculate_weighted
223  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
224   utility::Matrix* distances) const
225  {
226    for(size_t i=0; i<training.columns(); i++) { 
227      for(size_t j=0; j<test.columns(); j++) {
228        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
229                                      test.begin_column(j));
230        // If the distance is NaN (no common variables with non-zero weights),
231        // the distance is set to infinity to be sorted as a neighbor at the end
232        if(std::isnan((*distances)(i,j))) 
233          (*distances)(i,j)=std::numeric_limits<double>::infinity();
234      }
235    }
236  }
237 
238 
239  template <typename Distance, typename NeighborWeighting>
240  unsigned int KNN<Distance, NeighborWeighting>::k() const
241  {
242    return k_;
243  }
244
245  template <typename Distance, typename NeighborWeighting>
246  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
247  {
248    k_=k;
249  }
250
251
252  template <typename Distance, typename NeighborWeighting>
253  KNN<Distance, NeighborWeighting>* 
254  KNN<Distance, NeighborWeighting>::make_classifier() const 
255  {     
256    // All private members should be copied here to generate an
257    // identical but untrained classifier
258    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
259    knn->weighting_=this->weighting_;
260    knn->k(this->k());
261    return knn;
262  }
263 
264 
265  template <typename Distance, typename NeighborWeighting>
266  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
267                                               const Target& target)
268  {   
269    utility::yat_assert<std::runtime_error>
270      (data.columns()==target.size(),
271       "KNN::train called with different sizes of target and data");
272    // k has to be at most the number of training samples.
273    if(data.columns()<k_) 
274      k_=data.columns();
275    data_ml_=&data;
276    data_mlw_=0;
277    target_=&target;
278  }
279
280  template <typename Distance, typename NeighborWeighting>
281  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
282                                               const Target& target)
283  {   
284    utility::yat_assert<std::runtime_error>
285      (data.columns()==target.size(),
286       "KNN::train called with different sizes of target and data");
287    // k has to be at most the number of training samples.
288    if(data.columns()<k_) 
289      k_=data.columns();
290    data_ml_=0;
291    data_mlw_=&data;
292    target_=&target;
293  }
294
295
296  template <typename Distance, typename NeighborWeighting>
297  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
298                                                 utility::Matrix& prediction) const
299  {   
300    // matrix with training samples as rows and test samples as columns
301    utility::Matrix* distances = 0;
302    // unweighted training data
303    if(data_ml_ && !data_mlw_) {
304      utility::yat_assert<std::runtime_error>
305        (data_ml_->rows()==test.rows(),
306         "KNN::predict different number of rows in training and test data");     
307      distances=new utility::Matrix(data_ml_->columns(),test.columns());
308      calculate_unweighted(*data_ml_,test,distances);
309    }
310    else if (data_mlw_ && !data_ml_) {
311      // weighted training data
312      utility::yat_assert<std::runtime_error>
313        (data_mlw_->rows()==test.rows(),
314         "KNN::predict different number of rows in training and test data");           
315      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
316      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
317                         distances);             
318    }
319    else {
320      std::runtime_error("KNN::predict no training data");
321    }
322
323    prediction.resize(target_->nof_classes(),test.columns(),0.0);
324    predict_common(*distances,prediction);
325    if(distances)
326      delete distances;
327  }
328
329  template <typename Distance, typename NeighborWeighting>
330  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
331                                                 utility::Matrix& prediction) const
332  {   
333    // matrix with training samples as rows and test samples as columns
334    utility::Matrix* distances=0; 
335    // unweighted training data
336    if(data_ml_ && !data_mlw_) { 
337      utility::yat_assert<std::runtime_error>
338        (data_ml_->rows()==test.rows(),
339         "KNN::predict different number of rows in training and test data");   
340      distances=new utility::Matrix(data_ml_->columns(),test.columns());
341      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
342    }
343    // weighted training data
344    else if (data_mlw_ && !data_ml_) {
345      utility::yat_assert<std::runtime_error>
346        (data_mlw_->rows()==test.rows(),
347         "KNN::predict different number of rows in training and test data");   
348      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
349      calculate_weighted(*data_mlw_,test,distances);             
350    }
351    else {
352      std::runtime_error("KNN::predict no training data");
353    }
354
355    prediction.resize(target_->nof_classes(),test.columns(),0.0);
356    predict_common(*distances,prediction);
357   
358    if(distances)
359      delete distances;
360  }
361 
362  template <typename Distance, typename NeighborWeighting>
363  void KNN<Distance, NeighborWeighting>::predict_common
364  (const utility::Matrix& distances, utility::Matrix& prediction) const
365  {   
366    for(size_t sample=0;sample<distances.columns();sample++) {
367      std::vector<size_t> k_index;
368      utility::VectorConstView dist=distances.column_const_view(sample);
369      utility::sort_smallest_index(k_index,k_,dist);
370      utility::VectorView pred=prediction.column_view(sample);
371      weighting_(dist,k_index,*target_,pred);
372    }
373   
374    // classes for which there are no training samples should be set
375    // to nan in the predictions
376    for(size_t c=0;c<target_->nof_classes(); c++) 
377      if(!target_->size(c)) 
378        for(size_t j=0;j<prediction.columns();j++)
379          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
380  }
381
382
383}}} // of namespace classifier, yat, and theplu
384
385#endif
386
Note: See TracBrowser for help on using the repository browser.