source: trunk/yat/classifier/KNN.h @ 1164

Last change on this file since 1164 was 1164, checked in by Markus Ringnér, 14 years ago

Refs. #318

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 10.0 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1164 2008-02-26 18:36:52Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// @brief Constructor
60    ///
61    KNN(void);
62
63
64    ///
65    /// @brief Constructor
66    ///
67    KNN(const Distance&);
68
69
70    ///
71    /// @brief Destructor
72    ///
73    virtual ~KNN();
74   
75
76    ///
77    /// Default number of neighbors (k) is set to 3.
78    ///
79    /// @return the number of neighbors
80    ///
81    u_int k() const;
82
83    ///
84    /// @brief sets the number of neighbors, k.
85    ///
86    void k(u_int k_in);
87
88
89    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
90   
91    ///
92    /// Train the classifier using training data and target.
93    ///
94    /// If the number of training samples set is smaller than \a k_in,
95    /// k is set to the number of training samples.
96    ///
97    void train(const MatrixLookup&, const Target&);
98
99    ///
100    /// Train the classifier using weighted training data and target.
101    ///
102    void train(const MatrixLookupWeighted&, const Target&);
103
104   
105    ///
106    /// For each sample, calculate the number of neighbors for each
107    /// class.
108    ///
109    void predict(const MatrixLookup&, utility::Matrix&) const;
110
111    ///
112    /// For each sample, calculate the number of neighbors for each
113    /// class.
114    ///
115    void predict(const MatrixLookupWeighted&, utility::Matrix&) const;
116
117
118  private:
119
120    const MatrixLookup* data_ml_;
121    const MatrixLookupWeighted* data_mlw_;
122    const Target* target_;
123
124    // The number of neighbors
125    u_int k_;
126
127    Distance distance_;
128    NeighborWeighting weighting_;
129
130    void calculate_unweighted(const MatrixLookup&,
131                              const MatrixLookup&,
132                              utility::Matrix*) const;
133    void calculate_weighted(const MatrixLookupWeighted&,
134                            const MatrixLookupWeighted&,
135                            utility::Matrix*) const;
136
137    void predict_common(const utility::Matrix& distances, 
138                        utility::Matrix& prediction) const;
139
140  };
141 
142 
143  // templates
144 
145  template <typename Distance, typename NeighborWeighting>
146  KNN<Distance, NeighborWeighting>::KNN() 
147    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
148  {
149  }
150
151  template <typename Distance, typename NeighborWeighting>
152  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
153    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
154  {
155  }
156
157 
158  template <typename Distance, typename NeighborWeighting>
159  KNN<Distance, NeighborWeighting>::~KNN()   
160  {
161  }
162 
163
164  template <typename Distance, typename NeighborWeighting>
165  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
166  (const MatrixLookup& training, const MatrixLookup& test,
167   utility::Matrix* distances) const
168  {
169    for(size_t i=0; i<training.columns(); i++) {
170      for(size_t j=0; j<test.columns(); j++) {
171        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
172                                      test.begin_column(j));
173        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
174      }
175    }
176  }
177
178 
179  template <typename Distance, typename NeighborWeighting>
180  void 
181  KNN<Distance, NeighborWeighting>::calculate_weighted
182  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
183   utility::Matrix* distances) const
184  {
185    for(size_t i=0; i<training.columns(); i++) { 
186      for(size_t j=0; j<test.columns(); j++) {
187        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
188                                      test.begin_column(j));
189        // If the distance is NaN (no common variables with non-zero weights),
190        // the distance is set to infinity to be sorted as a neighbor at the end
191        if(std::isnan((*distances)(i,j))) 
192          (*distances)(i,j)=std::numeric_limits<double>::infinity();
193      }
194    }
195  }
196 
197 
198  template <typename Distance, typename NeighborWeighting>
199  u_int KNN<Distance, NeighborWeighting>::k() const
200  {
201    return k_;
202  }
203
204  template <typename Distance, typename NeighborWeighting>
205  void KNN<Distance, NeighborWeighting>::k(u_int k)
206  {
207    k_=k;
208  }
209
210
211  template <typename Distance, typename NeighborWeighting>
212  KNN<Distance, NeighborWeighting>* 
213  KNN<Distance, NeighborWeighting>::make_classifier() const 
214  {     
215    // All private members should be copied here to generate an
216    // identical but untrained classifier
217    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
218    knn->weighting_=this->weighting_;
219    knn->k(this->k());
220    return knn;
221  }
222 
223 
224  template <typename Distance, typename NeighborWeighting>
225  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
226                                               const Target& target)
227  {   
228    utility::yat_assert<std::runtime_error>
229      (data.columns()==target.size(),
230       "KNN::train called with different sizes of target and data");
231    // k has to be at most the number of training samples.
232    if(data.columns()<k_) 
233      k_=data.columns();
234    data_ml_=&data;
235    data_mlw_=0;
236    target_=&target;
237  }
238
239  template <typename Distance, typename NeighborWeighting>
240  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
241                                               const Target& target)
242  {   
243    utility::yat_assert<std::runtime_error>
244      (data.columns()==target.size(),
245       "KNN::train called with different sizes of target and data");
246    // k has to be at most the number of training samples.
247    if(data.columns()<k_) 
248      k_=data.columns();
249    data_ml_=0;
250    data_mlw_=&data;
251    target_=&target;
252  }
253
254
255  template <typename Distance, typename NeighborWeighting>
256  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
257                                                 utility::Matrix& prediction) const
258  {   
259    // matrix with training samples as rows and test samples as columns
260    utility::Matrix* distances = 0;
261    // unweighted training data
262    if(data_ml_ && !data_mlw_) {
263      utility::yat_assert<std::runtime_error>
264        (data_ml_->rows()==test.rows(),
265         "KNN::predict different number of rows in training and test data");     
266      distances=new utility::Matrix(data_ml_->columns(),test.columns());
267      calculate_unweighted(*data_ml_,test,distances);
268    }
269    else if (data_mlw_ && !data_ml_) {
270      // weighted training data
271      utility::yat_assert<std::runtime_error>
272        (data_mlw_->rows()==test.rows(),
273         "KNN::predict different number of rows in training and test data");           
274      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
275      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
276                         distances);             
277    }
278    else {
279      std::runtime_error("KNN::predict no training data");
280    }
281
282    prediction.resize(target_->nof_classes(),test.columns(),0.0);
283    predict_common(*distances,prediction);
284    if(distances)
285      delete distances;
286  }
287
288  template <typename Distance, typename NeighborWeighting>
289  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
290                                                 utility::Matrix& prediction) const
291  {   
292    // matrix with training samples as rows and test samples as columns
293    utility::Matrix* distances=0; 
294    // unweighted training data
295    if(data_ml_ && !data_mlw_) { 
296      utility::yat_assert<std::runtime_error>
297        (data_ml_->rows()==test.rows(),
298         "KNN::predict different number of rows in training and test data");   
299      distances=new utility::Matrix(data_ml_->columns(),test.columns());
300      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
301    }
302    // weighted training data
303    else if (data_mlw_ && !data_ml_) {
304      utility::yat_assert<std::runtime_error>
305        (data_mlw_->rows()==test.rows(),
306         "KNN::predict different number of rows in training and test data");   
307      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
308      calculate_weighted(*data_mlw_,test,distances);             
309    }
310    else {
311      std::runtime_error("KNN::predict no training data");
312    }
313
314    prediction.resize(target_->nof_classes(),test.columns(),0.0);
315    predict_common(*distances,prediction);
316   
317    if(distances)
318      delete distances;
319  }
320 
321  template <typename Distance, typename NeighborWeighting>
322  void KNN<Distance, NeighborWeighting>::predict_common
323  (const utility::Matrix& distances, utility::Matrix& prediction) const
324  {   
325    for(size_t sample=0;sample<distances.columns();sample++) {
326      std::vector<size_t> k_index;
327      utility::VectorConstView dist=distances.column_const_view(sample);
328      utility::sort_smallest_index(k_index,k_,dist);
329      utility::VectorView pred=prediction.column_view(sample);
330      weighting_(dist,k_index,*target_,pred);
331    }
332   
333    // classes for which there are no training samples should be set
334    // to nan in the predictions
335    for(size_t c=0;c<target_->nof_classes(); c++) 
336      if(!target_->size(c)) 
337        for(size_t j=0;j<prediction.columns();j++)
338          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
339  }
340
341
342}}} // of namespace classifier, yat, and theplu
343
344#endif
345
Note: See TracBrowser for help on using the repository browser.