source: trunk/yat/classifier/KNN.h @ 2337

Last change on this file since 2337 was 2337, checked in by Peter, 12 years ago

missing includes

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 12.6 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 2337 2010-10-15 21:30:53Z peter $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009 Peter Johansson
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "DataLookup1D.h"
27#include "DataLookupWeighted1D.h"
28#include "KNN_Uniform.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/concept_check.h"
34#include "yat/utility/Exception.h"
35#include "yat/utility/Matrix.h"
36#include "yat/utility/VectorConstView.h"
37#include "yat/utility/VectorView.h"
38#include "yat/utility/yat_assert.h"
39
40#include <boost/concept_check.hpp>
41
42#include <cmath>
43#include <limits>
44#include <map>
45#include <stdexcept>
46#include <vector>
47
48namespace theplu {
49namespace yat {
50namespace classifier {
51
52  /**
53     \brief Nearest Neighbor Classifier
54     
55     A sample is predicted based on the classes of its k nearest
56     neighbors among the training data samples. KNN supports using
57     different measures, for example, Euclidean distance, to define
58     distance between samples. KNN also supports using different ways to
59     weight the votes of the k nearest neighbors. For example, using a
60     uniform vote a test sample gets a vote for each class which is the
61     number of nearest neighbors belonging to the class.
62     
63     The template argument Distance should be a class modelling the
64     concept \ref concept_distance. The template argument
65     NeighborWeighting should be a class modelling the concept \ref
66     concept_neighbor_weighting.
67  */
68  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
69  class KNN : public SupervisedClassifier
70  {
71   
72  public:
73    /**
74       \brief Default constructor.
75       
76       The number of nearest neighbors (k) is set to 3. Distance and
77       NeighborWeighting are initialized using their default
78       constructuors.
79    */
80    KNN(void);
81
82
83    /**
84       \brief Constructor using an intialized distance measure.
85       
86       The number of nearest neighbors (k) is set to
87       3. NeighborWeighting is initialized using its default
88       constructor. This constructor should be used if Distance has
89       parameters and the user wants to specify the parameters by
90       initializing Distance prior to constructing the KNN.
91    */ 
92    KNN(const Distance&);
93
94
95    /**
96       Destructor
97    */
98    virtual ~KNN();
99   
100   
101    /**
102       \brief Get the number of nearest neighbors.
103       \return The number of neighbors.
104    */
105    unsigned int k() const;
106
107    /**
108       \brief Set the number of nearest neighbors.
109       
110       Sets the number of neighbors to \a k_in.
111    */
112    void k(unsigned int k_in);
113
114
115    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
116   
117    /**
118       \brief Make predictions for unweighted test data.
119       
120       Predictions are calculated and returned in \a results.  For
121       each sample in \a data, \a results contains the weighted number
122       of nearest neighbors which belong to each class. Numbers of
123       nearest neighbors are weighted according to
124       NeighborWeighting. If a class has no training samples NaN's are
125       returned for this class in \a results.
126    */
127    void predict(const MatrixLookup& data , utility::Matrix& results) const;
128
129    /**   
130        \brief Make predictions for weighted test data.
131       
132        Predictions are calculated and returned in \a results. For
133        each sample in \a data, \a results contains the weighted
134        number of nearest neighbors which belong to each class as in
135        predict(const MatrixLookup& data, utility::Matrix& results).
136        If a test and training sample pair has no variables with
137        non-zero weights in common, there are no variables which can
138        be used to calculate the distance between the two samples. In
139        this case the distance between the two is set to infinity.
140    */
141    void predict(const MatrixLookupWeighted& data, 
142                 utility::Matrix& results) const;
143
144
145    /**
146       \brief Train the KNN using unweighted training data with known
147       targets.
148       
149       For KNN there is no actual training; the entire training data
150       set is stored with targets. KNN only stores references to \a data
151       and \a targets as copying these would make the %classifier
152       slow. If the number of training samples set is smaller than k,
153       k is set to the number of training samples.
154       
155       \note If \a data or \a targets go out of scope ore are
156       deleted, the KNN becomes invalid and further use is undefined
157       unless it is trained again.
158    */
159    void train(const MatrixLookup& data, const Target& targets);
160   
161    /**   
162       \brief Train the KNN using weighted training data with known targets.
163   
164       See train(const MatrixLookup& data, const Target& targets) for
165       additional information.
166    */
167    void train(const MatrixLookupWeighted& data, const Target& targets);
168   
169  private:
170   
171    const MatrixLookup* data_ml_;
172    const MatrixLookupWeighted* data_mlw_;
173    const Target* target_;
174
175    // The number of neighbors
176    unsigned int k_;
177
178    Distance distance_;
179    NeighborWeighting weighting_;
180
181    void calculate_unweighted(const MatrixLookup&,
182                              const MatrixLookup&,
183                              utility::Matrix*) const;
184    void calculate_weighted(const MatrixLookupWeighted&,
185                            const MatrixLookupWeighted&,
186                            utility::Matrix*) const;
187
188    void predict_common(const utility::Matrix& distances, 
189                        utility::Matrix& prediction) const;
190
191  };
192 
193 
194  // templates
195 
196  template <typename Distance, typename NeighborWeighting>
197  KNN<Distance, NeighborWeighting>::KNN() 
198    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
199  {
200    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
201  }
202
203  template <typename Distance, typename NeighborWeighting>
204  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
205    : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3), 
206      distance_(dist)
207  {
208    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
209  }
210
211 
212  template <typename Distance, typename NeighborWeighting>
213  KNN<Distance, NeighborWeighting>::~KNN()   
214  {
215  }
216 
217
218  template <typename Distance, typename NeighborWeighting>
219  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
220  (const MatrixLookup& training, const MatrixLookup& test,
221   utility::Matrix* distances) const
222  {
223    for(size_t i=0; i<training.columns(); i++) {
224      for(size_t j=0; j<test.columns(); j++) {
225        (*distances)(i,j) = distance_(training.begin_column(i), 
226                                      training.end_column(i), 
227                                      test.begin_column(j));
228        YAT_ASSERT(!std::isnan((*distances)(i,j)));
229      }
230    }
231  }
232
233 
234  template <typename Distance, typename NeighborWeighting>
235  void 
236  KNN<Distance, NeighborWeighting>::calculate_weighted
237  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
238   utility::Matrix* distances) const
239  {
240    for(size_t i=0; i<training.columns(); i++) { 
241      for(size_t j=0; j<test.columns(); j++) {
242        (*distances)(i,j) = distance_(training.begin_column(i), 
243                                      training.end_column(i), 
244                                      test.begin_column(j));
245        // If the distance is NaN (no common variables with non-zero weights),
246        // the distance is set to infinity to be sorted as a neighbor at the end
247        if(std::isnan((*distances)(i,j))) 
248          (*distances)(i,j)=std::numeric_limits<double>::infinity();
249      }
250    }
251  }
252 
253 
254  template <typename Distance, typename NeighborWeighting>
255  unsigned int KNN<Distance, NeighborWeighting>::k() const
256  {
257    return k_;
258  }
259
260  template <typename Distance, typename NeighborWeighting>
261  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
262  {
263    k_=k;
264  }
265
266
267  template <typename Distance, typename NeighborWeighting>
268  KNN<Distance, NeighborWeighting>* 
269  KNN<Distance, NeighborWeighting>::make_classifier() const 
270  {     
271    // All private members should be copied here to generate an
272    // identical but untrained classifier
273    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
274    knn->weighting_=this->weighting_;
275    knn->k(this->k());
276    return knn;
277  }
278 
279 
280  template <typename Distance, typename NeighborWeighting>
281  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
282                                               const Target& target)
283  {   
284    utility::yat_assert<utility::runtime_error>
285      (data.columns()==target.size(),
286       "KNN::train called with different sizes of target and data");
287    // k has to be at most the number of training samples.
288    if(data.columns()<k_) 
289      k_=data.columns();
290    data_ml_=&data;
291    data_mlw_=0;
292    target_=&target;
293  }
294
295  template <typename Distance, typename NeighborWeighting>
296  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
297                                               const Target& target)
298  {   
299    utility::yat_assert<utility::runtime_error>
300      (data.columns()==target.size(),
301       "KNN::train called with different sizes of target and data");
302    // k has to be at most the number of training samples.
303    if(data.columns()<k_) 
304      k_=data.columns();
305    data_ml_=0;
306    data_mlw_=&data;
307    target_=&target;
308  }
309
310
311  template <typename Distance, typename NeighborWeighting>
312  void 
313  KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
314                                            utility::Matrix& prediction) const
315  {   
316    // matrix with training samples as rows and test samples as columns
317    utility::Matrix* distances = 0;
318    // unweighted training data
319    if(data_ml_ && !data_mlw_) {
320      utility::yat_assert<utility::runtime_error>
321        (data_ml_->rows()==test.rows(),
322         "KNN::predict different number of rows in training and test data");
323      distances=new utility::Matrix(data_ml_->columns(),test.columns());
324      calculate_unweighted(*data_ml_,test,distances);
325    }
326    else if (data_mlw_ && !data_ml_) {
327      // weighted training data
328      utility::yat_assert<utility::runtime_error>
329        (data_mlw_->rows()==test.rows(),
330         "KNN::predict different number of rows in training and test data");
331      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
332      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
333                         distances);             
334    }
335    else {
336      throw utility::runtime_error("KNN::predict no training data");
337    }
338
339    prediction.resize(target_->nof_classes(),test.columns(),0.0);
340    predict_common(*distances,prediction);
341    if(distances)
342      delete distances;
343  }
344
345  template <typename Distance, typename NeighborWeighting>
346  void 
347  KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
348                                            utility::Matrix& prediction) const
349  {   
350    // matrix with training samples as rows and test samples as columns
351    utility::Matrix* distances=0; 
352    // unweighted training data
353    if(data_ml_ && !data_mlw_) { 
354      utility::yat_assert<utility::runtime_error>
355        (data_ml_->rows()==test.rows(),
356         "KNN::predict different number of rows in training and test data");   
357      distances=new utility::Matrix(data_ml_->columns(),test.columns());
358      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
359    }
360    // weighted training data
361    else if (data_mlw_ && !data_ml_) {
362      utility::yat_assert<utility::runtime_error>
363        (data_mlw_->rows()==test.rows(),
364         "KNN::predict different number of rows in training and test data");   
365      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
366      calculate_weighted(*data_mlw_,test,distances);             
367    }
368    else {
369      throw utility::runtime_error("KNN::predict no training data");
370    }
371
372    prediction.resize(target_->nof_classes(),test.columns(),0.0);
373    predict_common(*distances,prediction);
374   
375    if(distances)
376      delete distances;
377  }
378 
379  template <typename Distance, typename NeighborWeighting>
380  void KNN<Distance, NeighborWeighting>::predict_common
381  (const utility::Matrix& distances, utility::Matrix& prediction) const
382  {   
383    for(size_t sample=0;sample<distances.columns();sample++) {
384      std::vector<size_t> k_index;
385      utility::VectorConstView dist=distances.column_const_view(sample);
386      utility::sort_smallest_index(k_index,k_,dist);
387      utility::VectorView pred=prediction.column_view(sample);
388      weighting_(dist,k_index,*target_,pred);
389    }
390   
391    // classes for which there are no training samples should be set
392    // to nan in the predictions
393    for(size_t c=0;c<target_->nof_classes(); c++) 
394      if(!target_->size(c)) 
395        for(size_t j=0;j<prediction.columns();j++)
396          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
397  }
398}}} // of namespace classifier, yat, and theplu
399
400#endif
Note: See TracBrowser for help on using the repository browser.