source: trunk/yat/classifier/KNN.h @ 948

Last change on this file since 948 was 948, checked in by Markus Ringnér, 14 years ago

Adding support and checks for intended lookups in classifiers

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 6.5 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 948 2007-10-08 14:06:53Z markus $
5
6#include "DataLookupWeighted1D.h"
7#include "MatrixLookup.h"
8#include "MatrixLookupWeighted.h"
9#include "SupervisedClassifier.h"
10#include "Target.h"
11#include "yat/statistics/vector_distance.h"
12#include "yat/utility/matrix.h"
13#include "yat/utility/yat_assert.h"
14
15#include <cmath>
16#include <map>
17#include <stdexcept>
18
19namespace theplu {
20namespace yat {
21namespace classifier {
22
23  ///
24  /// @brief Class for Nearest Centroid Classification.
25  ///
26 
27 
28  template <typename Distance>
29  class KNN : public SupervisedClassifier
30  {
31   
32  public:
33    ///
34    /// Constructor taking the training data and the target   
35    /// as input.
36    ///
37    KNN(const MatrixLookup&, const Target&);
38
39
40    ///
41    /// Constructor taking the training data with weights and the
42    /// target as input.
43    ///
44    KNN(const MatrixLookupWeighted&, const Target&);
45
46    virtual ~KNN();
47   
48    //
49    // @return the training data
50    //
51    const DataLookup2D& data(void) const;
52
53
54    ///
55    /// Default number of neighbours (k) is set to 3.
56    ///
57    /// @return the number of neighbours
58    ///
59    u_int k() const;
60
61    ///
62    /// @brief sets the number of neighbours, k.
63    ///
64    void k(u_int);
65
66
67    SupervisedClassifier* make_classifier(const DataLookup2D&, 
68                                          const Target&) const;
69   
70    ///
71    /// Train the classifier using the training data. Centroids are
72    /// calculated for each class.
73    ///
74    /// @return true if training succedeed.
75    ///
76    bool train();
77
78   
79    ///
80    /// Calculate the distance to each centroid for test samples
81    ///
82    void predict(const DataLookup2D&, utility::matrix&) const;
83
84
85  private:
86
87    // data_ has to be of type DataLookup2D to accomodate both
88    // MatrixLookup and MatrixLookupWeighted
89    const DataLookup2D& data_;
90
91    // The number of neighbours
92    u_int k_;
93
94    ///
95    /// Calculates the distances between a data set and the training
96    /// data. The rows are training and the columns test samples,
97    /// respectively. The returned distance matrix is dynamically
98    /// generated and needs to be deleted by the caller.
99    ///
100    utility::matrix* calculate_distances(const DataLookup2D&) const;
101  };
102 
103 
104  // templates
105 
106  template <typename Distance>
107  KNN<Distance>::KNN(const MatrixLookup& data, const Target& target) 
108    : SupervisedClassifier(target), data_(data),k_(3)
109  {
110  }
111
112
113  template <typename Distance>
114  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
115    : SupervisedClassifier(target), data_(data),k_(3)
116  {
117  }
118 
119  template <typename Distance>
120  KNN<Distance>::~KNN()   
121  {
122  }
123 
124  template <typename Distance>
125  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const
126  {
127    // matrix with training samples as rows and test samples as columns
128    utility::matrix* distances = 
129      new utility::matrix(data_.columns(),input.columns());
130   
131    // if both training and test are unweighted: unweighted
132    // calculations are used.
133    const MatrixLookup* test_unweighted = 
134      dynamic_cast<const MatrixLookup*>(&input);     
135    if(test_unweighted && !data_.weighted()) {
136      const MatrixLookup* data_unweighted = 
137        dynamic_cast<const MatrixLookup*>(&data_);     
138      for(size_t i=0; i<data_.columns(); i++) {
139        classifier::DataLookup1D training(*data_unweighted,i,false);
140        for(size_t j=0; j<input.columns(); j++) {
141          classifier::DataLookup1D test(*test_unweighted,j,false);
142          utility::yat_assert<std::runtime_error>(training.size()==test.size());
143          (*distances)(i,j) =
144            statistics::vector_distance(training.begin(),training.end(),
145                                        test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
146          utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
147        }
148      }
149    }
150    // if either training or test is weighted: weighted calculations
151    // are used.
152    else {
153      const MatrixLookupWeighted* data_weighted = 
154        dynamic_cast<const MatrixLookupWeighted*>(&data_);
155      const MatrixLookupWeighted* test_weighted = 
156        dynamic_cast<const MatrixLookupWeighted*>(&input);               
157      if(data_weighted && test_weighted) {
158        for(size_t i=0; i<data_.columns(); i++) {
159          classifier::DataLookupWeighted1D training(*data_weighted,i,false);
160          for(size_t j=0; j<input.columns(); j++) {
161            classifier::DataLookupWeighted1D test(*test_weighted,j,false);
162            utility::yat_assert<std::runtime_error>(training.size()==test.size());
163            (*distances)(i,j) =
164              statistics::vector_distance(training.begin(),training.end(),
165                                          test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
166            utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
167          }
168        }
169      }
170      else {
171        std::string str;
172        str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted";
173        throw std::runtime_error(str);
174      }
175    }
176    return distances;
177  }
178 
179  template <typename Distance>
180  const DataLookup2D& KNN<Distance>::data(void) const
181  {
182    return data_;
183  }
184 
185 
186  template <typename Distance>
187  u_int KNN<Distance>::k() const
188  {
189    return k_;
190  }
191
192  template <typename Distance>
193  void KNN<Distance>::k(u_int k)
194  {
195    k_=k;
196  }
197
198
199  template <typename Distance>
200  SupervisedClassifier* 
201  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
202  {     
203    KNN* knn=0;
204    try {
205      if(data.weighted()) {
206        knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
207                              target);
208      } 
209      else {
210        knn=new KNN<Distance>(dynamic_cast<const MatrixLookup&>(data),
211                              target);
212      }
213      knn->k(this->k());
214    }
215    catch (std::bad_cast) {
216      std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class.";
217      throw std::runtime_error(str);
218    }
219    return knn;
220  }
221 
222 
223  template <typename Distance>
224  bool KNN<Distance>::train()
225  {   
226    trained_=true;
227    return trained_;
228  }
229
230
231  template <typename Distance>
232  void KNN<Distance>::predict(const DataLookup2D& input,                   
233                              utility::matrix& prediction) const
234  {   
235    utility::matrix* distances=calculate_distances(input);
236   
237    // for each test sample (column in distances) find the closest
238    // training samples
239    prediction.clone(utility::matrix(target_.nof_classes(),input.columns(),0.0));
240    for(size_t sample=0;sample<distances->columns();sample++) {
241      std::vector<size_t> k_index;
242      utility::sort_smallest_index(k_index,k_,utility::vector(*distances,sample,false));
243      for(size_t j=0;j<k_index.size();j++) {
244        prediction(target_(k_index[j]),sample)++;
245      }
246    }
247    prediction*=(1.0/k_);
248    delete distances;
249  }
250
251}}} // of namespace classifier, yat, and theplu
252
253#endif
254
Note: See TracBrowser for help on using the repository browser.