source: trunk/yat/classifier/KNN.h @ 1090

Last change on this file since 1090 was 1050, checked in by Peter, 13 years ago

Simplifying distance structure

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 6.9 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1050 2008-02-07 18:47:34Z peter $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookup1D.h"
28#include "DataLookupWeighted1D.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/matrix.h"
34#include "yat/utility/yat_assert.h"
35
36#include <cmath>
37#include <map>
38#include <stdexcept>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  ///
45  /// @brief Class for Nearest Centroid Classification.
46  ///
47 
48 
49  template <typename Distance>
50  class KNN : public SupervisedClassifier
51  {
52   
53  public:
54    ///
55    /// Constructor taking the training data and the target   
56    /// as input.
57    ///
58    KNN(const MatrixLookup&, const Target&);
59
60
61    ///
62    /// Constructor taking the training data with weights and the
63    /// target as input.
64    ///
65    KNN(const MatrixLookupWeighted&, const Target&);
66
67    virtual ~KNN();
68   
69    //
70    // @return the training data
71    //
72    const DataLookup2D& data(void) const;
73
74
75    ///
76    /// Default number of neighbours (k) is set to 3.
77    ///
78    /// @return the number of neighbours
79    ///
80    u_int k() const;
81
82    ///
83    /// @brief sets the number of neighbours, k.
84    ///
85    void k(u_int);
86
87
88    SupervisedClassifier* make_classifier(const DataLookup2D&, 
89                                          const Target&) const;
90   
91    ///
92    /// Train the classifier using the training data. Centroids are
93    /// calculated for each class.
94    ///
95    /// @return true if training succedeed.
96    ///
97    void train();
98
99   
100    ///
101    /// Calculate the distance to each centroid for test samples
102    ///
103    void predict(const DataLookup2D&, utility::matrix&) const;
104
105
106  private:
107
108    // data_ has to be of type DataLookup2D to accomodate both
109    // MatrixLookup and MatrixLookupWeighted
110    const DataLookup2D& data_;
111
112    // The number of neighbours
113    u_int k_;
114
115    Distance distance_;
116    ///
117    /// Calculates the distances between a data set and the training
118    /// data. The rows are training and the columns test samples,
119    /// respectively. The returned distance matrix is dynamically
120    /// generated and needs to be deleted by the caller.
121    ///
122    utility::matrix* calculate_distances(const DataLookup2D&) const;
123  };
124 
125 
126  // templates
127 
128  template <typename Distance>
129  KNN<Distance>::KNN(const MatrixLookup& data, const Target& target) 
130    : SupervisedClassifier(target), data_(data),k_(3)
131  {
132  }
133
134
135  template <typename Distance>
136  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
137    : SupervisedClassifier(target), data_(data),k_(3)
138  {
139  }
140 
141  template <typename Distance>
142  KNN<Distance>::~KNN()   
143  {
144  }
145 
146  template <typename Distance>
147  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& test) const
148  {
149    // matrix with training samples as rows and test samples as columns
150    utility::matrix* distances = 
151      new utility::matrix(data_.columns(),test.columns());
152   
153    // unweighted test data
154    if(const MatrixLookup* test_unweighted = 
155       dynamic_cast<const MatrixLookup*>(&test)) {     
156      for(size_t i=0; i<data_.columns(); i++) {
157        for(size_t j=0; j<test.columns(); j++) {
158          classifier::DataLookup1D test(*test_unweighted,j,false);
159          classifier::DataLookup1D tmp(data_,i,false);
160          (*distances)(i,j) = distance_(tmp.begin(), tmp.end(), test.begin());
161          utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
162        }
163      }
164    }
165    // weighted test data
166    else {
167      const MatrixLookupWeighted* data_weighted = 
168        dynamic_cast<const MatrixLookupWeighted*>(&data_);
169      const MatrixLookupWeighted* test_weighted = 
170        dynamic_cast<const MatrixLookupWeighted*>(&test);               
171      if(data_weighted && test_weighted) {
172        for(size_t i=0; i<data_.columns(); i++) {
173          classifier::DataLookupWeighted1D training(*data_weighted,i,false);
174          for(size_t j=0; j<test.columns(); j++) {
175            classifier::DataLookupWeighted1D test(*test_weighted,j,false);
176            utility::yat_assert<std::runtime_error>(training.size()==test.size());
177            (*distances)(i,j) = distance_(training.begin(), training.end(),
178                                          test.begin());
179            utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
180          }
181        }
182      }
183      else {
184        std::string str;
185        str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted";
186        throw std::runtime_error(str);
187      }
188    }
189    return distances;
190  }
191 
192  template <typename Distance>
193  const DataLookup2D& KNN<Distance>::data(void) const
194  {
195    return data_;
196  }
197 
198 
199  template <typename Distance>
200  u_int KNN<Distance>::k() const
201  {
202    return k_;
203  }
204
205  template <typename Distance>
206  void KNN<Distance>::k(u_int k)
207  {
208    k_=k;
209  }
210
211
212  template <typename Distance>
213  SupervisedClassifier* 
214  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
215  {     
216    KNN* knn=0;
217    try {
218      if(data.weighted()) {
219        knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
220                              target);
221      } 
222      else {
223        knn=new KNN<Distance>(dynamic_cast<const MatrixLookup&>(data),
224                              target);
225      }
226      knn->k(this->k());
227    }
228    catch (std::bad_cast) {
229      std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class.";
230      throw std::runtime_error(str);
231    }
232    return knn;
233  }
234 
235 
236  template <typename Distance>
237  void KNN<Distance>::train()
238  {   
239    trained_=true;
240  }
241
242
243  template <typename Distance>
244  void KNN<Distance>::predict(const DataLookup2D& test,                     
245                              utility::matrix& prediction) const
246  {   
247    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
248
249    utility::matrix* distances=calculate_distances(test);
250   
251    // for each test sample (column in distances) find the closest
252    // training samples
253    prediction.clone(utility::matrix(target_.nof_classes(),test.columns(),0.0));
254    for(size_t sample=0;sample<distances->columns();sample++) {
255      std::vector<size_t> k_index;
256      utility::sort_smallest_index(k_index,k_,
257                                   distances->column_const_view(sample));
258      for(size_t j=0;j<k_index.size();j++) {
259        prediction(target_(k_index[j]),sample)++;
260      }
261    }
262    prediction*=(1.0/k_);
263    delete distances;
264  }
265
266}}} // of namespace classifier, yat, and theplu
267
268#endif
269
Note: See TracBrowser for help on using the repository browser.