source: trunk/yat/classifier/KNN.h @ 904

Last change on this file since 904 was 904, checked in by Peter, 14 years ago

removing assert from header file

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.8 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 904 2007-09-27 18:30:12Z peter $
5
6#include "DataLookupWeighted1D.h"
7#include "MatrixLookupWeighted.h"
8#include "SupervisedClassifier.h"
9#include "Target.h"
10#include "yat/statistics/vector_distance.h"
11#include "yat/utility/matrix.h"
12
13#include <map>
14
15namespace theplu {
16namespace yat {
17namespace classifier {
18
19  ///
20  /// @brief Class for Nearest Centroid Classification.
21  ///
22 
23 
24  template <typename Distance>
25  class KNN : public SupervisedClassifier
26  {
27   
28  public:
29    ///
30    /// Constructor taking the training data with weights, the target
31    /// vector, the distance measure, and a weight matrix for the
32    /// training data as input.
33    ///
34    KNN(const MatrixLookupWeighted&, const Target&);
35
36    virtual ~KNN();
37   
38    //
39    // @return the training data
40    //
41    const DataLookup2D& data(void) const;
42
43
44    ///
45    /// Default number of neighbours (k) is set to 3.
46    ///
47    /// @return the number of neighbours
48    ///
49    u_int k() const;
50
51    ///
52    /// @brief sets the number of neighbours, k.
53    ///
54    void k(u_int);
55
56
57    SupervisedClassifier* make_classifier(const DataLookup2D&, 
58                                          const Target&) const;
59   
60    ///
61    /// Train the classifier using the training data. Centroids are
62    /// calculated for each class.
63    ///
64    /// @return true if training succedeed.
65    ///
66    bool train();
67
68   
69    ///
70    /// Calculate the distance to each centroid for test samples
71    ///
72    void predict(const DataLookup2D&, utility::matrix&) const;
73
74
75  private:
76
77    // data_ has to be of type DataLookup2D to accomodate both
78    // MatrixLookup and MatrixLookupWeighted
79    const DataLookup2D& data_;
80
81    // The number of neighbours
82    u_int k_;
83
84    ///
85    /// Calculates the distances between a data set and the training
86    /// data. The rows are training and the columns test samples,
87    /// respectively. The returned distance matrix is dynamically
88    /// generated and needs to be deleted by the caller.
89    ///
90    utility::matrix* calculate_distances(const DataLookup2D&) const;
91  };
92
93  ///
94  /// The output operator for the KNN class.
95  ///
96  //  std::ostream& operator<< (std::ostream&, const KNN&);
97 
98 
99  // templates
100 
101  template <typename Distance>
102  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
103    : SupervisedClassifier(target), data_(data),k_(3)
104  {
105  }
106 
107  template <typename Distance>
108  KNN<Distance>::~KNN()   
109  {
110  }
111 
112  template <typename Distance>
113  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const
114  {
115    const MatrixLookupWeighted* weighted_data = 
116      dynamic_cast<const MatrixLookupWeighted*>(&data_);
117    const MatrixLookupWeighted* weighted_input = 
118      dynamic_cast<const MatrixLookupWeighted*>(&input); 
119
120    // matrix with training samples as rows and test samples as columns
121    utility::matrix* distances = new utility::matrix(data_.columns(),input.columns());
122     
123    if(weighted_data && weighted_input) {
124      for(size_t i=0; i<data_.columns(); i++) {
125        classifier::DataLookupWeighted1D training(*weighted_data,i,false);
126        for(size_t j=0; j<input.columns(); j++) {
127          classifier::DataLookupWeighted1D test(*weighted_input,j,false);
128          (*distances)(i,j)=statistics::vector_distance(training.begin(),training.end(),test.begin(),typename statistics::vector_distance_traits<Distance>::distance());
129        }
130      }
131    }
132    else {
133      std::string str;
134      str = "Error in KNN::calculate_distances: Only MatrixLookupWeighted supported still.";
135      throw std::runtime_error(str);
136    }
137    return distances;
138  }
139 
140  template <typename Distance>
141  const DataLookup2D& KNN<Distance>::data(void) const
142  {
143    return data_;
144  }
145 
146 
147  template <typename Distance>
148  u_int KNN<Distance>::k() const
149  {
150    return k_;
151  }
152
153  template <typename Distance>
154  void KNN<Distance>::k(u_int k)
155  {
156    k_=k;
157  }
158
159
160  template <typename Distance>
161  SupervisedClassifier* 
162  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
163  {     
164    KNN* knn=0;
165    if(data.weighted()) {
166      knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
167                            target);
168    }
169    knn->k(this->k());
170    return knn;
171  }
172 
173 
174  template <typename Distance>
175  bool KNN<Distance>::train()
176  {   
177    trained_=true;
178    return trained_;
179  }
180
181
182  template <typename Distance>
183  void KNN<Distance>::predict(const DataLookup2D& input,                   
184                              utility::matrix& prediction) const
185  {   
186    utility::matrix* distances=calculate_distances(input);
187   
188    // for each test sample (column in distances) find the closest training samples
189    prediction.clone(utility::matrix(target_.nof_classes(), input.columns(),0.0));
190    for(size_t sample=0;sample<distances->columns();sample++) {
191      std::vector<size_t> k_index;
192      utility::sort_smallest_index(k_index,k_,utility::vector(*distances,sample,false));
193      for(size_t j=0;j<k_index.size();j++) {
194        prediction(target_(k_index[j]),sample)++;
195      }
196    }
197    delete distances;
198  }
199
200}}} // of namespace classifier, yat, and theplu
201
202#endif
203
Note: See TracBrowser for help on using the repository browser.