source: trunk/yat/classifier/KNN.h @ 902

Last change on this file since 902 was 902, checked in by Markus Ringnér, 14 years ago

Fixing mistake in last revision (forgot to add KNN.h)

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.9 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 902 2007-09-27 09:01:16Z markus $
5
6#include "DataLookupWeighted1D.h"
7#include "MatrixLookupWeighted.h"
8#include "SupervisedClassifier.h"
9#include "Target.h"
10#include "yat/statistics/vector_distance.h"
11#include "yat/utility/matrix.h"
12
13#include <map>
14
15namespace theplu {
16namespace yat {
17namespace classifier {
18
19  ///
20  /// @brief Class for Nearest Centroid Classification.
21  ///
22 
23 
24  template <typename Distance>
25  class KNN : public SupervisedClassifier
26  {
27   
28  public:
29    ///
30    /// Constructor taking the training data with weights, the target
31    /// vector, the distance measure, and a weight matrix for the
32    /// training data as input.
33    ///
34    KNN(const MatrixLookupWeighted&, const Target&);
35
36    virtual ~KNN();
37   
38    //
39    // @return the training data
40    //
41    const DataLookup2D& data(void) const;
42
43
44    ///
45    /// Default number of neighbours (k) is set to 3.
46    ///
47    /// @return the number of neighbours
48    ///
49    u_int k() const;
50
51    ///
52    /// @brief sets the number of neighbours, k.
53    ///
54    void k(u_int);
55
56
57    SupervisedClassifier* make_classifier(const DataLookup2D&, 
58                                          const Target&) const;
59   
60    ///
61    /// Train the classifier using the training data. Centroids are
62    /// calculated for each class.
63    ///
64    /// @return true if training succedeed.
65    ///
66    bool train();
67
68   
69    ///
70    /// Calculate the distance to each centroid for test samples
71    ///
72    void predict(const DataLookup2D&, utility::matrix&) const;
73
74
75  private:
76
77    // data_ has to be of type DataLookup2D to accomodate both
78    // MatrixLookup and MatrixLookupWeighted
79    const DataLookup2D& data_;
80
81    // The number of neighbours
82    u_int k_;
83
84    ///
85    /// Calculates the distances between a data set and the training
86    /// data. The rows are training and the columns test samples,
87    /// respectively. The returned distance matrix is dynamically
88    /// generated and needs to be deleted by the caller.
89    ///
90    utility::matrix* calculate_distances(const DataLookup2D&) const;
91  };
92
93  ///
94  /// The output operator for the KNN class.
95  ///
96  //  std::ostream& operator<< (std::ostream&, const KNN&);
97 
98 
99  // templates
100 
101  template <typename Distance>
102  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
103    : SupervisedClassifier(target), data_(data),k_(3)
104  {
105  }
106 
107  template <typename Distance>
108  KNN<Distance>::~KNN()   
109  {
110  }
111 
112  template <typename Distance>
113  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const
114  {
115    assert(input.rows()==data_.rows());
116   
117    const MatrixLookupWeighted* weighted_data = 
118      dynamic_cast<const MatrixLookupWeighted*>(&data_);
119    const MatrixLookupWeighted* weighted_input = 
120      dynamic_cast<const MatrixLookupWeighted*>(&input); 
121
122    // matrix with training samples as rows and test samples as columns
123    utility::matrix* distances = new utility::matrix(data_.columns(),input.columns());
124     
125    if(weighted_data && weighted_input) {
126      for(size_t i=0; i<data_.columns(); i++) {
127        classifier::DataLookupWeighted1D training(*weighted_data,i,false);
128        for(size_t j=0; j<input.columns(); j++) {
129          classifier::DataLookupWeighted1D test(*weighted_input,j,false);
130          (*distances)(i,j)=statistics::vector_distance(training.begin(),training.end(),test.begin(),typename statistics::vector_distance_traits<Distance>::distance());
131        }
132      }
133    }
134    else {
135      std::string str;
136      str = "Error in KNN::calculate_distances: Only MatrixLookupWeighted supported still.";
137      throw std::runtime_error(str);
138    }
139    return distances;
140  }
141 
142  template <typename Distance>
143  const DataLookup2D& KNN<Distance>::data(void) const
144  {
145    return data_;
146  }
147 
148 
149  template <typename Distance>
150  u_int KNN<Distance>::k() const
151  {
152    return k_;
153  }
154
155  template <typename Distance>
156  void KNN<Distance>::k(u_int k)
157  {
158    k_=k;
159  }
160
161
162  template <typename Distance>
163  SupervisedClassifier* 
164  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
165  {     
166    KNN* knn=0;
167    if(data.weighted()) {
168      knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
169                            target);
170    }
171    knn->k(this->k());
172    return knn;
173  }
174 
175 
176  template <typename Distance>
177  bool KNN<Distance>::train()
178  {   
179    trained_=true;
180    return trained_;
181  }
182
183
184  template <typename Distance>
185  void KNN<Distance>::predict(const DataLookup2D& input,                   
186                              utility::matrix& prediction) const
187  {   
188    assert(input.rows()==data_.rows());
189   
190    utility::matrix* distances=calculate_distances(input);
191   
192    // for each test sample (column in distances) find the closest training samples
193    prediction.clone(utility::matrix(target_.nof_classes(), input.columns(),0.0));
194    for(size_t sample=0;sample<distances->columns();sample++) {
195      std::vector<size_t> k_index;
196      utility::sort_smallest_index(k_index,k_,utility::vector(*distances,sample,false));
197      for(size_t j=0;j<k_index.size();j++) {
198        prediction(target_(k_index[j]),sample)++;
199      }
200    }
201    delete distances;
202  }
203
204}}} // of namespace classifier, yat, and theplu
205
206#endif
207
Note: See TracBrowser for help on using the repository browser.