source: trunk/yat/classifier/KNN.h @ 1031

Last change on this file since 1031 was 1031, checked in by Markus Ringnér, 14 years ago

Fixes #272

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 7.2 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 1031 2008-02-04 15:44:44Z markus $
5
6/*
7  Copyright (C) 2007 Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "DataLookupWeighted1D.h"
28#include "MatrixLookup.h"
29#include "MatrixLookupWeighted.h"
30#include "SupervisedClassifier.h"
31#include "Target.h"
32#include "yat/statistics/distance.h"
33#include "yat/utility/matrix.h"
34#include "yat/utility/yat_assert.h"
35
36#include <cmath>
37#include <map>
38#include <stdexcept>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  ///
45  /// @brief Class for Nearest Centroid Classification.
46  ///
47 
48 
49  template <typename Distance>
50  class KNN : public SupervisedClassifier
51  {
52   
53  public:
54    ///
55    /// Constructor taking the training data and the target   
56    /// as input.
57    ///
58    KNN(const MatrixLookup&, const Target&);
59
60
61    ///
62    /// Constructor taking the training data with weights and the
63    /// target as input.
64    ///
65    KNN(const MatrixLookupWeighted&, const Target&);
66
67    virtual ~KNN();
68   
69    //
70    // @return the training data
71    //
72    const DataLookup2D& data(void) const;
73
74
75    ///
76    /// Default number of neighbours (k) is set to 3.
77    ///
78    /// @return the number of neighbours
79    ///
80    u_int k() const;
81
82    ///
83    /// @brief sets the number of neighbours, k.
84    ///
85    void k(u_int);
86
87
88    SupervisedClassifier* make_classifier(const DataLookup2D&, 
89                                          const Target&) const;
90   
91    ///
92    /// Train the classifier using the training data. Centroids are
93    /// calculated for each class.
94    ///
95    /// @return true if training succedeed.
96    ///
97    bool train();
98
99   
100    ///
101    /// Calculate the distance to each centroid for test samples
102    ///
103    void predict(const DataLookup2D&, utility::matrix&) const;
104
105
106  private:
107
108    // data_ has to be of type DataLookup2D to accomodate both
109    // MatrixLookup and MatrixLookupWeighted
110    const DataLookup2D& data_;
111
112    // The number of neighbours
113    u_int k_;
114
115    ///
116    /// Calculates the distances between a data set and the training
117    /// data. The rows are training and the columns test samples,
118    /// respectively. The returned distance matrix is dynamically
119    /// generated and needs to be deleted by the caller.
120    ///
121    utility::matrix* calculate_distances(const DataLookup2D&) const;
122  };
123 
124 
125  // templates
126 
127  template <typename Distance>
128  KNN<Distance>::KNN(const MatrixLookup& data, const Target& target) 
129    : SupervisedClassifier(target), data_(data),k_(3)
130  {
131  }
132
133
134  template <typename Distance>
135  KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 
136    : SupervisedClassifier(target), data_(data),k_(3)
137  {
138  }
139 
140  template <typename Distance>
141  KNN<Distance>::~KNN()   
142  {
143  }
144 
145  template <typename Distance>
146  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& test) const
147  {
148    // matrix with training samples as rows and test samples as columns
149    utility::matrix* distances = 
150      new utility::matrix(data_.columns(),test.columns());
151   
152    // unweighted test data
153    if(const MatrixLookup* test_unweighted = 
154       dynamic_cast<const MatrixLookup*>(&test)) {     
155      for(size_t i=0; i<data_.columns(); i++) {
156        for(size_t j=0; j<test.columns(); j++) {
157          classifier::DataLookup1D test(*test_unweighted,j,false);
158          (*distances)(i,j) =
159            statistics::distance(classifier::DataLookup1D(data_,
160                                                          i,false).begin(),
161                                 classifier::DataLookup1D(data_,
162                                                          i,false).end(),
163                                 test.begin(), 
164                                 typename statistics::
165                                 distance_traits<Distance>::distance());
166          utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
167        }
168      }
169    }
170    // weighted test data
171    else {
172      const MatrixLookupWeighted* data_weighted = 
173        dynamic_cast<const MatrixLookupWeighted*>(&data_);
174      const MatrixLookupWeighted* test_weighted = 
175        dynamic_cast<const MatrixLookupWeighted*>(&test);               
176      if(data_weighted && test_weighted) {
177        for(size_t i=0; i<data_.columns(); i++) {
178          classifier::DataLookupWeighted1D training(*data_weighted,i,false);
179          for(size_t j=0; j<test.columns(); j++) {
180            classifier::DataLookupWeighted1D test(*test_weighted,j,false);
181            utility::yat_assert<std::runtime_error>(training.size()==test.size());
182            (*distances)(i,j) =
183              statistics::distance(training.begin(),training.end(),
184                                   test.begin(), typename statistics::distance_traits<Distance>::distance());
185            utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
186          }
187        }
188      }
189      else {
190        std::string str;
191        str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted";
192        throw std::runtime_error(str);
193      }
194    }
195    return distances;
196  }
197 
198  template <typename Distance>
199  const DataLookup2D& KNN<Distance>::data(void) const
200  {
201    return data_;
202  }
203 
204 
205  template <typename Distance>
206  u_int KNN<Distance>::k() const
207  {
208    return k_;
209  }
210
211  template <typename Distance>
212  void KNN<Distance>::k(u_int k)
213  {
214    k_=k;
215  }
216
217
218  template <typename Distance>
219  SupervisedClassifier* 
220  KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
221  {     
222    KNN* knn=0;
223    try {
224      if(data.weighted()) {
225        knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
226                              target);
227      } 
228      else {
229        knn=new KNN<Distance>(dynamic_cast<const MatrixLookup&>(data),
230                              target);
231      }
232      knn->k(this->k());
233    }
234    catch (std::bad_cast) {
235      std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class.";
236      throw std::runtime_error(str);
237    }
238    return knn;
239  }
240 
241 
242  template <typename Distance>
243  bool KNN<Distance>::train()
244  {   
245    trained_=true;
246    return trained_;
247  }
248
249
250  template <typename Distance>
251  void KNN<Distance>::predict(const DataLookup2D& test,                     
252                              utility::matrix& prediction) const
253  {   
254    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
255
256    utility::matrix* distances=calculate_distances(test);
257   
258    // for each test sample (column in distances) find the closest
259    // training samples
260    prediction.clone(utility::matrix(target_.nof_classes(),test.columns(),0.0));
261    for(size_t sample=0;sample<distances->columns();sample++) {
262      std::vector<size_t> k_index;
263      utility::sort_smallest_index(k_index,k_,
264                                   distances->column_const_view(sample));
265      for(size_t j=0;j<k_index.size();j++) {
266        prediction(target_(k_index[j]),sample)++;
267      }
268    }
269    prediction*=(1.0/k_);
270    delete distances;
271  }
272
273}}} // of namespace classifier, yat, and theplu
274
275#endif
276
Note: See TracBrowser for help on using the repository browser.