source: trunk/yat/classifier/KNN.h @ 1156

Last change on this file since 1156 was 1156, checked in by Markus Ringnér, 14 years ago

Refs. #335, fixed for KNN

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 10.2 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1156 2008-02-26 08:46:49Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// Constructor taking the training data and the target   
60    /// as input.
61    ///
62    KNN(const MatrixLookup&, const Target&);
63
64
65    ///
66    /// Constructor taking the training data with weights and the
67    /// target as input.
68    ///
69    KNN(const MatrixLookupWeighted&, const Target&);
70
71    virtual ~KNN();
72   
73    //
74    // @return the training data
75    //
76    const DataLookup2D& data(void) const;
77
78
79    ///
80    /// Default number of neighbors (k) is set to 3.
81    ///
82    /// @return the number of neighbors
83    ///
84    u_int k() const;
85
86    ///
87    /// @brief sets the number of neighbors, k. If the number of
88    /// training samples set is smaller than \a k_in, k is set to the number of
89    /// training samples.
90    ///
91    void k(u_int k_in);
92
93
94    KNN<Distance,NeighborWeighting>* make_classifier(const DataLookup2D&, 
95                         const Target&) const;
96   
97    ///
98    /// Train the classifier using the training data.
99    /// This function does nothing but is required by the interface.
100    ///
101    void train();
102
103   
104    ///
105    /// For each sample, calculate the number of neighbors for each
106    /// class.
107    ///
108    ///
109    void predict(const DataLookup2D&, utility::Matrix&) const;
110
111
112  private:
113
114    // data_ has to be of type DataLookup2D to accomodate both
115    // MatrixLookup and MatrixLookupWeighted
116    const DataLookup2D& data_;
117
118    // The number of neighbors
119    u_int k_;
120
121    Distance distance_;
122
123    NeighborWeighting weighting_;
124
125    ///
126    /// Calculates the distances between a data set and the training
127    /// data. The rows are training and the columns test samples,
128    /// respectively. The returned distance matrix is dynamically
129    /// generated and needs to be deleted by the caller.
130    ///
131    utility::Matrix* calculate_distances(const DataLookup2D&) const;
132
133    void calculate_unweighted(const MatrixLookup&,
134                              const MatrixLookup&,
135                              utility::Matrix*) const;
136    void calculate_weighted(const MatrixLookupWeighted&,
137                            const MatrixLookupWeighted&,
138                            utility::Matrix*) const;
139  };
140 
141 
142  // templates
143 
144  template <typename Distance, typename NeighborWeighting>
145  KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target) 
146    : SupervisedClassifier(target), data_(data),k_(3)
147  {
148    utility::yat_assert<std::runtime_error>
149      (data.columns()==target.size(),
150       "KNN::KNN called with different sizes of target and data");
151    // k has to be at most the number of training samples.
152    if(data_.columns()>k_) 
153      k_=data_.columns();
154  }
155
156
157  template <typename Distance, typename NeighborWeighting>
158  KNN<Distance, NeighborWeighting>::KNN
159  (const MatrixLookupWeighted& data, const Target& target) 
160    : SupervisedClassifier(target), data_(data),k_(3)
161  {
162    utility::yat_assert<std::runtime_error>
163      (data.columns()==target.size(),
164       "KNN::KNN called with different sizes of target and data");
165    if(data_.columns()>k_) 
166      k_=data_.columns();
167  }
168 
169  template <typename Distance, typename NeighborWeighting>
170  KNN<Distance, NeighborWeighting>::~KNN()   
171  {
172  }
173 
174  template <typename Distance, typename NeighborWeighting>
175  utility::Matrix* KNN<Distance, NeighborWeighting>::calculate_distances
176  (const DataLookup2D& test) const
177  {
178    // matrix with training samples as rows and test samples as columns
179    utility::Matrix* distances = 
180      new utility::Matrix(data_.columns(),test.columns());
181   
182   
183    // unweighted test data
184    if(const MatrixLookup* test_unweighted = 
185       dynamic_cast<const MatrixLookup*>(&test)) {     
186      // unweighted training data
187      if(const MatrixLookup* training_unweighted = 
188         dynamic_cast<const MatrixLookup*>(&data_)) 
189        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
190      // weighted training data
191      else if(const MatrixLookupWeighted* training_weighted = 
192              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
193        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
194                           distances);             
195      // Training data can not be of incorrect type
196    }
197    // weighted test data
198    else if (const MatrixLookupWeighted* test_weighted = 
199             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
200      // unweighted training data
201      if(const MatrixLookup* training_unweighted = 
202         dynamic_cast<const MatrixLookup*>(&data_)) {
203        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
204                           *test_weighted,distances);
205      }
206      // weighted training data
207      else if(const MatrixLookupWeighted* training_weighted = 
208              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
209        calculate_weighted(*training_weighted,*test_weighted,distances);             
210      // Training data can not be of incorrect type
211    } 
212    else {
213      std::string str;
214      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
215      throw std::runtime_error(str);
216    }
217    return distances;
218  }
219
220  template <typename Distance, typename NeighborWeighting>
221  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
222  (const MatrixLookup& training, const MatrixLookup& test,
223   utility::Matrix* distances) const
224  {
225    for(size_t i=0; i<training.columns(); i++) {
226      classifier::DataLookup1D training1(training,i,false);
227      for(size_t j=0; j<test.columns(); j++) {
228        classifier::DataLookup1D test1(test,j,false);
229        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
230        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
231      }
232    }
233  }
234 
235  template <typename Distance, typename NeighborWeighting>
236  void 
237  KNN<Distance, NeighborWeighting>::calculate_weighted
238  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
239   utility::Matrix* distances) const
240  {
241    for(size_t i=0; i<training.columns(); i++) {
242      classifier::DataLookupWeighted1D training1(training,i,false);
243      for(size_t j=0; j<test.columns(); j++) {
244        classifier::DataLookupWeighted1D test1(test,j,false);
245        (*distances)(i,j) = distance_(training1.begin(), training1.end(), 
246                                      test1.begin());
247        // If the distance is NaN (no common variables with non-zero weights),
248        // the distance is set to infinity to be sorted as a neighbor at the end
249        if(std::isnan((*distances)(i,j))) 
250          (*distances)(i,j)=std::numeric_limits<double>::infinity();
251      }
252    }
253  }
254
255 
256  template <typename Distance, typename NeighborWeighting>
257  const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const
258  {
259    return data_;
260  }
261 
262 
263  template <typename Distance, typename NeighborWeighting>
264  u_int KNN<Distance, NeighborWeighting>::k() const
265  {
266    return k_;
267  }
268
269  template <typename Distance, typename NeighborWeighting>
270  void KNN<Distance, NeighborWeighting>::k(u_int k)
271  {
272    k_=k;
273    if(k_>data_.columns())
274      k_=data_.columns();
275  }
276
277
278  template <typename Distance, typename NeighborWeighting>
279  KNN<Distance, NeighborWeighting>* 
280  KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data, 
281                                                    const Target& target) const 
282  {     
283    KNN* knn=0;
284    try {
285      if(data.weighted()) {
286        knn=new KNN<Distance, NeighborWeighting>
287          (dynamic_cast<const MatrixLookupWeighted&>(data),target);
288      } 
289      else {
290        knn=new KNN<Distance, NeighborWeighting>
291          (dynamic_cast<const MatrixLookup&>(data),target);
292      }
293      knn->k(this->k());
294    }
295    catch (std::bad_cast) {
296      std::string str = "Error in KNN<Distance, NeighborWeighting>"; 
297      str += "::make_classifier: DataLookup2D of unexpected class.";
298      throw std::runtime_error(str);
299    }
300    return knn;
301  }
302 
303 
304  template <typename Distance, typename NeighborWeighting>
305  void KNN<Distance, NeighborWeighting>::train()
306  {   
307    trained_=true;
308  }
309
310
311  template <typename Distance, typename NeighborWeighting>
312  void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test,
313                                                 utility::Matrix& prediction) const
314  {   
315    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows(),"KNN::predict different number of rows in training and test data");
316
317    utility::Matrix* distances=calculate_distances(test);
318   
319    prediction.resize(target_.nof_classes(),test.columns(),0.0);
320    for(size_t sample=0;sample<distances->columns();sample++) {
321      std::vector<size_t> k_index;
322      utility::VectorConstView dist=distances->column_const_view(sample);
323      utility::sort_smallest_index(k_index,k_,dist);
324      utility::VectorView pred=prediction.column_view(sample);
325      weighting_(dist,k_index,target_,pred);
326    }
327    delete distances;
328
329    // classes for which there are no training samples should be set
330    // to nan in the predictions
331    for(size_t c=0;c<target_.nof_classes(); c++) 
332      if(!target_.size(c)) 
333        for(size_t j=0;j<prediction.columns();j++)
334          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
335  }
336
337}}} // of namespace classifier, yat, and theplu
338
339#endif
340
Note: See TracBrowser for help on using the repository browser.