source: branches/0.4-stable/yat/classifier/KNN.h @ 1392

Last change on this file since 1392 was 1392, checked in by Peter, 15 years ago

trac has moved

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.3 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1392 2008-07-28 19:35:30Z peter $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://dev.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  /**
46     \brief Nearest Neighbor Classifier
47     
48     A sample is predicted based on the classes of its k nearest
49     neighbors among the training data samples. KNN supports using
50     different measures, for example, Euclidean distance, to define
51     distance between samples. KNN also supports using different ways to
52     weight the votes of the k nearest neighbors. For example, using a
53     uniform vote a test sample gets a vote for each class which is the
54     number of nearest neighbors belonging to the class.
55     
56     The template argument Distance should be a class modelling the
57     concept \ref concept_distance. The template argument
58     NeighborWeighting should be a class modelling the concept \ref
59     concept_neighbor_weighting.
60  */
61  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
62  class KNN : public SupervisedClassifier
63  {
64   
65  public:
66    /**
67       \brief Default constructor.
68       
69       The number of nearest neighbors (k) is set to 3. Distance and
70       NeighborWeighting are initialized using their default
71       constructuors.
72    */
73    KNN(void);
74
75
76    /**
77       \brief Constructor using an intialized distance measure.
78       
79       The number of nearest neighbors (k) is set to
80       3. NeighborWeighting is initialized using its default
81       constructor. This constructor should be used if Distance has
82       parameters and the user wants to specify the parameters by
83       initializing Distance prior to constructing the KNN.
84    */ 
85    KNN(const Distance&);
86
87
88    /**
89       Destructor
90    */
91    virtual ~KNN();
92   
93   
94    /**
95       \brief Get the number of nearest neighbors.
96       \return The number of neighbors.
97    */
98    unsigned int k() const;
99
100    /**
101       \brief Set the number of nearest neighbors.
102       
103       Sets the number of neighbors to \a k_in.
104    */
105    void k(unsigned int k_in);
106
107
108    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
109   
110    /**
111       \brief Make predictions for unweighted test data.
112       
113       Predictions are calculated and returned in \a results.  For
114       each sample in \a data, \a results contains the weighted number
115       of nearest neighbors which belong to each class. Numbers of
116       nearest neighbors are weighted according to
117       NeighborWeighting. If a class has no training samples NaN's are
118       returned for this class in \a results.
119    */
120    void predict(const MatrixLookup& data , utility::Matrix& results) const;
121
122    /**   
123        \brief Make predictions for weighted test data.
124       
125        Predictions are calculated and returned in \a results. For
126        each sample in \a data, \a results contains the weighted
127        number of nearest neighbors which belong to each class as in
128        predict(const MatrixLookup& data, utility::Matrix& results).
129        If a test and training sample pair has no variables with
130        non-zero weights in common, there are no variables which can
131        be used to calculate the distance between the two samples. In
132        this case the distance between the two is set to infinity.
133    */
134    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
135
136
137    /**
138       \brief Train the KNN using unweighted training data with known
139       targets.
140       
141       For KNN there is no actual training; the entire training data
142       set is stored with targets. KNN only stores references to \a data
143       and \a targets as copying these would make the %classifier
144       slow. If the number of training samples set is smaller than k,
145       k is set to the number of training samples.
146       
147       \note If \a data or \a targets go out of scope ore are
148       deleted, the KNN becomes invalid and further use is undefined
149       unless it is trained again.
150    */
151    void train(const MatrixLookup& data, const Target& targets);
152   
153    /**   
154       \brief Train the KNN using weighted training data with known targets.
155   
156       See train(const MatrixLookup& data, const Target& targets) for
157       additional information.
158    */
159    void train(const MatrixLookupWeighted& data, const Target& targets);
160   
161  private:
162   
163    const MatrixLookup* data_ml_;
164    const MatrixLookupWeighted* data_mlw_;
165    const Target* target_;
166
167    // The number of neighbors
168    unsigned int k_;
169
170    Distance distance_;
171    NeighborWeighting weighting_;
172
173    void calculate_unweighted(const MatrixLookup&,
174                              const MatrixLookup&,
175                              utility::Matrix*) const;
176    void calculate_weighted(const MatrixLookupWeighted&,
177                            const MatrixLookupWeighted&,
178                            utility::Matrix*) const;
179
180    void predict_common(const utility::Matrix& distances, 
181                        utility::Matrix& prediction) const;
182
183  };
184 
185 
186  // templates
187 
188  template <typename Distance, typename NeighborWeighting>
189  KNN<Distance, NeighborWeighting>::KNN() 
190    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
191  {
192  }
193
194  template <typename Distance, typename NeighborWeighting>
195  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
196    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
197  {
198  }
199
200 
201  template <typename Distance, typename NeighborWeighting>
202  KNN<Distance, NeighborWeighting>::~KNN()   
203  {
204  }
205 
206
207  template <typename Distance, typename NeighborWeighting>
208  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
209  (const MatrixLookup& training, const MatrixLookup& test,
210   utility::Matrix* distances) const
211  {
212    for(size_t i=0; i<training.columns(); i++) {
213      for(size_t j=0; j<test.columns(); j++) {
214        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
215                                      test.begin_column(j));
216        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
217      }
218    }
219  }
220
221 
222  template <typename Distance, typename NeighborWeighting>
223  void 
224  KNN<Distance, NeighborWeighting>::calculate_weighted
225  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
226   utility::Matrix* distances) const
227  {
228    for(size_t i=0; i<training.columns(); i++) { 
229      for(size_t j=0; j<test.columns(); j++) {
230        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
231                                      test.begin_column(j));
232        // If the distance is NaN (no common variables with non-zero weights),
233        // the distance is set to infinity to be sorted as a neighbor at the end
234        if(std::isnan((*distances)(i,j))) 
235          (*distances)(i,j)=std::numeric_limits<double>::infinity();
236      }
237    }
238  }
239 
240 
241  template <typename Distance, typename NeighborWeighting>
242  unsigned int KNN<Distance, NeighborWeighting>::k() const
243  {
244    return k_;
245  }
246
247  template <typename Distance, typename NeighborWeighting>
248  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
249  {
250    k_=k;
251  }
252
253
254  template <typename Distance, typename NeighborWeighting>
255  KNN<Distance, NeighborWeighting>* 
256  KNN<Distance, NeighborWeighting>::make_classifier() const 
257  {     
258    // All private members should be copied here to generate an
259    // identical but untrained classifier
260    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
261    knn->weighting_=this->weighting_;
262    knn->k(this->k());
263    return knn;
264  }
265 
266 
267  template <typename Distance, typename NeighborWeighting>
268  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
269                                               const Target& target)
270  {   
271    utility::yat_assert<std::runtime_error>
272      (data.columns()==target.size(),
273       "KNN::train called with different sizes of target and data");
274    // k has to be at most the number of training samples.
275    if(data.columns()<k_) 
276      k_=data.columns();
277    data_ml_=&data;
278    data_mlw_=0;
279    target_=&target;
280  }
281
282  template <typename Distance, typename NeighborWeighting>
283  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
284                                               const Target& target)
285  {   
286    utility::yat_assert<std::runtime_error>
287      (data.columns()==target.size(),
288       "KNN::train called with different sizes of target and data");
289    // k has to be at most the number of training samples.
290    if(data.columns()<k_) 
291      k_=data.columns();
292    data_ml_=0;
293    data_mlw_=&data;
294    target_=&target;
295  }
296
297
298  template <typename Distance, typename NeighborWeighting>
299  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
300                                                 utility::Matrix& prediction) const
301  {   
302    // matrix with training samples as rows and test samples as columns
303    utility::Matrix* distances = 0;
304    // unweighted training data
305    if(data_ml_ && !data_mlw_) {
306      utility::yat_assert<std::runtime_error>
307        (data_ml_->rows()==test.rows(),
308         "KNN::predict different number of rows in training and test data");     
309      distances=new utility::Matrix(data_ml_->columns(),test.columns());
310      calculate_unweighted(*data_ml_,test,distances);
311    }
312    else if (data_mlw_ && !data_ml_) {
313      // weighted training data
314      utility::yat_assert<std::runtime_error>
315        (data_mlw_->rows()==test.rows(),
316         "KNN::predict different number of rows in training and test data");           
317      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
318      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
319                         distances);             
320    }
321    else {
322      std::runtime_error("KNN::predict no training data");
323    }
324
325    prediction.resize(target_->nof_classes(),test.columns(),0.0);
326    predict_common(*distances,prediction);
327    if(distances)
328      delete distances;
329  }
330
331  template <typename Distance, typename NeighborWeighting>
332  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
333                                                 utility::Matrix& prediction) const
334  {   
335    // matrix with training samples as rows and test samples as columns
336    utility::Matrix* distances=0; 
337    // unweighted training data
338    if(data_ml_ && !data_mlw_) { 
339      utility::yat_assert<std::runtime_error>
340        (data_ml_->rows()==test.rows(),
341         "KNN::predict different number of rows in training and test data");   
342      distances=new utility::Matrix(data_ml_->columns(),test.columns());
343      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
344    }
345    // weighted training data
346    else if (data_mlw_ && !data_ml_) {
347      utility::yat_assert<std::runtime_error>
348        (data_mlw_->rows()==test.rows(),
349         "KNN::predict different number of rows in training and test data");   
350      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
351      calculate_weighted(*data_mlw_,test,distances);             
352    }
353    else {
354      std::runtime_error("KNN::predict no training data");
355    }
356
357    prediction.resize(target_->nof_classes(),test.columns(),0.0);
358    predict_common(*distances,prediction);
359   
360    if(distances)
361      delete distances;
362  }
363 
364  template <typename Distance, typename NeighborWeighting>
365  void KNN<Distance, NeighborWeighting>::predict_common
366  (const utility::Matrix& distances, utility::Matrix& prediction) const
367  {   
368    for(size_t sample=0;sample<distances.columns();sample++) {
369      std::vector<size_t> k_index;
370      utility::VectorConstView dist=distances.column_const_view(sample);
371      utility::sort_smallest_index(k_index,k_,dist);
372      utility::VectorView pred=prediction.column_view(sample);
373      weighting_(dist,k_index,*target_,pred);
374    }
375   
376    // classes for which there are no training samples should be set
377    // to nan in the predictions
378    for(size_t c=0;c<target_->nof_classes(); c++) 
379      if(!target_->size(c)) 
380        for(size_t j=0;j<prediction.columns();j++)
381          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
382  }
383
384
385}}} // of namespace classifier, yat, and theplu
386
387#endif
388
Note: See TracBrowser for help on using the repository browser.