source: trunk/yat/classifier/KNN.h @ 1107

Last change on this file since 1107 was 1107, checked in by Markus Ringnér, 14 years ago

Ticket #259 fixed for KNN

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 8.3 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1107 2008-02-19 15:23:52Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/matrix.h"
34#include "yat/utility/yat_assert.h"
35
36#include <cmath>
37#include <map>
38#include <stdexcept>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  ///
45  /// @brief Class for Nearest Centroid Classification.
46  ///
47 
48 
49  template <typename Distance>
50  class KNN : public SupervisedClassifier
51  {
52   
53  public:
54    ///
55    /// Constructor taking the training data and the target   
56    /// as input.
57    ///
58    KNN(const MatrixLookup&, const Target&);
59
60
61    ///
62    /// Constructor taking the training data with weights and the
63    /// target as input.
64    ///
65    KNN(const MatrixLookupWeighted&, const Target&);
66
67    virtual ~KNN();
68   
69    //
70    // @return the training data
71    //
72    const DataLookup2D& data(void) const;
73
74
75    ///
76    /// Default number of neighbours (k) is set to 3.
77    ///
78    /// @return the number of neighbours
79    ///
80    u_int k() const;
81
82    ///
83    /// @brief sets the number of neighbours, k.
84    ///
85    void k(u_int);
86
87
88    SupervisedClassifier* make_classifier(const DataLookup2D&, 
89                                          const Target&) const;
90   
91    ///
92    /// Train the classifier using the training data.
93    /// This function does nothing but is required by the interface.
94    ///
95    /// @return true if training succedeed.
96    ///
97    void train();
98
99   
100    ///
101    /// For each sample, calculate the number of neighbours for each
102    /// class.
103    ///
104    ///
105    void predict(const DataLookup2D&, utility::matrix&) const;
106
107
108  private:
109
110    // data_ has to be of type DataLookup2D to accomodate both
111    // MatrixLookup and MatrixLookupWeighted
112    const DataLookup2D& data_;
113
114    // The number of neighbours
115    u_int k_;
116
117    Distance distance_;
118    ///
119    /// Calculates the distances between a data set and the training
120    /// data. The rows are training and the columns test samples,
121    /// respectively. The returned distance matrix is dynamically
122    /// generated and needs to be deleted by the caller.
123    ///
124    utility::matrix* calculate_distances(const DataLookup2D&) const;
125    void calculate_unweighted(const MatrixLookup&,
126                              const MatrixLookup&,
127                              utility::matrix*) const;
128    void calculate_weighted(const MatrixLookupWeighted&,
129                            const MatrixLookupWeighted&,
130                            utility::matrix*) const;
131  };
132 
133 
134  // templates
135 
136  template <typename Distance>
137  KNN<Distance>::KNN(const MatrixLookup& data, const Target& target) 
138    : SupervisedClassifier(target), data_(data),k_(3)
139  {
140  }
141
142
143  template <typename Distance>
144  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
145    : SupervisedClassifier(target), data_(data),k_(3)
146  {
147  }
148 
149  template <typename Distance>
150  KNN<Distance>::~KNN()   
151  {
152  }
153 
154  template <typename Distance>
155  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& test) const
156  {
157    // matrix with training samples as rows and test samples as columns
158    utility::matrix* distances = 
159      new utility::matrix(data_.columns(),test.columns());
160   
161   
162    // unweighted test data
163    if(const MatrixLookup* test_unweighted = 
164       dynamic_cast<const MatrixLookup*>(&test)) {     
165      // unweighted training data
166      if(const MatrixLookup* training_unweighted = 
167         dynamic_cast<const MatrixLookup*>(&data_)) 
168        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
169      // weighted training data
170      else if(const MatrixLookupWeighted* training_weighted = 
171              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
172        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
173                           distances);             
174      // Training data can not be of incorrect type
175    }
176    // weighted test data
177    else if (const MatrixLookupWeighted* test_weighted = 
178             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
179      // unweighted training data
180      if(const MatrixLookup* training_unweighted = 
181         dynamic_cast<const MatrixLookup*>(&data_)) {
182        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
183                           *test_weighted,distances);
184      }
185      // weighted training data
186      else if(const MatrixLookupWeighted* training_weighted = 
187              dynamic_cast<const MatrixLookupWeighted*>(&data_)) 
188        calculate_weighted(*training_weighted,*test_weighted,distances);             
189      // Training data can not be of incorrect type
190    } 
191    else {
192      std::string str;
193      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
194      throw std::runtime_error(str);
195    }
196    return distances;
197  }
198
199  template <typename Distance>
200  void  KNN<Distance>:: calculate_unweighted(const MatrixLookup& training,
201                                             const MatrixLookup& test,
202                                             utility::matrix* distances) const
203  {
204    for(size_t i=0; i<training.columns(); i++) {
205      classifier::DataLookup1D training1(training,i,false);
206      for(size_t j=0; j<test.columns(); j++) {
207        classifier::DataLookup1D test1(test,j,false);
208        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
209        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
210      }
211    }
212  }
213 
214  template <typename Distance>
215  void  KNN<Distance>:: calculate_weighted(const MatrixLookupWeighted& training,
216                                           const MatrixLookupWeighted& test,
217                                           utility::matrix* distances) const
218  {
219    for(size_t i=0; i<training.columns(); i++) {
220      classifier::DataLookupWeighted1D training1(training,i,false);
221      for(size_t j=0; j<test.columns(); j++) {
222        classifier::DataLookupWeighted1D test1(test,j,false);
223        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
224        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
225      }
226    }
227  }
228
229 
230  template <typename Distance>
231  const DataLookup2D& KNN<Distance>::data(void) const
232  {
233    return data_;
234  }
235 
236 
237  template <typename Distance>
238  u_int KNN<Distance>::k() const
239  {
240    return k_;
241  }
242
243  template <typename Distance>
244  void KNN<Distance>::k(u_int k)
245  {
246    k_=k;
247  }
248
249
250  template <typename Distance>
251  SupervisedClassifier* 
252  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
253  {     
254    KNN* knn=0;
255    try {
256      if(data.weighted()) {
257        knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
258                              target);
259      } 
260      else {
261        knn=new KNN<Distance>(dynamic_cast<const MatrixLookup&>(data),
262                              target);
263      }
264      knn->k(this->k());
265    }
266    catch (std::bad_cast) {
267      std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class.";
268      throw std::runtime_error(str);
269    }
270    return knn;
271  }
272 
273 
274  template <typename Distance>
275  void KNN<Distance>::train()
276  {   
277    trained_=true;
278  }
279
280
281  template <typename Distance>
282  void KNN<Distance>::predict(const DataLookup2D& test,                     
283                              utility::matrix& prediction) const
284  {   
285    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
286
287    utility::matrix* distances=calculate_distances(test);
288   
289    // for each test sample (column in distances) find the closest
290    // training samples
291    prediction.resize(target_.nof_classes(),test.columns(),0.0);
292    for(size_t sample=0;sample<distances->columns();sample++) {
293      std::vector<size_t> k_index;
294      utility::sort_smallest_index(k_index,k_,
295                                   distances->column_const_view(sample));
296      for(size_t j=0;j<k_index.size();j++) {
297        prediction(target_(k_index[j]),sample)++;
298      }
299    }
300    prediction*=(1.0/k_);
301    delete distances;
302  }
303
304}}} // of namespace classifier, yat, and theplu
305
306#endif
307
Note: See TracBrowser for help on using the repository browser.