source: trunk/yat/classifier/KNN.h @ 936

Last change on this file since 936 was 936, checked in by Peter, 14 years ago

reimplementing yat_assert as a throwing function

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 5.2 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 936 2007-10-05 23:02:08Z peter $
5
6#include "DataLookupWeighted1D.h"
7#include "MatrixLookupWeighted.h"
8#include "SupervisedClassifier.h"
9#include "Target.h"
10#include "yat/statistics/vector_distance.h"
11#include "yat/utility/matrix.h"
12#include "yat/utility/yat_assert.h"
13
14#include <cmath>
15#include <map>
16#include <stdexcept>
17
18namespace theplu {
19namespace yat {
20namespace classifier {
21
22  ///
23  /// @brief Class for Nearest Centroid Classification.
24  ///
25 
26 
27  template <typename Distance>
28  class KNN : public SupervisedClassifier
29  {
30   
31  public:
32    ///
33    /// Constructor taking the training data with weights, the target
34    /// vector, the distance measure, and a weight matrix for the
35    /// training data as input.
36    ///
37    KNN(const MatrixLookupWeighted&, const Target&);
38
39    virtual ~KNN();
40   
41    //
42    // @return the training data
43    //
44    const DataLookup2D& data(void) const;
45
46
47    ///
48    /// Default number of neighbours (k) is set to 3.
49    ///
50    /// @return the number of neighbours
51    ///
52    u_int k() const;
53
54    ///
55    /// @brief sets the number of neighbours, k.
56    ///
57    void k(u_int);
58
59
60    SupervisedClassifier* make_classifier(const DataLookup2D&, 
61                                          const Target&) const;
62   
63    ///
64    /// Train the classifier using the training data. Centroids are
65    /// calculated for each class.
66    ///
67    /// @return true if training succedeed.
68    ///
69    bool train();
70
71   
72    ///
73    /// Calculate the distance to each centroid for test samples
74    ///
75    void predict(const DataLookup2D&, utility::matrix&) const;
76
77
78  private:
79
80    // data_ has to be of type DataLookup2D to accomodate both
81    // MatrixLookup and MatrixLookupWeighted
82    const DataLookup2D& data_;
83
84    // The number of neighbours
85    u_int k_;
86
87    ///
88    /// Calculates the distances between a data set and the training
89    /// data. The rows are training and the columns test samples,
90    /// respectively. The returned distance matrix is dynamically
91    /// generated and needs to be deleted by the caller.
92    ///
93    utility::matrix* calculate_distances(const DataLookup2D&) const;
94  };
95
96  ///
97  /// The output operator for the KNN class.
98  ///
99  //  std::ostream& operator<< (std::ostream&, const KNN&);
100 
101 
102  // templates
103 
104  template <typename Distance>
105  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
106    : SupervisedClassifier(target), data_(data),k_(3)
107  {
108  }
109 
110  template <typename Distance>
111  KNN<Distance>::~KNN()   
112  {
113  }
114 
115  template <typename Distance>
116  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const
117  {
118    const MatrixLookupWeighted* weighted_data = 
119      dynamic_cast<const MatrixLookupWeighted*>(&data_);
120    const MatrixLookupWeighted* weighted_input = 
121      dynamic_cast<const MatrixLookupWeighted*>(&input); 
122
123    // matrix with training samples as rows and test samples as columns
124    utility::matrix* distances = 
125      new utility::matrix(data_.columns(),input.columns());
126     
127    if(weighted_data && weighted_input) {
128      for(size_t i=0; i<data_.columns(); i++) {
129        classifier::DataLookupWeighted1D training(*weighted_data,i,false);
130        for(size_t j=0; j<input.columns(); j++) {
131          classifier::DataLookupWeighted1D test(*weighted_input,j,false);
132          utility::yat_assert<std::runtime_error>(training.size()==test.size());
133          (*distances)(i,j) =
134            statistics::vector_distance(training.begin(),training.end(),
135                                        test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
136          utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
137        }
138      }
139    }
140    else {
141      std::string str;
142      str = "Error in KNN::calculate_distances: Only MatrixLookupWeighted supported still.";
143      throw std::runtime_error(str);
144    }
145    return distances;
146  }
147 
148  template <typename Distance>
149  const DataLookup2D& KNN<Distance>::data(void) const
150  {
151    return data_;
152  }
153 
154 
155  template <typename Distance>
156  u_int KNN<Distance>::k() const
157  {
158    return k_;
159  }
160
161  template <typename Distance>
162  void KNN<Distance>::k(u_int k)
163  {
164    k_=k;
165  }
166
167
168  template <typename Distance>
169  SupervisedClassifier* 
170  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
171  {     
172    KNN* knn=0;
173    try {
174      if(data.weighted()) {
175        knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
176                              target);
177      } 
178      knn->k(this->k());
179    }
180    catch (std::bad_cast) {
181      std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class.";
182      throw std::runtime_error(str);
183    }
184    return knn;
185  }
186 
187 
188  template <typename Distance>
189  bool KNN<Distance>::train()
190  {   
191    trained_=true;
192    return trained_;
193  }
194
195
196  template <typename Distance>
197  void KNN<Distance>::predict(const DataLookup2D& input,                   
198                              utility::matrix& prediction) const
199  {   
200    utility::matrix* distances=calculate_distances(input);
201   
202    // for each test sample (column in distances) find the closest
203    // training samples
204    prediction.clone(utility::matrix(target_.nof_classes(),input.columns(),0.0));
205    for(size_t sample=0;sample<distances->columns();sample++) {
206      std::vector<size_t> k_index;
207      utility::sort_smallest_index(k_index,k_,utility::vector(*distances,sample,false));
208      for(size_t j=0;j<k_index.size();j++) {
209        prediction(target_(k_index[j]),sample)++;
210      }
211    }
212    delete distances;
213  }
214
215}}} // of namespace classifier, yat, and theplu
216
217#endif
218
Note: See TracBrowser for help on using the repository browser.