Changeset 1112 for trunk/yat


Ignore:
Timestamp:
Feb 21, 2008, 3:59:30 PM (14 years ago)
Author:
Markus Ringnér
Message:

Mostly related to #295 and #182

Location:
trunk/yat/classifier
Files:
6 added
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/KNN.h

    r1107 r1112  
    2727#include "DataLookup1D.h"
    2828#include "DataLookupWeighted1D.h"
     29#include "KNN_Uniform.h"
    2930#include "MatrixLookup.h"
    3031#include "MatrixLookupWeighted.h"
     
    4344
    4445  ///
    45   /// @brief Class for Nearest Centroid Classification.
    46   ///
    47  
    48  
    49   template <typename Distance>
     46  /// @brief Class for Nearest Neigbor Classification.
     47  ///
     48  /// The template argument Distance should be a class implementing
     49  /// the concept \ref concept_distance.
     50  /// The template argument NeigborWeighting should be a class implementing
     51  /// the concept \ref concept_neighbor_weighting.
     52
     53  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
    5054  class KNN : public SupervisedClassifier
    5155  {
     
    7478
    7579    ///
    76     /// Default number of neighbours (k) is set to 3.
    77     ///
    78     /// @return the number of neighbours
     80    /// Default number of neighbors (k) is set to 3.
     81    ///
     82    /// @return the number of neighbors
    7983    ///
    8084    u_int k() const;
    8185
    8286    ///
    83     /// @brief sets the number of neighbours, k.
     87    /// @brief sets the number of neighbors, k.
    8488    ///
    8589    void k(u_int);
     
    99103   
    100104    ///
    101     /// For each sample, calculate the number of neighbours for each
     105    /// For each sample, calculate the number of neighbors for each
    102106    /// class.
    103107    ///
     
    112116    const DataLookup2D& data_;
    113117
    114     // The number of neighbours
     118    // The number of neighbors
    115119    u_int k_;
    116120
    117121    Distance distance_;
     122
     123    NeighborWeighting weighting_;
     124
    118125    ///
    119126    /// Calculates the distances between a data set and the training
     
    123130    ///
    124131    utility::matrix* calculate_distances(const DataLookup2D&) const;
     132
    125133    void calculate_unweighted(const MatrixLookup&,
    126134                              const MatrixLookup&,
     
    134142  // templates
    135143 
    136   template <typename Distance>
    137   KNN<Distance>::KNN(const MatrixLookup& data, const Target& target)
     144  template <typename Distance, typename NeighborWeighting>
     145  KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target)
    138146    : SupervisedClassifier(target), data_(data),k_(3)
    139147  {
     
    141149
    142150
    143   template <typename Distance>
    144   KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target)
     151  template <typename Distance, typename NeighborWeighting>
     152  KNN<Distance, NeighborWeighting>::KNN
     153  (const MatrixLookupWeighted& data, const Target& target)
    145154    : SupervisedClassifier(target), data_(data),k_(3)
    146155  {
    147156  }
    148157 
    149   template <typename Distance>
    150   KNN<Distance>::~KNN()   
    151   {
    152   }
    153  
    154   template <typename Distance>
    155   utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& test) const
     158  template <typename Distance, typename NeighborWeighting>
     159  KNN<Distance, NeighborWeighting>::~KNN()   
     160  {
     161  }
     162 
     163  template <typename Distance, typename NeighborWeighting>
     164  utility::matrix* KNN<Distance, NeighborWeighting>::calculate_distances
     165  (const DataLookup2D& test) const
    156166  {
    157167    // matrix with training samples as rows and test samples as columns
     
    197207  }
    198208
    199   template <typename Distance>
    200   void  KNN<Distance>:: calculate_unweighted(const MatrixLookup& training,
    201                                             const MatrixLookup& test,
    202                                             utility::matrix* distances) const
     209  template <typename Distance, typename NeighborWeighting>
     210  void  KNN<Distance, NeighborWeighting>::calculate_unweighted
     211  (const MatrixLookup& training, const MatrixLookup& test,
     212  utility::matrix* distances) const
    203213  {
    204214    for(size_t i=0; i<training.columns(); i++) {
     
    212222  }
    213223 
    214   template <typename Distance>
    215   void  KNN<Distance>:: calculate_weighted(const MatrixLookupWeighted& training,
    216                                            const MatrixLookupWeighted& test,
    217                                            utility::matrix* distances) const
     224  template <typename Distance, typename NeighborWeighting>
     225  void 
     226  KNN<Distance, NeighborWeighting>::calculate_weighted
     227  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
     228   utility::matrix* distances) const
    218229  {
    219230    for(size_t i=0; i<training.columns(); i++) {
     
    221232      for(size_t j=0; j<test.columns(); j++) {
    222233        classifier::DataLookupWeighted1D test1(test,j,false);
    223         (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
     234        (*distances)(i,j) = distance_(training1.begin(), training1.end(),
     235                                      test1.begin());
    224236        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    225237      }
     
    228240
    229241 
    230   template <typename Distance>
    231   const DataLookup2D& KNN<Distance>::data(void) const
     242  template <typename Distance, typename NeighborWeighting>
     243  const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const
    232244  {
    233245    return data_;
     
    235247 
    236248 
    237   template <typename Distance>
    238   u_int KNN<Distance>::k() const
     249  template <typename Distance, typename NeighborWeighting>
     250  u_int KNN<Distance, NeighborWeighting>::k() const
    239251  {
    240252    return k_;
    241253  }
    242254
    243   template <typename Distance>
    244   void KNN<Distance>::k(u_int k)
     255  template <typename Distance, typename NeighborWeighting>
     256  void KNN<Distance, NeighborWeighting>::k(u_int k)
    245257  {
    246258    k_=k;
     
    248260
    249261
    250   template <typename Distance>
     262  template <typename Distance, typename NeighborWeighting>
    251263  SupervisedClassifier*
    252   KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const
     264  KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data,
     265                                                    const Target& target) const
    253266  {     
    254267    KNN* knn=0;
    255268    try {
    256269      if(data.weighted()) {
    257         knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
    258                               target);
     270        knn=new KNN<Distance, NeighborWeighting>
     271          (dynamic_cast<const MatrixLookupWeighted&>(data),target);
    259272      } 
    260273      else {
    261         knn=new KNN<Distance>(dynamic_cast<const MatrixLookup&>(data),
    262                               target);
     274        knn=new KNN<Distance, NeighborWeighting>
     275          (dynamic_cast<const MatrixLookup&>(data),target);
    263276      }
    264277      knn->k(this->k());
    265278    }
    266279    catch (std::bad_cast) {
    267       std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class.";
     280      std::string str = "Error in KNN<Distance, NeighborWeighting>";
     281      str += "::make_classifier: DataLookup2D of unexpected class.";
    268282      throw std::runtime_error(str);
    269283    }
     
    272286 
    273287 
    274   template <typename Distance>
    275   void KNN<Distance>::train()
     288  template <typename Distance, typename NeighborWeighting>
     289  void KNN<Distance, NeighborWeighting>::train()
    276290  {   
    277291    trained_=true;
     
    279293
    280294
    281   template <typename Distance>
    282   void KNN<Distance>::predict(const DataLookup2D& test,                     
    283                               utility::matrix& prediction) const
     295  template <typename Distance, typename NeighborWeighting>
     296  void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test,
     297                                                 utility::matrix& prediction) const
    284298  {   
    285299    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
     
    287301    utility::matrix* distances=calculate_distances(test);
    288302   
    289     // for each test sample (column in distances) find the closest
    290     // training samples
    291303    prediction.resize(target_.nof_classes(),test.columns(),0.0);
    292304    for(size_t sample=0;sample<distances->columns();sample++) {
    293305      std::vector<size_t> k_index;
    294       utility::sort_smallest_index(k_index,k_,
    295                                    distances->column_const_view(sample));
    296       for(size_t j=0;j<k_index.size();j++) {
    297         prediction(target_(k_index[j]),sample)++;
    298       }
    299     }
    300     prediction*=(1.0/k_);
     306      utility::VectorConstView dist=distances->column_const_view(sample);
     307      utility::sort_smallest_index(k_index,k_,dist);
     308      utility::VectorView pred=prediction.column_view(sample);
     309      weighting_(dist,k_index,target_,pred);
     310    }
    301311    delete distances;
    302312  }
  • trunk/yat/classifier/Makefile.am

    r1079 r1112  
    4242  Kernel_SEV.cc \
    4343  KernelLookup.cc \
     44  KNN_Uniform.cc \
     45  KNN_ReciprocalDistance.cc \
     46  KNN_ReciprocalRank.cc \
    4447  MatrixLookup.cc \
    4548  MatrixLookupWeighted.cc \
     
    7780  Kernel_SEV.h \
    7881  KNN.h \
     82  KNN_Uniform.h \
    7983  MatrixLookup.h \
    8084  MatrixLookupWeighted.h \
  • trunk/yat/classifier/NCC.h

    r1098 r1112  
    5656  /// @brief Class for Nearest Centroid Classification.
    5757  ///
    58 
     58  /// The template argument Distance should be a class implementing
     59  /// the concept \ref concept_distance.
     60  ///
    5961  template <typename Distance>
    6062  class NCC : public SupervisedClassifier
Note: See TracChangeset for help on using the changeset viewer.