source: trunk/yat/classifier/KNN.h @ 1144

Last change on this file since 1144 was 1144, checked in by Markus Ringnér, 14 years ago

Fixes #334

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.9 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1144 2008-02-25 16:51:58Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// Constructor taking the training data and the target   
60    /// as input.
61    ///
62    KNN(const MatrixLookup&, const Target&);
63
64
65    ///
66    /// Constructor taking the training data with weights and the
67    /// target as input.
68    ///
69    KNN(const MatrixLookupWeighted&, const Target&);
70
71    virtual ~KNN();
72   
73    //
74    // @return the training data
75    //
76    const DataLookup2D& data(void) const;
77
78
79    ///
80    /// Default number of neighbors (k) is set to 3.
81    ///
82    /// @return the number of neighbors
83    ///
84    u_int k() const;
85
86    ///
87    /// @brief sets the number of neighbors, k. If the number of
88    /// training samples set is smaller than \a k_in, k is set to the number of
89    /// training samples.
90    ///
91    void k(u_int k_in);
92
93
94    KNN<Distance,NeighborWeighting>* make_classifier(const DataLookup2D&, 
95                         const Target&) const;
96   
97    ///
98    /// Train the classifier using the training data.
99    /// This function does nothing but is required by the interface.
100    ///
101    void train();
102
103   
104    ///
105    /// For each sample, calculate the number of neighbors for each
106    /// class.
107    ///
108    ///
109    void predict(const DataLookup2D&, utility::Matrix&) const;
110
111
112  private:
113
114    // data_ has to be of type DataLookup2D to accomodate both
115    // MatrixLookup and MatrixLookupWeighted
116    const DataLookup2D& data_;
117
118    // The number of neighbors
119    u_int k_;
120
121    Distance distance_;
122
123    NeighborWeighting weighting_;
124
125    ///
126    /// Calculates the distances between a data set and the training
127    /// data. The rows are training and the columns test samples,
128    /// respectively. The returned distance matrix is dynamically
129    /// generated and needs to be deleted by the caller.
130    ///
131    utility::Matrix* calculate_distances(const DataLookup2D&) const;
132
133    void calculate_unweighted(const MatrixLookup&,
134                              const MatrixLookup&,
135                              utility::Matrix*) const;
136    void calculate_weighted(const MatrixLookupWeighted&,
137                            const MatrixLookupWeighted&,
138                            utility::Matrix*) const;
139  };
140 
141 
142  // templates
143 
144  template <typename Distance, typename NeighborWeighting>
145  KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target) 
146    : SupervisedClassifier(target), data_(data),k_(3)
147  {
148    utility::yat_assert<std::runtime_error>
149      (data.columns()==target.size(),
150       "KNN::KNN called with different sizes of target and data");
151    // k has to be at most the number of training samples.
152    if(data_.columns()>k_) 
153      k_=data_.columns();
154  }
155
156
157  template <typename Distance, typename NeighborWeighting>
158  KNN<Distance, NeighborWeighting>::KNN
159  (const MatrixLookupWeighted& data, const Target& target) 
160    : SupervisedClassifier(target), data_(data),k_(3)
161  {
162    utility::yat_assert<std::runtime_error>
163      (data.columns()==target.size(),
164       "KNN::KNN called with different sizes of target and data");
165    if(data_.columns()>k_) 
166      k_=data_.columns();
167  }
168 
169  template <typename Distance, typename NeighborWeighting>
170  KNN<Distance, NeighborWeighting>::~KNN()   
171  {
172  }
173 
174  template <typename Distance, typename NeighborWeighting>
175  utility::Matrix* KNN<Distance, NeighborWeighting>::calculate_distances
176  (const DataLookup2D& test) const
177  {
178    // matrix with training samples as rows and test samples as columns
179    utility::Matrix* distances = 
180      new utility::Matrix(data_.columns(),test.columns());
181   
182   
183    // unweighted test data
184    if(const MatrixLookup* test_unweighted = 
185       dynamic_cast<const MatrixLookup*>(&test)) {     
186      // unweighted training data
187      if(const MatrixLookup* training_unweighted = 
188         dynamic_cast<const MatrixLookup*>(&data_)) 
189        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
190      // weighted training data
191      else if(const MatrixLookupWeighted* training_weighted = 
192              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
193        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
194                           distances);             
195      // Training data can not be of incorrect type
196    }
197    // weighted test data
198    else if (const MatrixLookupWeighted* test_weighted = 
199             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
200      // unweighted training data
201      if(const MatrixLookup* training_unweighted = 
202         dynamic_cast<const MatrixLookup*>(&data_)) {
203        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
204                           *test_weighted,distances);
205      }
206      // weighted training data
207      else if(const MatrixLookupWeighted* training_weighted = 
208              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
209        calculate_weighted(*training_weighted,*test_weighted,distances);             
210      // Training data can not be of incorrect type
211    } 
212    else {
213      std::string str;
214      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
215      throw std::runtime_error(str);
216    }
217    return distances;
218  }
219
220  template <typename Distance, typename NeighborWeighting>
221  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
222  (const MatrixLookup& training, const MatrixLookup& test,
223   utility::Matrix* distances) const
224  {
225    for(size_t i=0; i<training.columns(); i++) {
226      classifier::DataLookup1D training1(training,i,false);
227      for(size_t j=0; j<test.columns(); j++) {
228        classifier::DataLookup1D test1(test,j,false);
229        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
230        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
231      }
232    }
233  }
234 
235  template <typename Distance, typename NeighborWeighting>
236  void 
237  KNN<Distance, NeighborWeighting>::calculate_weighted
238  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
239   utility::Matrix* distances) const
240  {
241    for(size_t i=0; i<training.columns(); i++) {
242      classifier::DataLookupWeighted1D training1(training,i,false);
243      for(size_t j=0; j<test.columns(); j++) {
244        classifier::DataLookupWeighted1D test1(test,j,false);
245        (*distances)(i,j) = distance_(training1.begin(), training1.end(), 
246                                      test1.begin());
247      }
248    }
249  }
250
251 
252  template <typename Distance, typename NeighborWeighting>
253  const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const
254  {
255    return data_;
256  }
257 
258 
259  template <typename Distance, typename NeighborWeighting>
260  u_int KNN<Distance, NeighborWeighting>::k() const
261  {
262    return k_;
263  }
264
265  template <typename Distance, typename NeighborWeighting>
266  void KNN<Distance, NeighborWeighting>::k(u_int k)
267  {
268    k_=k;
269    if(k_>data_.columns())
270      k_=data_.columns();
271  }
272
273
274  template <typename Distance, typename NeighborWeighting>
275  KNN<Distance, NeighborWeighting>* 
276  KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data, 
277                                                    const Target& target) const 
278  {     
279    KNN* knn=0;
280    try {
281      if(data.weighted()) {
282        knn=new KNN<Distance, NeighborWeighting>
283          (dynamic_cast<const MatrixLookupWeighted&>(data),target);
284      } 
285      else {
286        knn=new KNN<Distance, NeighborWeighting>
287          (dynamic_cast<const MatrixLookup&>(data),target);
288      }
289      knn->k(this->k());
290    }
291    catch (std::bad_cast) {
292      std::string str = "Error in KNN<Distance, NeighborWeighting>"; 
293      str += "::make_classifier: DataLookup2D of unexpected class.";
294      throw std::runtime_error(str);
295    }
296    return knn;
297  }
298 
299 
300  template <typename Distance, typename NeighborWeighting>
301  void KNN<Distance, NeighborWeighting>::train()
302  {   
303    trained_=true;
304  }
305
306
307  template <typename Distance, typename NeighborWeighting>
308  void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test,
309                                                 utility::Matrix& prediction) const
310  {   
311    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows(),"KNN::predict different number of rows in training and test data");
312
313    utility::Matrix* distances=calculate_distances(test);
314   
315    prediction.resize(target_.nof_classes(),test.columns(),0.0);
316    for(size_t sample=0;sample<distances->columns();sample++) {
317      std::vector<size_t> k_index;
318      utility::VectorConstView dist=distances->column_const_view(sample);
319      utility::sort_smallest_index(k_index,k_,dist);
320      utility::VectorView pred=prediction.column_view(sample);
321      weighting_(dist,k_index,target_,pred);
322    }
323    delete distances;
324
325    // classes for which there are no training samples should be set
326    // to nan in the predictions
327    for(size_t c=0;c<target_.nof_classes(); c++) 
328      if(!target_.size(c)) 
329        for(size_t j=0;j<prediction.columns();j++)
330          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
331  }
332
333}}} // of namespace classifier, yat, and theplu
334
335#endif
336
Note: See TracBrowser for help on using the repository browser.