source: trunk/yat/classifier/KNN.h @ 3553

Last change on this file since 3553 was 3553, checked in by Peter, 6 years ago

reactivate concept assert

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 13.6 KB
Line 
1#ifndef _theplu_yat_classifier_knn_
2#define _theplu_yat_classifier_knn_
3
4// $Id: KNN.h 3553 2017-01-03 07:56:01Z peter $
5
6/*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009, 2010 Peter Johansson
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "DataLookup1D.h"
27#include "DataLookupWeighted1D.h"
28#include "KNN_Uniform.h"
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33#include "yat/utility/concept_check.h"
34#include "yat/utility/Exception.h"
35#include "yat/utility/Matrix.h"
36#include "yat/utility/Vector.h"
37#include "yat/utility/VectorConstView.h"
38#include "yat/utility/VectorView.h"
39#include "yat/utility/yat_assert.h"
40
41#include <boost/concept_check.hpp>
42
43#include <cmath>
44#include <limits>
45#include <map>
46#include <stdexcept>
47#include <vector>
48
49namespace theplu {
50namespace yat {
51namespace classifier {
52
53  /**
54     \brief Nearest Neighbor Classifier
55
56     A sample is predicted based on the classes of its k nearest
57     neighbors among the training data samples. KNN supports using
58     different measures, for example, Euclidean distance, to define
59     distance between samples. KNN also supports using different ways to
60     weight the votes of the k nearest neighbors. For example, using a
61     uniform vote a test sample gets a vote for each class which is the
62     number of nearest neighbors belonging to the class.
63
64     The template argument Distance should be a class modelling the
65     concept \ref concept_distance. The template argument
66     NeighborWeighting should be a class modelling the concept \ref
67     concept_neighbor_weighting.
68  */
69  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
70  class KNN : public SupervisedClassifier
71  {
72
73  public:
74    /**
75       \brief Default constructor.
76
77       The number of nearest neighbors (k) is set to 3. Distance and
78       NeighborWeighting are initialized using their default
79       constructuors.
80    */
81    KNN(void);
82
83
84    /**
85       \brief Constructor using an intialized distance measure.
86
87       The number of nearest neighbors (k) is set to
88       3. NeighborWeighting is initialized using its default
89       constructor. This constructor should be used if Distance has
90       parameters and the user wants to specify the parameters by
91       initializing Distance prior to constructing the KNN.
92    */
93    KNN(const Distance&);
94
95
96    /**
97       Destructor
98    */
99    virtual ~KNN();
100
101
102    /**
103       \brief Get the number of nearest neighbors.
104       \return The number of neighbors.
105    */
106    unsigned int k() const;
107
108    /**
109       \brief Set the number of nearest neighbors.
110
111       Sets the number of neighbors to \a k_in.
112    */
113    void k(unsigned int k_in);
114
115
116    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
117
118    /**
119       \brief Make predictions for unweighted test data.
120
121       Predictions are calculated and returned in \a results.  For
122       each sample in \a data, \a results contains the weighted number
123       of nearest neighbors which belong to each class. Numbers of
124       nearest neighbors are weighted according to
125       NeighborWeighting. If a class has no training samples NaN's are
126       returned for this class in \a results.
127    */
128    void predict(const MatrixLookup& data , utility::Matrix& results) const;
129
130    /**
131        \brief Make predictions for weighted test data.
132
133        Predictions are calculated and returned in \a results. For
134        each sample in \a data, \a results contains the weighted
135        number of nearest neighbors which belong to each class as in
136        predict(const MatrixLookup& data, utility::Matrix& results).
137        If a test and training sample pair has no variables with
138        non-zero weights in common, there are no variables which can
139        be used to calculate the distance between the two samples. In
140        this case the distance between the two is set to infinity.
141    */
142    void predict(const MatrixLookupWeighted& data,
143                 utility::Matrix& results) const;
144
145
146    /**
147       \brief Train the KNN using unweighted training data with known
148       targets.
149
150       For KNN there is no actual training; the entire training data
151       set is stored with targets. KNN only stores references to \a data
152       and \a targets as copying these would make the %classifier
153       slow. If the number of training samples set is smaller than k,
154       k is set to the number of training samples.
155
156       \note If \a data or \a targets go out of scope ore are
157       deleted, the KNN becomes invalid and further use is undefined
158       unless it is trained again.
159    */
160    void train(const MatrixLookup& data, const Target& targets);
161
162    /**
163       \brief Train the KNN using weighted training data with known targets.
164
165       See train(const MatrixLookup& data, const Target& targets) for
166       additional information.
167    */
168    void train(const MatrixLookupWeighted& data, const Target& targets);
169
170  private:
171
172    const MatrixLookup* data_ml_;
173    const MatrixLookupWeighted* data_mlw_;
174    const Target* target_;
175
176    // The number of neighbors
177    unsigned int k_;
178
179    Distance distance_;
180    NeighborWeighting weighting_;
181
182    void calculate_unweighted(const MatrixLookup&,
183                              const MatrixLookup&,
184                              utility::Matrix*) const;
185    void calculate_weighted(const MatrixLookupWeighted&,
186                            const MatrixLookupWeighted&,
187                            utility::Matrix*) const;
188
189    void predict_common(const utility::Matrix& distances,
190                        utility::Matrix& prediction) const;
191
192  };
193
194
195  /**
196     \brief Concept check for a \ref concept_neighbor_weighting
197
198     This class is intended to be used in a <a
199     href="\boost_url/concept_check/using_concept_check.htm">
200     BOOST_CONCEPT_ASSERT </a>
201
202     \code
203     template<class Distance>
204     void some_function(double x)
205     {
206     BOOST_CONCEPT_ASSERT((DistanceConcept<Distance>));
207     ...
208     }
209     \endcode
210
211     \since New in yat 0.7
212  */
213  template <class T>
214  class NeighborWeightingConcept
215    : public boost::DefaultConstructible<T>, public boost::Assignable<T>
216  {
217  public:
218    /**
219       \brief function doing the concept test
220     */
221    BOOST_CONCEPT_USAGE(NeighborWeightingConcept)
222    {
223      T neighbor_weighting;
224      utility::Vector vec;
225      const utility::VectorBase& distance(vec);
226      utility::VectorMutable& prediction(vec);
227      std::vector<size_t> k_sorted;
228      Target target;
229      neighbor_weighting(distance, k_sorted, target, prediction);
230    }
231  private:
232  };
233
234  // template implementation
235
236  template <typename Distance, typename NeighborWeighting>
237  KNN<Distance, NeighborWeighting>::KNN()
238    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
239  {
240    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
241    BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
242  }
243
244  template <typename Distance, typename NeighborWeighting>
245  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist)
246    : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3),
247      distance_(dist)
248  {
249    BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
250    BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
251  }
252
253
254  template <typename Distance, typename NeighborWeighting>
255  KNN<Distance, NeighborWeighting>::~KNN()
256  {
257  }
258
259
260  template <typename Distance, typename NeighborWeighting>
261  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
262  (const MatrixLookup& training, const MatrixLookup& test,
263   utility::Matrix* distances) const
264  {
265    for(size_t i=0; i<training.columns(); i++) {
266      for(size_t j=0; j<test.columns(); j++) {
267        (*distances)(i,j) = distance_(training.begin_column(i),
268                                      training.end_column(i),
269                                      test.begin_column(j));
270        YAT_ASSERT(!std::isnan((*distances)(i,j)));
271      }
272    }
273  }
274
275
276  template <typename Distance, typename NeighborWeighting>
277  void
278  KNN<Distance, NeighborWeighting>::calculate_weighted
279  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
280   utility::Matrix* distances) const
281  {
282    for(size_t i=0; i<training.columns(); i++) {
283      for(size_t j=0; j<test.columns(); j++) {
284        (*distances)(i,j) = distance_(training.begin_column(i),
285                                      training.end_column(i),
286                                      test.begin_column(j));
287        // If the distance is NaN (no common variables with non-zero weights),
288        // the distance is set to infinity to be sorted as a neighbor at the end
289        if(std::isnan((*distances)(i,j)))
290          (*distances)(i,j)=std::numeric_limits<double>::infinity();
291      }
292    }
293  }
294
295
296  template <typename Distance, typename NeighborWeighting>
297  unsigned int KNN<Distance, NeighborWeighting>::k() const
298  {
299    return k_;
300  }
301
302  template <typename Distance, typename NeighborWeighting>
303  void KNN<Distance, NeighborWeighting>::k(unsigned int k)
304  {
305    k_=k;
306  }
307
308
309  template <typename Distance, typename NeighborWeighting>
310  KNN<Distance, NeighborWeighting>*
311  KNN<Distance, NeighborWeighting>::make_classifier() const
312  {
313    // All private members should be copied here to generate an
314    // identical but untrained classifier
315    KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
316    knn->weighting_=this->weighting_;
317    knn->k(this->k());
318    return knn;
319  }
320
321
322  template <typename Distance, typename NeighborWeighting>
323  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data,
324                                               const Target& target)
325  {
326    utility::yat_assert<utility::runtime_error>
327      (data.columns()==target.size(),
328       "KNN::train called with different sizes of target and data");
329    // k has to be at most the number of training samples.
330    if(data.columns()<k_)
331      k_=data.columns();
332    data_ml_=&data;
333    data_mlw_=0;
334    target_=&target;
335  }
336
337  template <typename Distance, typename NeighborWeighting>
338  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data,
339                                               const Target& target)
340  {
341    utility::yat_assert<utility::runtime_error>
342      (data.columns()==target.size(),
343       "KNN::train called with different sizes of target and data");
344    // k has to be at most the number of training samples.
345    if(data.columns()<k_)
346      k_=data.columns();
347    data_ml_=0;
348    data_mlw_=&data;
349    target_=&target;
350  }
351
352
353  template <typename Distance, typename NeighborWeighting>
354  void
355  KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
356                                            utility::Matrix& prediction) const
357  {
358    // matrix with training samples as rows and test samples as columns
359    utility::Matrix* distances = 0;
360    // unweighted training data
361    if(data_ml_ && !data_mlw_) {
362      utility::yat_assert<utility::runtime_error>
363        (data_ml_->rows()==test.rows(),
364         "KNN::predict different number of rows in training and test data");
365      distances=new utility::Matrix(data_ml_->columns(),test.columns());
366      calculate_unweighted(*data_ml_,test,distances);
367    }
368    else if (data_mlw_ && !data_ml_) {
369      // weighted training data
370      utility::yat_assert<utility::runtime_error>
371        (data_mlw_->rows()==test.rows(),
372         "KNN::predict different number of rows in training and test data");
373      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
374      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
375                         distances);
376    }
377    else {
378      throw utility::runtime_error("KNN::predict no training data");
379    }
380
381    prediction.resize(target_->nof_classes(),test.columns(),0.0);
382    predict_common(*distances,prediction);
383    if(distances)
384      delete distances;
385  }
386
387  template <typename Distance, typename NeighborWeighting>
388  void
389  KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
390                                            utility::Matrix& prediction) const
391  {
392    // matrix with training samples as rows and test samples as columns
393    utility::Matrix* distances=0;
394    // unweighted training data
395    if(data_ml_ && !data_mlw_) {
396      utility::yat_assert<utility::runtime_error>
397        (data_ml_->rows()==test.rows(),
398         "KNN::predict different number of rows in training and test data");
399      distances=new utility::Matrix(data_ml_->columns(),test.columns());
400      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);
401    }
402    // weighted training data
403    else if (data_mlw_ && !data_ml_) {
404      utility::yat_assert<utility::runtime_error>
405        (data_mlw_->rows()==test.rows(),
406         "KNN::predict different number of rows in training and test data");
407      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
408      calculate_weighted(*data_mlw_,test,distances);
409    }
410    else {
411      throw utility::runtime_error("KNN::predict no training data");
412    }
413
414    prediction.resize(target_->nof_classes(),test.columns(),0.0);
415    predict_common(*distances,prediction);
416
417    if(distances)
418      delete distances;
419  }
420
421  template <typename Distance, typename NeighborWeighting>
422  void KNN<Distance, NeighborWeighting>::predict_common
423  (const utility::Matrix& distances, utility::Matrix& prediction) const
424  {
425    for(size_t sample=0;sample<distances.columns();sample++) {
426      std::vector<size_t> k_index;
427      utility::VectorConstView dist=distances.column_const_view(sample);
428      utility::sort_smallest_index(k_index,k_,dist);
429      utility::VectorView pred=prediction.column_view(sample);
430      weighting_(dist,k_index,*target_,pred);
431    }
432
433    // classes for which there are no training samples should be set
434    // to nan in the predictions
435    for(size_t c=0;c<target_->nof_classes(); c++)
436      if(!target_->size(c))
437        for(size_t j=0;j<prediction.columns();j++)
438          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
439  }
440}}} // of namespace classifier, yat, and theplu
441
442#endif
Note: See TracBrowser for help on using the repository browser.