Ignore:
Timestamp:
Feb 4, 2008, 4:44:44 PM (14 years ago)
Author:
Markus Ringnér
Message:

Fixes #272

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/KNN.h

    r1028 r1031  
    3030#include "SupervisedClassifier.h"
    3131#include "Target.h"
    32 #include "yat/statistics/vector_distance.h"
     32#include "yat/statistics/distance.h"
    3333#include "yat/utility/matrix.h"
    3434#include "yat/utility/yat_assert.h"
     
    144144 
    145145  template <typename Distance>
    146   utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const
     146  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& test) const
    147147  {
    148148    // matrix with training samples as rows and test samples as columns
    149149    utility::matrix* distances =
    150       new utility::matrix(data_.columns(),input.columns());
    151    
    152     // if both training and test are unweighted: unweighted
    153     // calculations are used.
    154     const MatrixLookup* test_unweighted =
    155       dynamic_cast<const MatrixLookup*>(&input);     
    156     if(test_unweighted && !data_.weighted()) {
    157       const MatrixLookup* data_unweighted =
    158         dynamic_cast<const MatrixLookup*>(&data_);     
     150      new utility::matrix(data_.columns(),test.columns());
     151   
     152    // unweighted test data
     153    if(const MatrixLookup* test_unweighted =
     154       dynamic_cast<const MatrixLookup*>(&test)) {     
    159155      for(size_t i=0; i<data_.columns(); i++) {
    160         classifier::DataLookup1D training(*data_unweighted,i,false);
    161         for(size_t j=0; j<input.columns(); j++) {
     156        for(size_t j=0; j<test.columns(); j++) {
    162157          classifier::DataLookup1D test(*test_unweighted,j,false);
    163           utility::yat_assert<std::runtime_error>(training.size()==test.size());
    164158          (*distances)(i,j) =
    165             statistics::vector_distance(training.begin(),training.end(),
    166                                         test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
     159            statistics::distance(classifier::DataLookup1D(data_,
     160                                                          i,false).begin(),
     161                                 classifier::DataLookup1D(data_,
     162                                                          i,false).end(),
     163                                 test.begin(),
     164                                 typename statistics::
     165                                 distance_traits<Distance>::distance());
    167166          utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    168167        }
    169168      }
    170169    }
    171     // if either training or test is weighted: weighted calculations
    172     // are used.
     170    // weighted test data
    173171    else {
    174172      const MatrixLookupWeighted* data_weighted =
    175173        dynamic_cast<const MatrixLookupWeighted*>(&data_);
    176174      const MatrixLookupWeighted* test_weighted =
    177         dynamic_cast<const MatrixLookupWeighted*>(&input);               
     175        dynamic_cast<const MatrixLookupWeighted*>(&test);               
    178176      if(data_weighted && test_weighted) {
    179177        for(size_t i=0; i<data_.columns(); i++) {
    180178          classifier::DataLookupWeighted1D training(*data_weighted,i,false);
    181           for(size_t j=0; j<input.columns(); j++) {
     179          for(size_t j=0; j<test.columns(); j++) {
    182180            classifier::DataLookupWeighted1D test(*test_weighted,j,false);
    183181            utility::yat_assert<std::runtime_error>(training.size()==test.size());
    184182            (*distances)(i,j) =
    185               statistics::vector_distance(training.begin(),training.end(),
    186                                           test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
     183              statistics::distance(training.begin(),training.end(),
     184                                   test.begin(), typename statistics::distance_traits<Distance>::distance());
    187185            utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    188186          }
     
    251249
    252250  template <typename Distance>
    253   void KNN<Distance>::predict(const DataLookup2D& input,                   
     251  void KNN<Distance>::predict(const DataLookup2D& test,                     
    254252                              utility::matrix& prediction) const
    255253  {   
    256     utility::matrix* distances=calculate_distances(input);
     254    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
     255
     256    utility::matrix* distances=calculate_distances(test);
    257257   
    258258    // for each test sample (column in distances) find the closest
    259259    // training samples
    260     prediction.clone(utility::matrix(target_.nof_classes(),input.columns(),0.0));
     260    prediction.clone(utility::matrix(target_.nof_classes(),test.columns(),0.0));
    261261    for(size_t sample=0;sample<distances->columns();sample++) {
    262262      std::vector<size_t> k_index;
Note: See TracChangeset for help on using the changeset viewer.