source: trunk/yat/classifier/KNN.h @ 1157

Last change on this file since 1157 was 1157, checked in by Markus Ringnér, 14 years ago

Refs #318

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.5 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1157 2008-02-26 13:25:19Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// @brief Constructor
60    ///
61    KNN(void);
62
63
64    ///
65    /// @brief Destructor
66    ///
67    virtual ~KNN();
68   
69
70    ///
71    /// Default number of neighbors (k) is set to 3.
72    ///
73    /// @return the number of neighbors
74    ///
75    u_int k() const;
76
77    ///
78    /// @brief sets the number of neighbors, k.
79    ///
80    void k(u_int k_in);
81
82
83    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
84   
85    ///
86    /// Train the classifier using training data and target.
87    ///
88    /// If the number of training samples set is smaller than \a k_in,
89    /// k is set to the number of training samples.
90    ///
91    void train(const MatrixLookup&, const Target&);
92
93    ///
94    /// Train the classifier using weighted training data and target.
95    ///
96    void train(const MatrixLookupWeighted&, const Target&);
97
98   
99    ///
100    /// For each sample, calculate the number of neighbors for each
101    /// class.
102    ///
103    ///
104    void predict(const DataLookup2D&, utility::Matrix&) const;
105
106
107  private:
108
109    // data_ has to be of type DataLookup2D to accomodate both
110    // MatrixLookup and MatrixLookupWeighted
111    const DataLookup2D* data_;
112    const Target* target_;
113
114    // The number of neighbors
115    u_int k_;
116
117    Distance distance_;
118
119    NeighborWeighting weighting_;
120
121    ///
122    /// Calculates the distances between a data set and the training
123    /// data. The rows are training and the columns test samples,
124    /// respectively. The returned distance matrix is dynamically
125    /// generated and needs to be deleted by the caller.
126    ///
127    utility::Matrix* calculate_distances(const DataLookup2D&) const;
128
129    void calculate_unweighted(const MatrixLookup&,
130                              const MatrixLookup&,
131                              utility::Matrix*) const;
132    void calculate_weighted(const MatrixLookupWeighted&,
133                            const MatrixLookupWeighted&,
134                            utility::Matrix*) const;
135  };
136 
137 
138  // templates
139 
140  template <typename Distance, typename NeighborWeighting>
141  KNN<Distance, NeighborWeighting>::KNN() 
142    : SupervisedClassifier(),data_(0),target_(0),k_(3)
143  {
144  }
145
146 
147  template <typename Distance, typename NeighborWeighting>
148  KNN<Distance, NeighborWeighting>::~KNN()   
149  {
150  }
151 
152  template <typename Distance, typename NeighborWeighting>
153  utility::Matrix* KNN<Distance, NeighborWeighting>::calculate_distances
154  (const DataLookup2D& test) const
155  {
156    // matrix with training samples as rows and test samples as columns
157    utility::Matrix* distances = 
158      new utility::Matrix(data_->columns(),test.columns());
159   
160   
161    // unweighted test data
162    if(const MatrixLookup* test_unweighted = 
163       dynamic_cast<const MatrixLookup*>(&test)) {     
164      // unweighted training data
165      if(const MatrixLookup* training_unweighted = 
166         dynamic_cast<const MatrixLookup*>(data_)) 
167        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
168      // weighted training data
169      else if(const MatrixLookupWeighted* training_weighted = 
170              dynamic_cast<const MatrixLookupWeighted*>(data_)) 
171        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
172                           distances);             
173      // Training data can not be of incorrect type
174    }
175    // weighted test data
176    else if (const MatrixLookupWeighted* test_weighted = 
177             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
178      // unweighted training data
179      if(const MatrixLookup* training_unweighted = 
180         dynamic_cast<const MatrixLookup*>(data_)) {
181        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
182                           *test_weighted,distances);
183      }
184      // weighted training data
185      else if(const MatrixLookupWeighted* training_weighted = 
186              dynamic_cast<const MatrixLookupWeighted*>(data_)) 
187        calculate_weighted(*training_weighted,*test_weighted,distances);             
188      // Training data can not be of incorrect type
189    } 
190    else {
191      std::string str;
192      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
193      throw std::runtime_error(str);
194    }
195    return distances;
196  }
197
198  template <typename Distance, typename NeighborWeighting>
199  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
200  (const MatrixLookup& training, const MatrixLookup& test,
201   utility::Matrix* distances) const
202  {
203    for(size_t i=0; i<training.columns(); i++) {
204      classifier::DataLookup1D training1(training,i,false);
205      for(size_t j=0; j<test.columns(); j++) {
206        classifier::DataLookup1D test1(test,j,false);
207        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
208        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
209      }
210    }
211  }
212 
213  template <typename Distance, typename NeighborWeighting>
214  void 
215  KNN<Distance, NeighborWeighting>::calculate_weighted
216  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
217   utility::Matrix* distances) const
218  {
219    for(size_t i=0; i<training.columns(); i++) {
220      classifier::DataLookupWeighted1D training1(training,i,false);
221      for(size_t j=0; j<test.columns(); j++) {
222        classifier::DataLookupWeighted1D test1(test,j,false);
223        (*distances)(i,j) = distance_(training1.begin(), training1.end(), 
224                                      test1.begin());
225        // If the distance is NaN (no common variables with non-zero weights),
226        // the distance is set to infinity to be sorted as a neighbor at the end
227        if(std::isnan((*distances)(i,j))) 
228          (*distances)(i,j)=std::numeric_limits<double>::infinity();
229      }
230    }
231  }
232 
233 
234  template <typename Distance, typename NeighborWeighting>
235  u_int KNN<Distance, NeighborWeighting>::k() const
236  {
237    return k_;
238  }
239
240  template <typename Distance, typename NeighborWeighting>
241  void KNN<Distance, NeighborWeighting>::k(u_int k)
242  {
243    k_=k;
244  }
245
246
247  template <typename Distance, typename NeighborWeighting>
248  KNN<Distance, NeighborWeighting>* 
249  KNN<Distance, NeighborWeighting>::make_classifier() const 
250  {     
251    KNN* knn=new KNN<Distance, NeighborWeighting>();
252    knn->k(this->k());
253    return knn;
254  }
255 
256 
257  template <typename Distance, typename NeighborWeighting>
258  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
259                                               const Target& target)
260  {   
261    utility::yat_assert<std::runtime_error>
262      (data.columns()==target.size(),
263       "KNN::train called with different sizes of target and data");
264    // k has to be at most the number of training samples.
265    if(data.columns()<k_) 
266      k_=data.columns();
267    data_=&data;
268    target_=&target;
269    trained_=true;
270  }
271
272  template <typename Distance, typename NeighborWeighting>
273  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
274                                               const Target& target)
275  {   
276    utility::yat_assert<std::runtime_error>
277      (data.columns()==target.size(),
278       "KNN::train called with different sizes of target and data");
279    // k has to be at most the number of training samples.
280    if(data.columns()<k_) 
281      k_=data.columns();
282    data_=&data;
283    target_=&target;
284    trained_=true;
285  }
286
287
288  template <typename Distance, typename NeighborWeighting>
289  void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test,
290                                                 utility::Matrix& prediction) const
291  {   
292    utility::yat_assert<std::runtime_error>(data_->rows()==test.rows(),"KNN::predict different number of rows in training and test data");
293
294    utility::Matrix* distances=calculate_distances(test);
295   
296    prediction.resize(target_->nof_classes(),test.columns(),0.0);
297    for(size_t sample=0;sample<distances->columns();sample++) {
298      std::vector<size_t> k_index;
299      utility::VectorConstView dist=distances->column_const_view(sample);
300      utility::sort_smallest_index(k_index,k_,dist);
301      utility::VectorView pred=prediction.column_view(sample);
302      weighting_(dist,k_index,*target_,pred);
303    }
304    delete distances;
305
306    // classes for which there are no training samples should be set
307    // to nan in the predictions
308    for(size_t c=0;c<target_->nof_classes(); c++) 
309      if(!target_->size(c)) 
310        for(size_t j=0;j<prediction.columns();j++)
311          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
312  }
313
314}}} // of namespace classifier, yat, and theplu
315
316#endif
317
Note: See TracBrowser for help on using the repository browser.