source: trunk/yat/classifier/KNN.h @ 1121

Last change on this file since 1121 was 1121, checked in by Peter, 14 years ago

fixes #308

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.1 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1121 2008-02-22 15:29:56Z peter $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// Constructor taking the training data and the target   
60    /// as input.
61    ///
62    KNN(const MatrixLookup&, const Target&);
63
64
65    ///
66    /// Constructor taking the training data with weights and the
67    /// target as input.
68    ///
69    KNN(const MatrixLookupWeighted&, const Target&);
70
71    virtual ~KNN();
72   
73    //
74    // @return the training data
75    //
76    const DataLookup2D& data(void) const;
77
78
79    ///
80    /// Default number of neighbors (k) is set to 3.
81    ///
82    /// @return the number of neighbors
83    ///
84    u_int k() const;
85
86    ///
87    /// @brief sets the number of neighbors, k.
88    ///
89    void k(u_int);
90
91
92    SupervisedClassifier* make_classifier(const DataLookup2D&, 
93                                          const Target&) const;
94   
95    ///
96    /// Train the classifier using the training data.
97    /// This function does nothing but is required by the interface.
98    ///
99    /// @return true if training succedeed.
100    ///
101    void train();
102
103   
104    ///
105    /// For each sample, calculate the number of neighbors for each
106    /// class.
107    ///
108    ///
109    void predict(const DataLookup2D&, utility::Matrix&) const;
110
111
112  private:
113
114    // data_ has to be of type DataLookup2D to accomodate both
115    // MatrixLookup and MatrixLookupWeighted
116    const DataLookup2D& data_;
117
118    // The number of neighbors
119    u_int k_;
120
121    Distance distance_;
122
123    NeighborWeighting weighting_;
124
125    ///
126    /// Calculates the distances between a data set and the training
127    /// data. The rows are training and the columns test samples,
128    /// respectively. The returned distance matrix is dynamically
129    /// generated and needs to be deleted by the caller.
130    ///
131    utility::Matrix* calculate_distances(const DataLookup2D&) const;
132
133    void calculate_unweighted(const MatrixLookup&,
134                              const MatrixLookup&,
135                              utility::Matrix*) const;
136    void calculate_weighted(const MatrixLookupWeighted&,
137                            const MatrixLookupWeighted&,
138                            utility::Matrix*) const;
139  };
140 
141 
142  // templates
143 
144  template <typename Distance, typename NeighborWeighting>
145  KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target) 
146    : SupervisedClassifier(target), data_(data),k_(3)
147  {
148  }
149
150
151  template <typename Distance, typename NeighborWeighting>
152  KNN<Distance, NeighborWeighting>::KNN
153  (const MatrixLookupWeighted& data, const Target& target) 
154    : SupervisedClassifier(target), data_(data),k_(3)
155  {
156  }
157 
158  template <typename Distance, typename NeighborWeighting>
159  KNN<Distance, NeighborWeighting>::~KNN()   
160  {
161  }
162 
163  template <typename Distance, typename NeighborWeighting>
164  utility::Matrix* KNN<Distance, NeighborWeighting>::calculate_distances
165  (const DataLookup2D& test) const
166  {
167    // matrix with training samples as rows and test samples as columns
168    utility::Matrix* distances = 
169      new utility::Matrix(data_.columns(),test.columns());
170   
171   
172    // unweighted test data
173    if(const MatrixLookup* test_unweighted = 
174       dynamic_cast<const MatrixLookup*>(&test)) {     
175      // unweighted training data
176      if(const MatrixLookup* training_unweighted = 
177         dynamic_cast<const MatrixLookup*>(&data_)) 
178        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
179      // weighted training data
180      else if(const MatrixLookupWeighted* training_weighted = 
181              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
182        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
183                           distances);             
184      // Training data can not be of incorrect type
185    }
186    // weighted test data
187    else if (const MatrixLookupWeighted* test_weighted = 
188             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
189      // unweighted training data
190      if(const MatrixLookup* training_unweighted = 
191         dynamic_cast<const MatrixLookup*>(&data_)) {
192        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
193                           *test_weighted,distances);
194      }
195      // weighted training data
196      else if(const MatrixLookupWeighted* training_weighted = 
197              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
198        calculate_weighted(*training_weighted,*test_weighted,distances);             
199      // Training data can not be of incorrect type
200    } 
201    else {
202      std::string str;
203      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
204      throw std::runtime_error(str);
205    }
206    return distances;
207  }
208
209  template <typename Distance, typename NeighborWeighting>
210  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
211  (const MatrixLookup& training, const MatrixLookup& test,
212   utility::Matrix* distances) const
213  {
214    for(size_t i=0; i<training.columns(); i++) {
215      classifier::DataLookup1D training1(training,i,false);
216      for(size_t j=0; j<test.columns(); j++) {
217        classifier::DataLookup1D test1(test,j,false);
218        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
219        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
220      }
221    }
222  }
223 
224  template <typename Distance, typename NeighborWeighting>
225  void 
226  KNN<Distance, NeighborWeighting>::calculate_weighted
227  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
228   utility::Matrix* distances) const
229  {
230    for(size_t i=0; i<training.columns(); i++) {
231      classifier::DataLookupWeighted1D training1(training,i,false);
232      for(size_t j=0; j<test.columns(); j++) {
233        classifier::DataLookupWeighted1D test1(test,j,false);
234        (*distances)(i,j) = distance_(training1.begin(), training1.end(), 
235                                      test1.begin());
236        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
237      }
238    }
239  }
240
241 
242  template <typename Distance, typename NeighborWeighting>
243  const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const
244  {
245    return data_;
246  }
247 
248 
249  template <typename Distance, typename NeighborWeighting>
250  u_int KNN<Distance, NeighborWeighting>::k() const
251  {
252    return k_;
253  }
254
255  template <typename Distance, typename NeighborWeighting>
256  void KNN<Distance, NeighborWeighting>::k(u_int k)
257  {
258    k_=k;
259  }
260
261
262  template <typename Distance, typename NeighborWeighting>
263  SupervisedClassifier* 
264  KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data, 
265                                                    const Target& target) const 
266  {     
267    KNN* knn=0;
268    try {
269      if(data.weighted()) {
270        knn=new KNN<Distance, NeighborWeighting>
271          (dynamic_cast<const MatrixLookupWeighted&>(data),target);
272      } 
273      else {
274        knn=new KNN<Distance, NeighborWeighting>
275          (dynamic_cast<const MatrixLookup&>(data),target);
276      }
277      knn->k(this->k());
278    }
279    catch (std::bad_cast) {
280      std::string str = "Error in KNN<Distance, NeighborWeighting>"; 
281      str += "::make_classifier: DataLookup2D of unexpected class.";
282      throw std::runtime_error(str);
283    }
284    return knn;
285  }
286 
287 
288  template <typename Distance, typename NeighborWeighting>
289  void KNN<Distance, NeighborWeighting>::train()
290  {   
291    trained_=true;
292  }
293
294
295  template <typename Distance, typename NeighborWeighting>
296  void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test,
297                                                 utility::Matrix& prediction) const
298  {   
299    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
300
301    utility::Matrix* distances=calculate_distances(test);
302   
303    prediction.resize(target_.nof_classes(),test.columns(),0.0);
304    for(size_t sample=0;sample<distances->columns();sample++) {
305      std::vector<size_t> k_index;
306      utility::VectorConstView dist=distances->column_const_view(sample);
307      utility::sort_smallest_index(k_index,k_,dist);
308      utility::VectorView pred=prediction.column_view(sample);
309      weighting_(dist,k_index,target_,pred);
310    }
311    delete distances;
312  }
313
314}}} // of namespace classifier, yat, and theplu
315
316#endif
317
Note: See TracBrowser for help on using the repository browser.