source: trunk/yat/classifier/KNN.h @ 1188

Last change on this file since 1188 was 1188, checked in by Markus Ringnér, 14 years ago

Fixed KNN and SupervisedClassfier? for #75

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.2 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1188 2008-02-29 10:14:04Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  /**
46     @brief Nearest Neighbor Classifier
47     
48     A sample is predicted based on the classes of its k nearest
49     neighbors among the training data samples. KNN supports using
50     different measures, for example, Euclidean distance, to define
51     distance between samples. KNN also supports using different ways to
52     weight the votes of the k nearest neighbors. For example, using a
53     uniform vote a test sample gets a vote for each class which is the
54     number of nearest neighbors belonging to the class.
55     
56     The template argument Distance should be a class modelling the
57     concept \ref concept_distance. The template argument
58     NeighborWeighting should be a class modelling the concept \ref
59     concept_neighbor_weighting.
60  */
61  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
62  class KNN : public SupervisedClassifier
63  {
64   
65  public:
66    /**
67       @brief Default constructor.
68       
69       The number of nearest neighbors (k) is set to 3. Distance and
70       NeighborWeighting are initialized using their default
71       constructuors.
72    */
73    KNN(void);
74
75
76    /**
77       @brief Constructor using an intialized distance measure.
78       
79       The number of nearest neighbors (k) is set to 3. This constructor
80       should be used if Distance has parameters and the user wants
81       to specify the parameters by initializing Distance prior to
82       constructing the KNN.
83    */ 
84    KNN(const Distance&);
85
86
87    /**
88       Destructor
89    */
90    virtual ~KNN();
91   
92   
93    /**
94       \brief Get the number of nearest neighbors.
95       \return The number of neighbors.
96    */
97    u_int k() const;
98
99    /**
100       \brief Set the number of nearest neighbors.
101       
102       Sets the number of neighbors to \a k_in.
103    */
104    void k(u_int k_in);
105
106
107    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
108   
109    /**
110       @brief Make predictions for unweighted test data.
111       
112       Predictions are calculated and returned in \a results.  For
113       each sample in \a data, \a results contains the weighted number
114       of nearest neighbors which belong to each class. Numbers of
115       nearest neighbors are weighted according to
116       NeighborWeighting. If a class has no training samples NaN's are
117       returned for this class in \a results.
118    */
119    void predict(const MatrixLookup& data , utility::Matrix& results) const;
120
121    /**   
122        @brief Make predictions for weighted test data.
123       
124        Predictions are calculated and returned in \a results. For
125        each sample in \a data, \a results contains the weighted
126        number of nearest neighbors which belong to each class as in
127        predict(const MatrixLookup& data, utility::Matrix& results).
128        If a test and training sample pair has no variables with
129        non-zero weights in common, there are no variables which can
130        be used to calculate the distance between the two samples. In
131        this case the distance between the two is set to infinity.
132    */
133    void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
134
135
136    /**
137       @brief Train the KNN using unweighted training data with known
138       targets.
139       
140       For KNN there is no actual training; the entire training data
141       set is stored with targets. KNN only stores references to \a data
142       and \a targets as copying these would make the %classifier
143       slow. If the number of training samples set is smaller than k,
144       k is set to the number of training samples.
145       
146       \note If \a data or \a targets go out of scope ore are
147       deleted, the KNN becomes invalid and further use is undefined
148       unless it is trained again.
149    */
150    void train(const MatrixLookup& data, const Target& targets);
151   
152    /**   
153       \brief Train the KNN using weighted training data with known targets.
154   
155       See train(const MatrixLookup& data, const Target& targets) for
156       additional information.
157    */
158    void train(const MatrixLookupWeighted& data, const Target& targets);
159   
160  private:
161   
162    const MatrixLookup* data_ml_;
163    const MatrixLookupWeighted* data_mlw_;
164    const Target* target_;
165
166    // The number of neighbors
167    u_int k_;
168
169    Distance distance_;
170    NeighborWeighting weighting_;
171
172    void calculate_unweighted(const MatrixLookup&,
173                              const MatrixLookup&,
174                              utility::Matrix*) const;
175    void calculate_weighted(const MatrixLookupWeighted&,
176                            const MatrixLookupWeighted&,
177                            utility::Matrix*) const;
178
179    void predict_common(const utility::Matrix& distances, 
180                        utility::Matrix& prediction) const;
181
182  };
183 
184 
185  // templates
186 
187  template <typename Distance, typename NeighborWeighting>
188  KNN<Distance, NeighborWeighting>::KNN() 
189    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
190  {
191  }
192
193  template <typename Distance, typename NeighborWeighting>
194  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
195    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
196  {
197  }
198
199 
200  template <typename Distance, typename NeighborWeighting>
201  KNN<Distance, NeighborWeighting>::~KNN()   
202  {
203  }
204 
205
206  template <typename Distance, typename NeighborWeighting>
207  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
208  (const MatrixLookup& training, const MatrixLookup& test,
209   utility::Matrix* distances) const
210  {
211    for(size_t i=0; i<training.columns(); i++) {
212      for(size_t j=0; j<test.columns(); j++) {
213        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
214                                      test.begin_column(j));
215        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
216      }
217    }
218  }
219
220 
221  template <typename Distance, typename NeighborWeighting>
222  void 
223  KNN<Distance, NeighborWeighting>::calculate_weighted
224  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
225   utility::Matrix* distances) const
226  {
227    for(size_t i=0; i<training.columns(); i++) { 
228      for(size_t j=0; j<test.columns(); j++) {
229        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
230                                      test.begin_column(j));
231        // If the distance is NaN (no common variables with non-zero weights),
232        // the distance is set to infinity to be sorted as a neighbor at the end
233        if(std::isnan((*distances)(i,j))) 
234          (*distances)(i,j)=std::numeric_limits<double>::infinity();
235      }
236    }
237  }
238 
239 
240  template <typename Distance, typename NeighborWeighting>
241  u_int KNN<Distance, NeighborWeighting>::k() const
242  {
243    return k_;
244  }
245
246  template <typename Distance, typename NeighborWeighting>
247  void KNN<Distance, NeighborWeighting>::k(u_int k)
248  {
249    k_=k;
250  }
251
252
253  template <typename Distance, typename NeighborWeighting>
254  KNN<Distance, NeighborWeighting>* 
255  KNN<Distance, NeighborWeighting>::make_classifier() const 
256  {     
257    // All private members should be copied here to generate an
258    // identical but untrained classifier
259    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
260    knn->weighting_=this->weighting_;
261    knn->k(this->k());
262    return knn;
263  }
264 
265 
266  template <typename Distance, typename NeighborWeighting>
267  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
268                                               const Target& target)
269  {   
270    utility::yat_assert<std::runtime_error>
271      (data.columns()==target.size(),
272       "KNN::train called with different sizes of target and data");
273    // k has to be at most the number of training samples.
274    if(data.columns()<k_) 
275      k_=data.columns();
276    data_ml_=&data;
277    data_mlw_=0;
278    target_=&target;
279  }
280
281  template <typename Distance, typename NeighborWeighting>
282  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
283                                               const Target& target)
284  {   
285    utility::yat_assert<std::runtime_error>
286      (data.columns()==target.size(),
287       "KNN::train called with different sizes of target and data");
288    // k has to be at most the number of training samples.
289    if(data.columns()<k_) 
290      k_=data.columns();
291    data_ml_=0;
292    data_mlw_=&data;
293    target_=&target;
294  }
295
296
297  template <typename Distance, typename NeighborWeighting>
298  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
299                                                 utility::Matrix& prediction) const
300  {   
301    // matrix with training samples as rows and test samples as columns
302    utility::Matrix* distances = 0;
303    // unweighted training data
304    if(data_ml_ && !data_mlw_) {
305      utility::yat_assert<std::runtime_error>
306        (data_ml_->rows()==test.rows(),
307         "KNN::predict different number of rows in training and test data");     
308      distances=new utility::Matrix(data_ml_->columns(),test.columns());
309      calculate_unweighted(*data_ml_,test,distances);
310    }
311    else if (data_mlw_ && !data_ml_) {
312      // weighted training data
313      utility::yat_assert<std::runtime_error>
314        (data_mlw_->rows()==test.rows(),
315         "KNN::predict different number of rows in training and test data");           
316      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
317      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
318                         distances);             
319    }
320    else {
321      std::runtime_error("KNN::predict no training data");
322    }
323
324    prediction.resize(target_->nof_classes(),test.columns(),0.0);
325    predict_common(*distances,prediction);
326    if(distances)
327      delete distances;
328  }
329
330  template <typename Distance, typename NeighborWeighting>
331  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
332                                                 utility::Matrix& prediction) const
333  {   
334    // matrix with training samples as rows and test samples as columns
335    utility::Matrix* distances=0; 
336    // unweighted training data
337    if(data_ml_ && !data_mlw_) { 
338      utility::yat_assert<std::runtime_error>
339        (data_ml_->rows()==test.rows(),
340         "KNN::predict different number of rows in training and test data");   
341      distances=new utility::Matrix(data_ml_->columns(),test.columns());
342      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
343    }
344    // weighted training data
345    else if (data_mlw_ && !data_ml_) {
346      utility::yat_assert<std::runtime_error>
347        (data_mlw_->rows()==test.rows(),
348         "KNN::predict different number of rows in training and test data");   
349      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
350      calculate_weighted(*data_mlw_,test,distances);             
351    }
352    else {
353      std::runtime_error("KNN::predict no training data");
354    }
355
356    prediction.resize(target_->nof_classes(),test.columns(),0.0);
357    predict_common(*distances,prediction);
358   
359    if(distances)
360      delete distances;
361  }
362 
363  template <typename Distance, typename NeighborWeighting>
364  void KNN<Distance, NeighborWeighting>::predict_common
365  (const utility::Matrix& distances, utility::Matrix& prediction) const
366  {   
367    for(size_t sample=0;sample<distances.columns();sample++) {
368      std::vector<size_t> k_index;
369      utility::VectorConstView dist=distances.column_const_view(sample);
370      utility::sort_smallest_index(k_index,k_,dist);
371      utility::VectorView pred=prediction.column_view(sample);
372      weighting_(dist,k_index,*target_,pred);
373    }
374   
375    // classes for which there are no training samples should be set
376    // to nan in the predictions
377    for(size_t c=0;c<target_->nof_classes(); c++) 
378      if(!target_->size(c)) 
379        for(size_t j=0;j<prediction.columns();j++)
380          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
381  }
382
383
384}}} // of namespace classifier, yat, and theplu
385
386#endif
387
Note: See TracBrowser for help on using the repository browser.