source: trunk/yat/classifier/KNN.h @ 1028

Last change on this file since 1028 was 1028, checked in by Peter, 14 years ago

documentation for VectorConstView? and changing name of view functions in matrix

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 7.4 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1028 2008-02-03 01:53:29Z peter $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookupWeighted1D.h"
28#include "MatrixLookup.h"
29#include "MatrixLookupWeighted.h"
30#include "SupervisedClassifier.h"
31#include "Target.h"
32#include "yat/statistics/vector_distance.h"
33#include "yat/utility/matrix.h"
34#include "yat/utility/yat_assert.h"
35
36#include <cmath>
37#include <map>
38#include <stdexcept>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  ///
45  /// @brief Class for Nearest Centroid Classification.
46  ///
47 
48 
49  template <typename Distance>
50  class KNN : public SupervisedClassifier
51  {
52   
53  public:
54    ///
55    /// Constructor taking the training data and the target   
56    /// as input.
57    ///
58    KNN(const MatrixLookup&, const Target&);
59
60
61    ///
62    /// Constructor taking the training data with weights and the
63    /// target as input.
64    ///
65    KNN(const MatrixLookupWeighted&, const Target&);
66
67    virtual ~KNN();
68   
69    //
70    // @return the training data
71    //
72    const DataLookup2D& data(void) const;
73
74
75    ///
76    /// Default number of neighbours (k) is set to 3.
77    ///
78    /// @return the number of neighbours
79    ///
80    u_int k() const;
81
82    ///
83    /// @brief sets the number of neighbours, k.
84    ///
85    void k(u_int);
86
87
88    SupervisedClassifier* make_classifier(const DataLookup2D&, 
89                                          const Target&) const;
90   
91    ///
92    /// Train the classifier using the training data. Centroids are
93    /// calculated for each class.
94    ///
95    /// @return true if training succedeed.
96    ///
97    bool train();
98
99   
100    ///
101    /// Calculate the distance to each centroid for test samples
102    ///
103    void predict(const DataLookup2D&, utility::matrix&) const;
104
105
106  private:
107
108    // data_ has to be of type DataLookup2D to accomodate both
109    // MatrixLookup and MatrixLookupWeighted
110    const DataLookup2D& data_;
111
112    // The number of neighbours
113    u_int k_;
114
115    ///
116    /// Calculates the distances between a data set and the training
117    /// data. The rows are training and the columns test samples,
118    /// respectively. The returned distance matrix is dynamically
119    /// generated and needs to be deleted by the caller.
120    ///
121    utility::matrix* calculate_distances(const DataLookup2D&) const;
122  };
123 
124 
125  // templates
126 
127  template <typename Distance>
128  KNN<Distance>::KNN(const MatrixLookup& data, const Target& target) 
129    : SupervisedClassifier(target), data_(data),k_(3)
130  {
131  }
132
133
134  template <typename Distance>
135  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
136    : SupervisedClassifier(target), data_(data),k_(3)
137  {
138  }
139 
140  template <typename Distance>
141  KNN<Distance>::~KNN()   
142  {
143  }
144 
145  template <typename Distance>
146  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const
147  {
148    // matrix with training samples as rows and test samples as columns
149    utility::matrix* distances = 
150      new utility::matrix(data_.columns(),input.columns());
151   
152    // if both training and test are unweighted: unweighted
153    // calculations are used.
154    const MatrixLookup* test_unweighted = 
155      dynamic_cast<const MatrixLookup*>(&input);     
156    if(test_unweighted && !data_.weighted()) {
157      const MatrixLookup* data_unweighted = 
158        dynamic_cast<const MatrixLookup*>(&data_);     
159      for(size_t i=0; i<data_.columns(); i++) {
160        classifier::DataLookup1D training(*data_unweighted,i,false);
161        for(size_t j=0; j<input.columns(); j++) {
162          classifier::DataLookup1D test(*test_unweighted,j,false);
163          utility::yat_assert<std::runtime_error>(training.size()==test.size());
164          (*distances)(i,j) =
165            statistics::vector_distance(training.begin(),training.end(),
166                                        test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
167          utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
168        }
169      }
170    }
171    // if either training or test is weighted: weighted calculations
172    // are used.
173    else {
174      const MatrixLookupWeighted* data_weighted = 
175        dynamic_cast<const MatrixLookupWeighted*>(&data_);
176      const MatrixLookupWeighted* test_weighted = 
177        dynamic_cast<const MatrixLookupWeighted*>(&input);               
178      if(data_weighted && test_weighted) {
179        for(size_t i=0; i<data_.columns(); i++) {
180          classifier::DataLookupWeighted1D training(*data_weighted,i,false);
181          for(size_t j=0; j<input.columns(); j++) {
182            classifier::DataLookupWeighted1D test(*test_weighted,j,false);
183            utility::yat_assert<std::runtime_error>(training.size()==test.size());
184            (*distances)(i,j) =
185              statistics::vector_distance(training.begin(),training.end(),
186                                          test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
187            utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
188          }
189        }
190      }
191      else {
192        std::string str;
193        str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted";
194        throw std::runtime_error(str);
195      }
196    }
197    return distances;
198  }
199 
200  template <typename Distance>
201  const DataLookup2D& KNN<Distance>::data(void) const
202  {
203    return data_;
204  }
205 
206 
207  template <typename Distance>
208  u_int KNN<Distance>::k() const
209  {
210    return k_;
211  }
212
213  template <typename Distance>
214  void KNN<Distance>::k(u_int k)
215  {
216    k_=k;
217  }
218
219
220  template <typename Distance>
221  SupervisedClassifier* 
222  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
223  {     
224    KNN* knn=0;
225    try {
226      if(data.weighted()) {
227        knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
228                              target);
229      } 
230      else {
231        knn=new KNN<Distance>(dynamic_cast<const MatrixLookup&>(data),
232                              target);
233      }
234      knn->k(this->k());
235    }
236    catch (std::bad_cast) {
237      std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class.";
238      throw std::runtime_error(str);
239    }
240    return knn;
241  }
242 
243 
244  template <typename Distance>
245  bool KNN<Distance>::train()
246  {   
247    trained_=true;
248    return trained_;
249  }
250
251
252  template <typename Distance>
253  void KNN<Distance>::predict(const DataLookup2D& input,                   
254                              utility::matrix& prediction) const
255  {   
256    utility::matrix* distances=calculate_distances(input);
257   
258    // for each test sample (column in distances) find the closest
259    // training samples
260    prediction.clone(utility::matrix(target_.nof_classes(),input.columns(),0.0));
261    for(size_t sample=0;sample<distances->columns();sample++) {
262      std::vector<size_t> k_index;
263      utility::sort_smallest_index(k_index,k_,
264                                   distances->column_const_view(sample));
265      for(size_t j=0;j<k_index.size();j++) {
266        prediction(target_(k_index[j]),sample)++;
267      }
268    }
269    prediction*=(1.0/k_);
270    delete distances;
271  }
272
273}}} // of namespace classifier, yat, and theplu
274
275#endif
276
Note: See TracBrowser for help on using the repository browser.