source: trunk/yat/classifier/KNN.h @ 1124

Last change on this file since 1124 was 1124, checked in by Peter, 14 years ago

train returns nothing, removed docs saying else

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.1 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1124 2008-02-22 18:48:31Z peter $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// Constructor taking the training data and the target   
60    /// as input.
61    ///
62    KNN(const MatrixLookup&, const Target&);
63
64
65    ///
66    /// Constructor taking the training data with weights and the
67    /// target as input.
68    ///
69    KNN(const MatrixLookupWeighted&, const Target&);
70
71    virtual ~KNN();
72   
73    //
74    // @return the training data
75    //
76    const DataLookup2D& data(void) const;
77
78
79    ///
80    /// Default number of neighbors (k) is set to 3.
81    ///
82    /// @return the number of neighbors
83    ///
84    u_int k() const;
85
86    ///
87    /// @brief sets the number of neighbors, k.
88    ///
89    void k(u_int);
90
91
92    SupervisedClassifier* make_classifier(const DataLookup2D&, 
93                                          const Target&) const;
94   
95    ///
96    /// Train the classifier using the training data.
97    /// This function does nothing but is required by the interface.
98    ///
99    void train();
100
101   
102    ///
103    /// For each sample, calculate the number of neighbors for each
104    /// class.
105    ///
106    ///
107    void predict(const DataLookup2D&, utility::Matrix&) const;
108
109
110  private:
111
112    // data_ has to be of type DataLookup2D to accomodate both
113    // MatrixLookup and MatrixLookupWeighted
114    const DataLookup2D& data_;
115
116    // The number of neighbors
117    u_int k_;
118
119    Distance distance_;
120
121    NeighborWeighting weighting_;
122
123    ///
124    /// Calculates the distances between a data set and the training
125    /// data. The rows are training and the columns test samples,
126    /// respectively. The returned distance matrix is dynamically
127    /// generated and needs to be deleted by the caller.
128    ///
129    utility::Matrix* calculate_distances(const DataLookup2D&) const;
130
131    void calculate_unweighted(const MatrixLookup&,
132                              const MatrixLookup&,
133                              utility::Matrix*) const;
134    void calculate_weighted(const MatrixLookupWeighted&,
135                            const MatrixLookupWeighted&,
136                            utility::Matrix*) const;
137  };
138 
139 
140  // templates
141 
142  template <typename Distance, typename NeighborWeighting>
143  KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target) 
144    : SupervisedClassifier(target), data_(data),k_(3)
145  {
146  }
147
148
149  template <typename Distance, typename NeighborWeighting>
150  KNN<Distance, NeighborWeighting>::KNN
151  (const MatrixLookupWeighted& data, const Target& target) 
152    : SupervisedClassifier(target), data_(data),k_(3)
153  {
154  }
155 
156  template <typename Distance, typename NeighborWeighting>
157  KNN<Distance, NeighborWeighting>::~KNN()   
158  {
159  }
160 
161  template <typename Distance, typename NeighborWeighting>
162  utility::Matrix* KNN<Distance, NeighborWeighting>::calculate_distances
163  (const DataLookup2D& test) const
164  {
165    // matrix with training samples as rows and test samples as columns
166    utility::Matrix* distances = 
167      new utility::Matrix(data_.columns(),test.columns());
168   
169   
170    // unweighted test data
171    if(const MatrixLookup* test_unweighted = 
172       dynamic_cast<const MatrixLookup*>(&test)) {     
173      // unweighted training data
174      if(const MatrixLookup* training_unweighted = 
175         dynamic_cast<const MatrixLookup*>(&data_)) 
176        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
177      // weighted training data
178      else if(const MatrixLookupWeighted* training_weighted = 
179              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
180        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
181                           distances);             
182      // Training data can not be of incorrect type
183    }
184    // weighted test data
185    else if (const MatrixLookupWeighted* test_weighted = 
186             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
187      // unweighted training data
188      if(const MatrixLookup* training_unweighted = 
189         dynamic_cast<const MatrixLookup*>(&data_)) {
190        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
191                           *test_weighted,distances);
192      }
193      // weighted training data
194      else if(const MatrixLookupWeighted* training_weighted = 
195              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
196        calculate_weighted(*training_weighted,*test_weighted,distances);             
197      // Training data can not be of incorrect type
198    } 
199    else {
200      std::string str;
201      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
202      throw std::runtime_error(str);
203    }
204    return distances;
205  }
206
207  template <typename Distance, typename NeighborWeighting>
208  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
209  (const MatrixLookup& training, const MatrixLookup& test,
210   utility::Matrix* distances) const
211  {
212    for(size_t i=0; i<training.columns(); i++) {
213      classifier::DataLookup1D training1(training,i,false);
214      for(size_t j=0; j<test.columns(); j++) {
215        classifier::DataLookup1D test1(test,j,false);
216        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
217        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
218      }
219    }
220  }
221 
222  template <typename Distance, typename NeighborWeighting>
223  void 
224  KNN<Distance, NeighborWeighting>::calculate_weighted
225  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
226   utility::Matrix* distances) const
227  {
228    for(size_t i=0; i<training.columns(); i++) {
229      classifier::DataLookupWeighted1D training1(training,i,false);
230      for(size_t j=0; j<test.columns(); j++) {
231        classifier::DataLookupWeighted1D test1(test,j,false);
232        (*distances)(i,j) = distance_(training1.begin(), training1.end(), 
233                                      test1.begin());
234        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
235      }
236    }
237  }
238
239 
240  template <typename Distance, typename NeighborWeighting>
241  const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const
242  {
243    return data_;
244  }
245 
246 
247  template <typename Distance, typename NeighborWeighting>
248  u_int KNN<Distance, NeighborWeighting>::k() const
249  {
250    return k_;
251  }
252
253  template <typename Distance, typename NeighborWeighting>
254  void KNN<Distance, NeighborWeighting>::k(u_int k)
255  {
256    k_=k;
257  }
258
259
260  template <typename Distance, typename NeighborWeighting>
261  SupervisedClassifier* 
262  KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data, 
263                                                    const Target& target) const 
264  {     
265    KNN* knn=0;
266    try {
267      if(data.weighted()) {
268        knn=new KNN<Distance, NeighborWeighting>
269          (dynamic_cast<const MatrixLookupWeighted&>(data),target);
270      } 
271      else {
272        knn=new KNN<Distance, NeighborWeighting>
273          (dynamic_cast<const MatrixLookup&>(data),target);
274      }
275      knn->k(this->k());
276    }
277    catch (std::bad_cast) {
278      std::string str = "Error in KNN<Distance, NeighborWeighting>"; 
279      str += "::make_classifier: DataLookup2D of unexpected class.";
280      throw std::runtime_error(str);
281    }
282    return knn;
283  }
284 
285 
286  template <typename Distance, typename NeighborWeighting>
287  void KNN<Distance, NeighborWeighting>::train()
288  {   
289    trained_=true;
290  }
291
292
293  template <typename Distance, typename NeighborWeighting>
294  void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test,
295                                                 utility::Matrix& prediction) const
296  {   
297    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
298
299    utility::Matrix* distances=calculate_distances(test);
300   
301    prediction.resize(target_.nof_classes(),test.columns(),0.0);
302    for(size_t sample=0;sample<distances->columns();sample++) {
303      std::vector<size_t> k_index;
304      utility::VectorConstView dist=distances->column_const_view(sample);
305      utility::sort_smallest_index(k_index,k_,dist);
306      utility::VectorView pred=prediction.column_view(sample);
307      weighting_(dist,k_index,target_,pred);
308    }
309    delete distances;
310  }
311
312}}} // of namespace classifier, yat, and theplu
313
314#endif
315
Note: See TracBrowser for help on using the repository browser.