source: trunk/yat/classifier/KNN.h @ 1162

Last change on this file since 1162 was 1162, checked in by Peter, 14 years ago

removing trained_

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.8 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1162 2008-02-26 16:24:11Z peter $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "KNN_Uniform.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "SupervisedClassifier.h"
33#include "Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/yat_assert.h"
36
37#include <cmath>
38#include <map>
39#include <stdexcept>
40
41namespace theplu {
42namespace yat {
43namespace classifier {
44
45  ///
46  /// @brief Class for Nearest Neigbor Classification.
47  ///
48  /// The template argument Distance should be a class modelling
49  /// the concept \ref concept_distance.
50  /// The template argument NeigborWeighting should be a class modelling
51  /// the concept \ref concept_neighbor_weighting.
52
53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
54  class KNN : public SupervisedClassifier
55  {
56   
57  public:
58    ///
59    /// @brief Constructor
60    ///
61    KNN(void);
62
63
64    ///
65    /// @brief Constructor
66    ///
67    KNN(const Distance&);
68
69
70    ///
71    /// @brief Destructor
72    ///
73    virtual ~KNN();
74   
75
76    ///
77    /// Default number of neighbors (k) is set to 3.
78    ///
79    /// @return the number of neighbors
80    ///
81    u_int k() const;
82
83    ///
84    /// @brief sets the number of neighbors, k.
85    ///
86    void k(u_int k_in);
87
88
89    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
90   
91    ///
92    /// Train the classifier using training data and target.
93    ///
94    /// If the number of training samples set is smaller than \a k_in,
95    /// k is set to the number of training samples.
96    ///
97    void train(const MatrixLookup&, const Target&);
98
99    ///
100    /// Train the classifier using weighted training data and target.
101    ///
102    void train(const MatrixLookupWeighted&, const Target&);
103
104   
105    ///
106    /// For each sample, calculate the number of neighbors for each
107    /// class.
108    ///
109    void predict(const MatrixLookup&, utility::Matrix&) const;
110
111    ///
112    /// For each sample, calculate the number of neighbors for each
113    /// class.
114    ///
115    void predict(const MatrixLookupWeighted&, utility::Matrix&) const;
116
117
118  private:
119
120    const MatrixLookup* data_ml_;
121    const MatrixLookupWeighted* data_mlw_;
122    const Target* target_;
123
124    // The number of neighbors
125    u_int k_;
126
127    Distance distance_;
128
129    NeighborWeighting weighting_;
130
131    void calculate_unweighted(const MatrixLookup&,
132                              const MatrixLookup&,
133                              utility::Matrix*) const;
134    void calculate_weighted(const MatrixLookupWeighted&,
135                            const MatrixLookupWeighted&,
136                            utility::Matrix*) const;
137
138    void predict_common(const utility::Matrix& distances, 
139                        utility::Matrix& prediction) const;
140
141  };
142 
143 
144  // templates
145 
146  template <typename Distance, typename NeighborWeighting>
147  KNN<Distance, NeighborWeighting>::KNN() 
148    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
149  {
150  }
151
152  template <typename Distance, typename NeighborWeighting>
153  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
154    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
155  {
156  }
157
158 
159  template <typename Distance, typename NeighborWeighting>
160  KNN<Distance, NeighborWeighting>::~KNN()   
161  {
162  }
163 
164
165  template <typename Distance, typename NeighborWeighting>
166  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
167  (const MatrixLookup& training, const MatrixLookup& test,
168   utility::Matrix* distances) const
169  {
170    for(size_t i=0; i<training.columns(); i++) {
171      for(size_t j=0; j<test.columns(); j++) {
172        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
173                                      test.begin_column(j));
174        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
175      }
176    }
177  }
178
179 
180  template <typename Distance, typename NeighborWeighting>
181  void 
182  KNN<Distance, NeighborWeighting>::calculate_weighted
183  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
184   utility::Matrix* distances) const
185  {
186    for(size_t i=0; i<training.columns(); i++) { 
187      for(size_t j=0; j<test.columns(); j++) {
188        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i), 
189                                      test.begin_column(j));
190        // If the distance is NaN (no common variables with non-zero weights),
191        // the distance is set to infinity to be sorted as a neighbor at the end
192        if(std::isnan((*distances)(i,j))) 
193          (*distances)(i,j)=std::numeric_limits<double>::infinity();
194      }
195    }
196  }
197 
198 
199  template <typename Distance, typename NeighborWeighting>
200  u_int KNN<Distance, NeighborWeighting>::k() const
201  {
202    return k_;
203  }
204
205  template <typename Distance, typename NeighborWeighting>
206  void KNN<Distance, NeighborWeighting>::k(u_int k)
207  {
208    k_=k;
209  }
210
211
212  template <typename Distance, typename NeighborWeighting>
213  KNN<Distance, NeighborWeighting>* 
214  KNN<Distance, NeighborWeighting>::make_classifier() const 
215  {     
216    KNN* knn=new KNN<Distance, NeighborWeighting>();
217    knn->k(this->k());
218    return knn;
219  }
220 
221 
222  template <typename Distance, typename NeighborWeighting>
223  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
224                                               const Target& target)
225  {   
226    utility::yat_assert<std::runtime_error>
227      (data.columns()==target.size(),
228       "KNN::train called with different sizes of target and data");
229    // k has to be at most the number of training samples.
230    if(data.columns()<k_) 
231      k_=data.columns();
232    data_ml_=&data;
233    data_mlw_=0;
234    target_=&target;
235  }
236
237  template <typename Distance, typename NeighborWeighting>
238  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
239                                               const Target& target)
240  {   
241    utility::yat_assert<std::runtime_error>
242      (data.columns()==target.size(),
243       "KNN::train called with different sizes of target and data");
244    // k has to be at most the number of training samples.
245    if(data.columns()<k_) 
246      k_=data.columns();
247    data_ml_=0;
248    data_mlw_=&data;
249    target_=&target;
250  }
251
252
253  template <typename Distance, typename NeighborWeighting>
254  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
255                                                 utility::Matrix& prediction) const
256  {   
257    // matrix with training samples as rows and test samples as columns
258    utility::Matrix* distances = 0;
259    // unweighted training data
260    if(data_ml_ && !data_mlw_) {
261      utility::yat_assert<std::runtime_error>
262        (data_ml_->rows()==test.rows(),
263         "KNN::predict different number of rows in training and test data");     
264      distances=new utility::Matrix(data_ml_->columns(),test.columns());
265      calculate_unweighted(*data_ml_,test,distances);
266    }
267    else if (data_mlw_ && !data_ml_) {
268      // weighted training data
269      utility::yat_assert<std::runtime_error>
270        (data_mlw_->rows()==test.rows(),
271         "KNN::predict different number of rows in training and test data");           
272      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
273      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
274                         distances);             
275    }
276    else {
277      std::runtime_error("KNN::predict no training data");
278    }
279
280    prediction.resize(target_->nof_classes(),test.columns(),0.0);
281    predict_common(*distances,prediction);
282    if(distances)
283      delete distances;
284  }
285
286  template <typename Distance, typename NeighborWeighting>
287  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
288                                                 utility::Matrix& prediction) const
289  {   
290    // matrix with training samples as rows and test samples as columns
291    utility::Matrix* distances=0; 
292    // unweighted training data
293    if(data_ml_ && !data_mlw_) { 
294      utility::yat_assert<std::runtime_error>
295        (data_ml_->rows()==test.rows(),
296         "KNN::predict different number of rows in training and test data");   
297      distances=new utility::Matrix(data_ml_->columns(),test.columns());
298      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
299    }
300    // weighted training data
301    else if (data_mlw_ && !data_ml_) {
302      utility::yat_assert<std::runtime_error>
303        (data_mlw_->rows()==test.rows(),
304         "KNN::predict different number of rows in training and test data");   
305      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
306      calculate_weighted(*data_mlw_,test,distances);             
307    }
308    else {
309      std::runtime_error("KNN::predict no training data");
310    }
311
312    prediction.resize(target_->nof_classes(),test.columns(),0.0);
313    predict_common(*distances,prediction);
314   
315    if(distances)
316      delete distances;
317  }
318 
319  template <typename Distance, typename NeighborWeighting>
320  void KNN<Distance, NeighborWeighting>::predict_common
321  (const utility::Matrix& distances, utility::Matrix& prediction) const
322  {   
323    for(size_t sample=0;sample<distances.columns();sample++) {
324      std::vector<size_t> k_index;
325      utility::VectorConstView dist=distances.column_const_view(sample);
326      utility::sort_smallest_index(k_index,k_,dist);
327      utility::VectorView pred=prediction.column_view(sample);
328      weighting_(dist,k_index,*target_,pred);
329    }
330   
331    // classes for which there are no training samples should be set
332    // to nan in the predictions
333    for(size_t c=0;c<target_->nof_classes(); c++) 
334      if(!target_->size(c)) 
335        for(size_t j=0;j<prediction.columns();j++)
336          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
337  }
338
339
340}}} // of namespace classifier, yat, and theplu
341
342#endif
343
Note: See TracBrowser for help on using the repository browser.