source: trunk/yat/classifier/KNN.h @ 1158

Last change on this file since 1158 was 1158, checked in by Markus Ringnér, 14 years ago

Fixed #322

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.7 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1158 2008-02-26 13:49:38Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// @brief Constructor
60    ///
61    KNN(void);
62
63
64    ///
65    /// @brief Constructor
66    ///
67    KNN(const Distance&);
68
69
70    ///
71    /// @brief Destructor
72    ///
73    virtual ~KNN();
74   
75
76    ///
77    /// Default number of neighbors (k) is set to 3.
78    ///
79    /// @return the number of neighbors
80    ///
81    u_int k() const;
82
83    ///
84    /// @brief sets the number of neighbors, k.
85    ///
86    void k(u_int k_in);
87
88
89    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
90   
91    ///
92    /// Train the classifier using training data and target.
93    ///
94    /// If the number of training samples set is smaller than \a k_in,
95    /// k is set to the number of training samples.
96    ///
97    void train(const MatrixLookup&, const Target&);
98
99    ///
100    /// Train the classifier using weighted training data and target.
101    ///
102    void train(const MatrixLookupWeighted&, const Target&);
103
104   
105    ///
106    /// For each sample, calculate the number of neighbors for each
107    /// class.
108    ///
109    ///
110    void predict(const DataLookup2D&, utility::Matrix&) const;
111
112
113  private:
114
115    // data_ has to be of type DataLookup2D to accomodate both
116    // MatrixLookup and MatrixLookupWeighted
117    const DataLookup2D* data_;
118    const Target* target_;
119
120    // The number of neighbors
121    u_int k_;
122
123    Distance distance_;
124
125    NeighborWeighting weighting_;
126
127    ///
128    /// Calculates the distances between a data set and the training
129    /// data. The rows are training and the columns test samples,
130    /// respectively. The returned distance matrix is dynamically
131    /// generated and needs to be deleted by the caller.
132    ///
133    utility::Matrix* calculate_distances(const DataLookup2D&) const;
134
135    void calculate_unweighted(const MatrixLookup&,
136                              const MatrixLookup&,
137                              utility::Matrix*) const;
138    void calculate_weighted(const MatrixLookupWeighted&,
139                            const MatrixLookupWeighted&,
140                            utility::Matrix*) const;
141  };
142 
143 
144  // templates
145 
146  template <typename Distance, typename NeighborWeighting>
147  KNN<Distance, NeighborWeighting>::KNN() 
148    : SupervisedClassifier(),data_(0),target_(0),k_(3)
149  {
150  }
151
152  template <typename Distance, typename NeighborWeighting>
153  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
154    : SupervisedClassifier(),data_(0),target_(0),k_(3), distance_(dist)
155  {
156  }
157
158 
159  template <typename Distance, typename NeighborWeighting>
160  KNN<Distance, NeighborWeighting>::~KNN()   
161  {
162  }
163 
164  template <typename Distance, typename NeighborWeighting>
165  utility::Matrix* KNN<Distance, NeighborWeighting>::calculate_distances
166  (const DataLookup2D& test) const
167  {
168    // matrix with training samples as rows and test samples as columns
169    utility::Matrix* distances = 
170      new utility::Matrix(data_->columns(),test.columns());
171   
172   
173    // unweighted test data
174    if(const MatrixLookup* test_unweighted = 
175       dynamic_cast<const MatrixLookup*>(&test)) {     
176      // unweighted training data
177      if(const MatrixLookup* training_unweighted = 
178         dynamic_cast<const MatrixLookup*>(data_)) 
179        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
180      // weighted training data
181      else if(const MatrixLookupWeighted* training_weighted = 
182              dynamic_cast<const MatrixLookupWeighted*>(data_)) 
183        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
184                           distances);             
185      // Training data can not be of incorrect type
186    }
187    // weighted test data
188    else if (const MatrixLookupWeighted* test_weighted = 
189             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
190      // unweighted training data
191      if(const MatrixLookup* training_unweighted = 
192         dynamic_cast<const MatrixLookup*>(data_)) {
193        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
194                           *test_weighted,distances);
195      }
196      // weighted training data
197      else if(const MatrixLookupWeighted* training_weighted = 
198              dynamic_cast<const MatrixLookupWeighted*>(data_)) 
199        calculate_weighted(*training_weighted,*test_weighted,distances);             
200      // Training data can not be of incorrect type
201    } 
202    else {
203      std::string str;
204      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
205      throw std::runtime_error(str);
206    }
207    return distances;
208  }
209
210  template <typename Distance, typename NeighborWeighting>
211  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
212  (const MatrixLookup& training, const MatrixLookup& test,
213   utility::Matrix* distances) const
214  {
215    for(size_t i=0; i<training.columns(); i++) {
216      classifier::DataLookup1D training1(training,i,false);
217      for(size_t j=0; j<test.columns(); j++) {
218        classifier::DataLookup1D test1(test,j,false);
219        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
220        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
221      }
222    }
223  }
224 
225  template <typename Distance, typename NeighborWeighting>
226  void 
227  KNN<Distance, NeighborWeighting>::calculate_weighted
228  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
229   utility::Matrix* distances) const
230  {
231    for(size_t i=0; i<training.columns(); i++) {
232      classifier::DataLookupWeighted1D training1(training,i,false);
233      for(size_t j=0; j<test.columns(); j++) {
234        classifier::DataLookupWeighted1D test1(test,j,false);
235        (*distances)(i,j) = distance_(training1.begin(), training1.end(), 
236                                      test1.begin());
237        // If the distance is NaN (no common variables with non-zero weights),
238        // the distance is set to infinity to be sorted as a neighbor at the end
239        if(std::isnan((*distances)(i,j))) 
240          (*distances)(i,j)=std::numeric_limits<double>::infinity();
241      }
242    }
243  }
244 
245 
246  template <typename Distance, typename NeighborWeighting>
247  u_int KNN<Distance, NeighborWeighting>::k() const
248  {
249    return k_;
250  }
251
252  template <typename Distance, typename NeighborWeighting>
253  void KNN<Distance, NeighborWeighting>::k(u_int k)
254  {
255    k_=k;
256  }
257
258
259  template <typename Distance, typename NeighborWeighting>
260  KNN<Distance, NeighborWeighting>* 
261  KNN<Distance, NeighborWeighting>::make_classifier() const 
262  {     
263    KNN* knn=new KNN<Distance, NeighborWeighting>();
264    knn->k(this->k());
265    return knn;
266  }
267 
268 
269  template <typename Distance, typename NeighborWeighting>
270  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
271                                               const Target& target)
272  {   
273    utility::yat_assert<std::runtime_error>
274      (data.columns()==target.size(),
275       "KNN::train called with different sizes of target and data");
276    // k has to be at most the number of training samples.
277    if(data.columns()<k_) 
278      k_=data.columns();
279    data_=&data;
280    target_=&target;
281    trained_=true;
282  }
283
284  template <typename Distance, typename NeighborWeighting>
285  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
286                                               const Target& target)
287  {   
288    utility::yat_assert<std::runtime_error>
289      (data.columns()==target.size(),
290       "KNN::train called with different sizes of target and data");
291    // k has to be at most the number of training samples.
292    if(data.columns()<k_) 
293      k_=data.columns();
294    data_=&data;
295    target_=&target;
296    trained_=true;
297  }
298
299
300  template <typename Distance, typename NeighborWeighting>
301  void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test,
302                                                 utility::Matrix& prediction) const
303  {   
304    utility::yat_assert<std::runtime_error>(data_->rows()==test.rows(),"KNN::predict different number of rows in training and test data");
305
306    utility::Matrix* distances=calculate_distances(test);
307   
308    prediction.resize(target_->nof_classes(),test.columns(),0.0);
309    for(size_t sample=0;sample<distances->columns();sample++) {
310      std::vector<size_t> k_index;
311      utility::VectorConstView dist=distances->column_const_view(sample);
312      utility::sort_smallest_index(k_index,k_,dist);
313      utility::VectorView pred=prediction.column_view(sample);
314      weighting_(dist,k_index,*target_,pred);
315    }
316    delete distances;
317
318    // classes for which there are no training samples should be set
319    // to nan in the predictions
320    for(size_t c=0;c<target_->nof_classes(); c++) 
321      if(!target_->size(c)) 
322        for(size_t j=0;j<prediction.columns();j++)
323          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
324  }
325
326}}} // of namespace classifier, yat, and theplu
327
328#endif
329
Note: See TracBrowser for help on using the repository browser.