Changeset 1031 for trunk/yat/classifier


Ignore:
Timestamp:
Feb 4, 2008, 4:44:44 PM (16 years ago)
Author:
Markus Ringnér
Message:

Fixes #272

Location:
trunk/yat/classifier
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/IGP.h

    r1009 r1031  
    3030#include "yat/utility/vector.h"
    3131#include "yat/utility/yat_assert.h"
    32 #include "yat/statistics/vector_distance.h"
     32#include "yat/statistics/distance.h"
    3333
    3434#include <cmath>
     
    9595        DataLookup1D b(matrix_,j,false);
    9696        double dist=statistics::
    97           vector_distance(a.begin,a.end(),b.begin(),
    98                           statistics::vector_distance_traits<Distance>::distace());
     97          distance(a.begin,a.end(),b.begin(),
     98                          statistics::distance_traits<Distance>::distace());
    9999        if(j!=i && dist<mindist) {
    100100          mindist=dist;
  • trunk/yat/classifier/KNN.h

    r1028 r1031  
    3030#include "SupervisedClassifier.h"
    3131#include "Target.h"
    32 #include "yat/statistics/vector_distance.h"
     32#include "yat/statistics/distance.h"
    3333#include "yat/utility/matrix.h"
    3434#include "yat/utility/yat_assert.h"
     
    144144 
    145145  template <typename Distance>
    146   utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const
     146  utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& test) const
    147147  {
    148148    // matrix with training samples as rows and test samples as columns
    149149    utility::matrix* distances =
    150       new utility::matrix(data_.columns(),input.columns());
    151    
    152     // if both training and test are unweighted: unweighted
    153     // calculations are used.
    154     const MatrixLookup* test_unweighted =
    155       dynamic_cast<const MatrixLookup*>(&input);     
    156     if(test_unweighted && !data_.weighted()) {
    157       const MatrixLookup* data_unweighted =
    158         dynamic_cast<const MatrixLookup*>(&data_);     
     150      new utility::matrix(data_.columns(),test.columns());
     151   
     152    // unweighted test data
     153    if(const MatrixLookup* test_unweighted =
     154       dynamic_cast<const MatrixLookup*>(&test)) {     
    159155      for(size_t i=0; i<data_.columns(); i++) {
    160         classifier::DataLookup1D training(*data_unweighted,i,false);
    161         for(size_t j=0; j<input.columns(); j++) {
     156        for(size_t j=0; j<test.columns(); j++) {
    162157          classifier::DataLookup1D test(*test_unweighted,j,false);
    163           utility::yat_assert<std::runtime_error>(training.size()==test.size());
    164158          (*distances)(i,j) =
    165             statistics::vector_distance(training.begin(),training.end(),
    166                                         test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
     159            statistics::distance(classifier::DataLookup1D(data_,
     160                                                          i,false).begin(),
     161                                 classifier::DataLookup1D(data_,
     162                                                          i,false).end(),
     163                                 test.begin(),
     164                                 typename statistics::
     165                                 distance_traits<Distance>::distance());
    167166          utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    168167        }
    169168      }
    170169    }
    171     // if either training or test is weighted: weighted calculations
    172     // are used.
     170    // weighted test data
    173171    else {
    174172      const MatrixLookupWeighted* data_weighted =
    175173        dynamic_cast<const MatrixLookupWeighted*>(&data_);
    176174      const MatrixLookupWeighted* test_weighted =
    177         dynamic_cast<const MatrixLookupWeighted*>(&input);               
     175        dynamic_cast<const MatrixLookupWeighted*>(&test);               
    178176      if(data_weighted && test_weighted) {
    179177        for(size_t i=0; i<data_.columns(); i++) {
    180178          classifier::DataLookupWeighted1D training(*data_weighted,i,false);
    181           for(size_t j=0; j<input.columns(); j++) {
     179          for(size_t j=0; j<test.columns(); j++) {
    182180            classifier::DataLookupWeighted1D test(*test_weighted,j,false);
    183181            utility::yat_assert<std::runtime_error>(training.size()==test.size());
    184182            (*distances)(i,j) =
    185               statistics::vector_distance(training.begin(),training.end(),
    186                                           test.begin(), typename statistics::vector_distance_traits<Distance>::distance());
     183              statistics::distance(training.begin(),training.end(),
     184                                   test.begin(), typename statistics::distance_traits<Distance>::distance());
    187185            utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    188186          }
     
    251249
    252250  template <typename Distance>
    253   void KNN<Distance>::predict(const DataLookup2D& input,                   
     251  void KNN<Distance>::predict(const DataLookup2D& test,                     
    254252                              utility::matrix& prediction) const
    255253  {   
    256     utility::matrix* distances=calculate_distances(input);
     254    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
     255
     256    utility::matrix* distances=calculate_distances(test);
    257257   
    258258    // for each test sample (column in distances) find the closest
    259259    // training samples
    260     prediction.clone(utility::matrix(target_.nof_classes(),input.columns(),0.0));
     260    prediction.clone(utility::matrix(target_.nof_classes(),test.columns(),0.0));
    261261    for(size_t sample=0;sample<distances->columns();sample++) {
    262262      std::vector<size_t> k_index;
  • trunk/yat/classifier/NCC.h

    r1013 r1031  
    3737#include "yat/statistics/Averager.h"
    3838#include "yat/statistics/AveragerWeighted.h"
    39 #include "yat/statistics/vector_distance.h"
     39#include "yat/statistics/distance.h"
    4040
    4141#include "yat/utility/Iterator.h"
     
    112112    // MatrixLookup and MatrixLookupWeighted
    113113    const DataLookup2D& data_;
    114 
     114    bool centroids_nan_;
    115115  };
    116116
     
    125125  template <typename Distance>
    126126  NCC<Distance>::NCC(const MatrixLookup& data, const Target& target)
    127     : SupervisedClassifier(target), centroids_(0), data_(data)
     127    : SupervisedClassifier(target), centroids_(0), data_(data), centroids_nan_(false)
    128128  {
    129129  }
     
    131131  template <typename Distance>
    132132  NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
    133     : SupervisedClassifier(target), centroids_(0), data_(data)
     133    : SupervisedClassifier(target), centroids_(0), data_(data), centroids_nan_(false)
    134134  {
    135135  }
     
    169169                              target);
    170170      }
    171       ncc->centroids_=0;
    172171    }
    173172    catch (std::bad_cast) {
     
    198197        for(size_t c=0;c<target_.nof_classes();c++) {
    199198          (*centroids_)(i,c) = class_averager[c].mean();
     199          if(class_averager[c].sum_w()==0)
     200            centroids_nan_=true;
    200201        }
    201202      }
     
    230231    prediction.clone(utility::matrix(centroids_->columns(), test.columns()));       
    231232
    232     // unweighted test data
     233    // unweighted test data and no nan's in centroids
     234    // Markus: Should test centroid_nan_ here!!!
    233235    if (const MatrixLookup* test_unweighted =
    234236        dynamic_cast<const MatrixLookup*>(&test)) {
     
    240242          utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
    241243          prediction(k,j)=statistics::
    242             vector_distance(in.begin(),in.end(),centroid.begin(),
    243                             typename statistics::vector_distance_traits<Distance>::distance());
     244            distance(in.begin(),in.end(),centroid.begin(),
     245                            typename statistics::distance_traits<Distance>::distance());
    244246        }
    245247      }
     
    255257          utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
    256258          prediction(k,j)=statistics::
    257             vector_distance(in.begin(),in.end(),centroid.begin(),
    258                             typename statistics::vector_distance_traits<Distance>::distance());
     259            distance(in.begin(),in.end(),centroid.begin(),
     260                            typename statistics::distance_traits<Distance>::distance());
    259261        }
    260262      }
Note: See TracChangeset for help on using the changeset viewer.