source: trunk/yat/classifier/KNN.h @ 916

Last change on this file since 916 was 916, checked in by Peter, 14 years ago

Sorry this commit is a bit to big.

Adding a yat_assert. The yat assert are turned on by providing a
'-DYAT_DEBUG' flag to preprocessor if normal cassert is turned
on. This flag is activated for developers running configure with
--enable-debug. The motivation is that we can use these yat_asserts in
header files and the yat_asserts will be invisible to the normal user
also if he uses C-asserts.

added output operator in DataLookup2D and removed output operator in
MatrixLookup?

Removed template function add_values in Averager and weighted version

Added function to AveragerWeighted? taking iterator to four ranges.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 5.0 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 916 2007-09-30 00:50:10Z peter $
5
6#include "DataLookupWeighted1D.h"
7#include "MatrixLookupWeighted.h"
8#include "SupervisedClassifier.h"
9#include "Target.h"
10#include "yat/statistics/vector_distance.h"
11#include "yat/utility/matrix.h"
12#include "yat/utility/yat_assert.h"
13
14#include <cmath>
15#include <map>
16
17namespace theplu {
18namespace yat {
19namespace classifier {
20
21  ///
22  /// @brief Class for Nearest Centroid Classification.
23  ///
24 
25 
26  template <typename Distance>
27  class KNN : public SupervisedClassifier
28  {
29   
30  public:
31    ///
32    /// Constructor taking the training data with weights, the target
33    /// vector, the distance measure, and a weight matrix for the
34    /// training data as input.
35    ///
36    KNN(const MatrixLookupWeighted&, const Target&);
37
38    virtual ~KNN();
39   
40    //
41    // @return the training data
42    //
43    const DataLookup2D& data(void) const;
44
45
46    ///
47    /// Default number of neighbours (k) is set to 3.
48    ///
49    /// @return the number of neighbours
50    ///
51    u_int k() const;
52
53    ///
54    /// @brief sets the number of neighbours, k.
55    ///
56    void k(u_int);
57
58
59    SupervisedClassifier* make_classifier(const DataLookup2D&, 
60                                          const Target&) const;
61   
62    ///
63    /// Train the classifier using the training data. Centroids are
64    /// calculated for each class.
65    ///
66    /// @return true if training succedeed.
67    ///
68    bool train();
69
70   
71    ///
72    /// Calculate the distance to each centroid for test samples
73    ///
74    void predict(const DataLookup2D&, utility::matrix&) const;
75
76
77  private:
78
79    // data_ has to be of type DataLookup2D to accomodate both
80    // MatrixLookup and MatrixLookupWeighted
81    const DataLookup2D& data_;
82
83    // The number of neighbours
84    u_int k_;
85
86    ///
87    /// Calculates the distances between a data set and the training
88    /// data. The rows are training and the columns test samples,
89    /// respectively. The returned distance matrix is dynamically
90    /// generated and needs to be deleted by the caller.
91    ///
92    utility::matrix* calculate_distances(const DataLookup2D&) const;
93  };
94
95  ///
96  /// The output operator for the KNN class.
97  ///
98  //  std::ostream& operator<< (std::ostream&, const KNN&);
99 
100 
101  // templates
102 
103  template <typename Distance>
104  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
105    : SupervisedClassifier(target), data_(data),k_(3)
106  {
107  }
108 
109  template <typename Distance>
110  KNN<Distance>::~KNN()   
111  {
112  }
113 
114  template <typename Distance>
115  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const
116  {
117    const MatrixLookupWeighted* weighted_data = 
118      dynamic_cast<const MatrixLookupWeighted*>(&data_);
119    const MatrixLookupWeighted* weighted_input = 
120      dynamic_cast<const MatrixLookupWeighted*>(&input); 
121
122    // matrix with training samples as rows and test samples as columns
123    utility::matrix* distances = 
124      new utility::matrix(data_.columns(),input.columns());
125     
126    if(weighted_data && weighted_input) {
127      for(size_t i=0; i<data_.columns(); i++) {
128        classifier::DataLookupWeighted1D training(*weighted_data,i,false);
129        for(size_t j=0; j<input.columns(); j++) {
130          classifier::DataLookupWeighted1D test(*weighted_input,j,false);
131          yat_assert(training.size()==test.size());
132          (*distances)(i,j)=statistics::vector_distance(training.begin(),training.end(),test.begin(),typename statistics::vector_distance_traits<Distance>::distance());
133          yat_assert(!std::isnan((*distances)(i,j)));
134        }
135      }
136    }
137    else {
138      std::string str;
139      str = "Error in KNN::calculate_distances: Only MatrixLookupWeighted supported still.";
140      throw std::runtime_error(str);
141    }
142    return distances;
143  }
144 
145  template <typename Distance>
146  const DataLookup2D& KNN<Distance>::data(void) const
147  {
148    return data_;
149  }
150 
151 
152  template <typename Distance>
153  u_int KNN<Distance>::k() const
154  {
155    return k_;
156  }
157
158  template <typename Distance>
159  void KNN<Distance>::k(u_int k)
160  {
161    k_=k;
162  }
163
164
165  template <typename Distance>
166  SupervisedClassifier* 
167  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
168  {     
169    KNN* knn=0;
170    if(data.weighted()) {
171      knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
172                            target);
173    }
174    knn->k(this->k());
175    return knn;
176  }
177 
178 
179  template <typename Distance>
180  bool KNN<Distance>::train()
181  {   
182    trained_=true;
183    return trained_;
184  }
185
186
187  template <typename Distance>
188  void KNN<Distance>::predict(const DataLookup2D& input,                   
189                              utility::matrix& prediction) const
190  {   
191    utility::matrix* distances=calculate_distances(input);
192   
193    // for each test sample (column in distances) find the closest
194    // training samples
195    prediction.clone(utility::matrix(target_.nof_classes(),input.columns(),0.0));
196    for(size_t sample=0;sample<distances->columns();sample++) {
197      std::vector<size_t> k_index;
198      utility::sort_smallest_index(k_index,k_,utility::vector(*distances,sample,false));
199      for(size_t j=0;j<k_index.size();j++) {
200        prediction(target_(k_index[j]),sample)++;
201      }
202    }
203    delete distances;
204  }
205
206}}} // of namespace classifier, yat, and theplu
207
208#endif
209
Note: See TracBrowser for help on using the repository browser.