Changeset 1107


Ignore:
Timestamp:
Feb 19, 2008, 4:23:52 PM (13 years ago)
Author:
Markus Ringnér
Message:

Ticket #259 fixed for KNN

Location:
trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/knn_test.cc

    r1052 r1107  
    2323
    2424#include "yat/classifier/KNN.h"
     25#include "yat/classifier/MatrixLookup.h"
    2526#include "yat/classifier/MatrixLookupWeighted.h"
    2627#include "yat/statistics/EuclideanDistance.h"
    27 #include "yat/statistics/PearsonDistance.h"
    2828#include "yat/utility/matrix.h"
    2929
     
    3838
    3939using namespace theplu::yat;
     40
     41double deviation(const utility::matrix& a, const utility::matrix& b) {
     42  double sl=0;
     43  for (size_t i=0; i<a.rows(); i++){
     44    for (size_t j=0; j<a.columns(); j++){
     45      sl += fabs(a(i,j)-b(i,j));
     46    }
     47  }
     48  sl /= (a.columns()*a.rows());
     49  return sl;
     50}
    4051
    4152int main(const int argc,const char* argv[])
     
    5263  *error << "testing knn" << std::endl;
    5364  bool ok = true;
     65
     66  ////////////////////////////////////////////////////////////////
     67  // A test of training and predictions using unweighted data
     68  ////////////////////////////////////////////////////////////////
     69  *error << "test of predictions using unweighted training and test data\n";
     70  utility::matrix data1(3,4);
     71  for(size_t i=0;i<3;i++) {
     72    data1(i,0)=3-i;
     73    data1(i,1)=5-i;
     74    data1(i,2)=i+1;
     75    data1(i,3)=i+3;
     76  }
     77  std::vector<std::string> vec1(4, "pos");
     78  vec1[0]="neg";
     79  vec1[1]="neg";
    5480 
    55   std::ifstream is("data/sorlie_centroid_data.txt");
    56   utility::matrix data(is,'\t');
    57   is.close();
     81  classifier::MatrixLookup ml1(data1);
     82  classifier::Target target1(vec1);
    5883 
    59   is.open("data/sorlie_centroid_classes.txt");
    60   classifier::Target targets(is);
    61   is.close();
     84  classifier::KNN<statistics::EuclideanDistance> knn1(ml1,target1);
     85  knn1.k(3);
     86  knn1.train();
     87  utility::matrix prediction1;
     88  knn1.predict(ml1,prediction1);
     89  double slack_bound=2e-7;
     90  utility::matrix result1(2,4);
     91  result1(0,0)=result1(0,1)=result1(1,2)=result1(1,3)=2.0/3.0;
     92  result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=1.0/3.0;
     93  double slack = deviation(prediction1,result1);
     94  if (slack > slack_bound || std::isnan(slack)){
     95    *error << "Difference to expected prediction too large\n";
     96    *error << "slack: " << slack << std::endl;
     97    *error << "expected less than " << slack_bound << std::endl;
     98    ok = false;
     99  }
     100 
    62101
    63   // Generate weight matrix with 0 for missing values and 1 for others.
    64   utility::matrix weights(data.rows(),data.columns(),0.0);
    65   utility::nan(data,weights);
     102  ////////////////////////////////////////////////////////////////
     103  // A test of training unweighted and test weighted
     104  ////////////////////////////////////////////////////////////////
     105  *error << "test of predictions using unweighted training and weighted test data\n";
     106  utility::matrix weights1(3,4,1.0);
     107  weights1(2,0)=0;
     108  classifier::MatrixLookupWeighted mlw1(data1,weights1);
     109  knn1.predict(mlw1,prediction1);
     110  result1(0,0)=1.0/3.0;
     111  result1(1,0)=2.0/3.0;
     112  slack = deviation(prediction1,result1);
     113  if (slack > slack_bound || std::isnan(slack)){
     114    *error << "Difference to expected prediction too large\n";
     115    *error << "slack: " << slack << std::endl;
     116    *error << "expected less than " << slack_bound << std::endl;
     117    ok = false;
     118  }
    66119
    67   classifier::MatrixLookupWeighted dataviewweighted(data,weights);
    68   classifier::KNN<statistics::PearsonDistance> knn(dataviewweighted,targets);
    69   *error << "training KNN" << std::endl;
    70   knn.train();
    71  
    72   utility::matrix prediction;
    73   knn.predict(dataviewweighted,prediction);
    74   *error << prediction << std::endl;
    75  
     120  ////////////////////////////////////////////////////////////////
     121  // A test of training and test both weighted
     122  ////////////////////////////////////////////////////////////////
     123  *error << "test of predictions using weighted training and test data\n";
     124  weights1(0,1)=0;
     125  utility::matrix weights2(3,4,1.0);
     126  weights2(2,3)=0;
     127  classifier::MatrixLookupWeighted mlw2(data1,weights2);
     128  classifier::KNN<statistics::EuclideanDistance> knn2(mlw2,target1);
     129  knn2.k(3);
     130  knn2.train();
     131  knn2.predict(mlw1,prediction1);
     132  result1(0,1)=1.0/3.0;
     133  result1(1,1)=2.0/3.0;
     134  slack = deviation(prediction1,result1);
     135  if (slack > slack_bound || std::isnan(slack)){
     136    *error << "Difference to expected prediction too large\n";
     137    *error << "slack: " << slack << std::endl;
     138    *error << "expected less than " << slack_bound << std::endl;
     139    ok = false;
     140  }
     141
     142
    76143  if(!ok) {
    77144    *error << "knn_test failed" << std::endl;
  • trunk/yat/classifier/KNN.h

    r1098 r1107  
    9090   
    9191    ///
    92     /// Train the classifier using the training data. Centroids are
    93     /// calculated for each class.
     92    /// Train the classifier using the training data.
     93    /// This function does nothing but is required by the interface.
    9494    ///
    9595    /// @return true if training succedeed.
     
    9999   
    100100    ///
    101     /// Calculate the distance to each centroid for test samples
     101    /// For each sample, calculate the number of neighbours for each
     102    /// class.
     103    ///
    102104    ///
    103105    void predict(const DataLookup2D&, utility::matrix&) const;
     
    121123    ///
    122124    utility::matrix* calculate_distances(const DataLookup2D&) const;
     125    void calculate_unweighted(const MatrixLookup&,
     126                              const MatrixLookup&,
     127                              utility::matrix*) const;
     128    void calculate_weighted(const MatrixLookupWeighted&,
     129                            const MatrixLookupWeighted&,
     130                            utility::matrix*) const;
    123131  };
    124132 
     
    151159      new utility::matrix(data_.columns(),test.columns());
    152160   
     161   
    153162    // unweighted test data
    154163    if(const MatrixLookup* test_unweighted =
    155164       dynamic_cast<const MatrixLookup*>(&test)) {     
    156       for(size_t i=0; i<data_.columns(); i++) {
    157         for(size_t j=0; j<test.columns(); j++) {
    158           classifier::DataLookup1D test(*test_unweighted,j,false);
    159           classifier::DataLookup1D tmp(data_,i,false);
    160           (*distances)(i,j) = distance_(tmp.begin(), tmp.end(), test.begin());
    161           utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    162         }
    163       }
     165      // unweighted training data
     166      if(const MatrixLookup* training_unweighted =
     167         dynamic_cast<const MatrixLookup*>(&data_))
     168        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
     169      // weighted training data
     170      else if(const MatrixLookupWeighted* training_weighted =
     171              dynamic_cast<const MatrixLookupWeighted*>(&data_))
     172        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
     173                           distances);             
     174      // Training data can not be of incorrect type
    164175    }
    165176    // weighted test data
     177    else if (const MatrixLookupWeighted* test_weighted =
     178             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
     179      // unweighted training data
     180      if(const MatrixLookup* training_unweighted =
     181         dynamic_cast<const MatrixLookup*>(&data_)) {
     182        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
     183                           *test_weighted,distances);
     184      }
     185      // weighted training data
     186      else if(const MatrixLookupWeighted* training_weighted =
     187              dynamic_cast<const MatrixLookupWeighted*>(&data_))
     188        calculate_weighted(*training_weighted,*test_weighted,distances);             
     189      // Training data can not be of incorrect type
     190    }
    166191    else {
    167       const MatrixLookupWeighted* data_weighted =
    168         dynamic_cast<const MatrixLookupWeighted*>(&data_);
    169       const MatrixLookupWeighted* test_weighted =
    170         dynamic_cast<const MatrixLookupWeighted*>(&test);               
    171       if(data_weighted && test_weighted) {
    172         for(size_t i=0; i<data_.columns(); i++) {
    173           classifier::DataLookupWeighted1D training(*data_weighted,i,false);
    174           for(size_t j=0; j<test.columns(); j++) {
    175             classifier::DataLookupWeighted1D test(*test_weighted,j,false);
    176             utility::yat_assert<std::runtime_error>(training.size()==test.size());
    177             (*distances)(i,j) = distance_(training.begin(), training.end(),
    178                                           test.begin());
    179             utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    180           }
    181         }
    182       }
    183       else {
    184         std::string str;
    185         str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted";
    186         throw std::runtime_error(str);
    187       }
     192      std::string str;
     193      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
     194      throw std::runtime_error(str);
    188195    }
    189196    return distances;
    190197  }
     198
     199  template <typename Distance>
     200  void  KNN<Distance>:: calculate_unweighted(const MatrixLookup& training,
     201                                             const MatrixLookup& test,
     202                                             utility::matrix* distances) const
     203  {
     204    for(size_t i=0; i<training.columns(); i++) {
     205      classifier::DataLookup1D training1(training,i,false);
     206      for(size_t j=0; j<test.columns(); j++) {
     207        classifier::DataLookup1D test1(test,j,false);
     208        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
     209        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
     210      }
     211    }
     212  }
     213 
     214  template <typename Distance>
     215  void  KNN<Distance>:: calculate_weighted(const MatrixLookupWeighted& training,
     216                                           const MatrixLookupWeighted& test,
     217                                           utility::matrix* distances) const
     218  {
     219    for(size_t i=0; i<training.columns(); i++) {
     220      classifier::DataLookupWeighted1D training1(training,i,false);
     221      for(size_t j=0; j<test.columns(); j++) {
     222        classifier::DataLookupWeighted1D test1(test,j,false);
     223        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
     224        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
     225      }
     226    }
     227  }
     228
    191229 
    192230  template <typename Distance>
Note: See TracChangeset for help on using the changeset viewer.