Changeset 1586


Ignore:
Timestamp:
Oct 16, 2008, 11:06:26 PM (13 years ago)
Author:
Peter
Message:

fixing knn_test - refs #396

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/knn_test.cc

    r1487 r1586  
    2929#include "yat/classifier/MatrixLookupWeighted.h"
    3030#include "yat/statistics/EuclideanDistance.h"
     31#include "yat/utility/DataIterator.h"
    3132#include "yat/utility/Matrix.h"
     33#include "yat/utility/MatrixWeighted.h"
    3234
    3335
     
    4244using namespace theplu::yat;
    4345
     46utility::Matrix data(void);
     47utility::MatrixWeighted data_weighted(void);
     48double deviation(const utility::Matrix& a, const utility::Matrix& b);
     49void test_unweighted(test::Suite&);
     50void test_unweighted_weighted(test::Suite&);
     51void test_weighted(test::Suite&);
     52void test_reciprocal_ranks(test::Suite&);
     53void test_reciprocal_distance(test::Suite&);
     54void test_no_samples(test::Suite&);
     55void test_no_features(test::Suite&);
     56std::vector<std::string> vec_target(void);
     57
     58int main(int argc, char* argv[])
     59
     60  test::Suite suite(argc, argv);
     61  suite.err() << "testing knn" << std::endl;
     62  test_unweighted(suite);
     63  test_unweighted_weighted(suite);
     64  test_weighted(suite);
     65  test_reciprocal_ranks(suite);
     66  test_reciprocal_distance(suite);
     67  test_no_samples(suite);
     68  test_no_features(suite);
     69  return suite.return_value();
     70}
     71
     72
     73utility::Matrix data(void)
     74{
     75  utility::Matrix data1(3,4);
     76  for(size_t i=0;i<3;i++) {
     77    data1(i,0)=3-i;
     78    data1(i,1)=5-i;
     79    data1(i,2)=i+1;
     80    data1(i,3)=i+3;
     81  }
     82  return data1;
     83}
     84
     85
     86utility::MatrixWeighted data_weighted(void)
     87{
     88  utility::Matrix x = data();
     89  utility::MatrixWeighted result(x.rows(), x.columns());
     90  std::copy(x.begin(), x.end(), utility::data_iterator(result.begin()));
     91  return result;
     92}
     93
     94
    4495double deviation(const utility::Matrix& a, const utility::Matrix& b) {
     96  assert(a.rows()==b.rows());
     97  assert(b.columns()==b.columns());
    4598  double sl=0;
    4699  for (size_t i=0; i<a.rows(); i++){
     
    53106}
    54107
    55 int main(int argc, char* argv[])
    56 
    57   test::Suite suite(argc, argv);
    58   suite.err() << "testing knn" << std::endl;
    59 
     108void test_unweighted(test::Suite& suite)
     109{
    60110  ////////////////////////////////////////////////////////////////
    61111  // A test of training and predictions using unweighted data
     
    63113  suite.err() << "test of predictions using unweighted training "
    64114              << "and test data\n";
    65   utility::Matrix data1(3,4);
    66   for(size_t i=0;i<3;i++) {
    67     data1(i,0)=3-i;
    68     data1(i,1)=5-i;
    69     data1(i,2)=i+1;
    70     data1(i,3)=i+3;
    71   }
    72   std::vector<std::string> vec1(4, "pos");
    73   vec1[0]="neg";
    74   vec1[1]="neg";
    75  
     115  utility::Matrix data1 = data();
    76116  classifier::MatrixLookup ml1(data1);
    77   classifier::Target target1(vec1);
     117  classifier::Target target1(vec_target());
    78118 
    79119  classifier::KNN<statistics::EuclideanDistance> knn1;
     
    87127  suite.add(suite.equal_range(result1.begin(), result1.end(),
    88128                              prediction1.begin(), 1));
     129}
     130
     131void test_unweighted_weighted(test::Suite& suite)
     132{
     133  suite.err() << "test of predictions using unweighted training "
     134              << "and weighted test data\n";
     135  utility::MatrixWeighted xw = data_weighted();
     136  xw(2,0).weight()=0;
    89137 
    90 
    91   ////////////////////////////////////////////////////////////////
    92   // A test of training unweighted and test weighted
    93   ////////////////////////////////////////////////////////////////
    94   suite.err() << "test of predictions using unweighted training and weighted test data\n";
    95   utility::Matrix weights1(3,4,1.0);
    96   weights1(2,0)=0;
    97   classifier::MatrixLookupWeighted mlw1(data1,weights1);
     138  classifier::MatrixLookupWeighted mlw1(xw);
     139  classifier::KNN<statistics::EuclideanDistance> knn1;
     140  knn1.k(3);
     141  utility::Matrix data1 = data();
     142  classifier::MatrixLookup ml1(data1);
     143  classifier::Target target1(vec_target());
     144  knn1.train(ml1,target1);
     145  utility::Matrix prediction1;
    98146  knn1.predict(mlw1,prediction1);
     147  utility::Matrix result1(2,4);
     148  result1(0,0)=result1(0,1)=result1(1,2)=result1(1,3)=2.0;
     149  result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=1.0;
    99150  result1(0,0)=1.0;
    100151  result1(1,0)=2.0;
    101152  suite.add(suite.equal_range(result1.begin(), result1.end(),
    102153                              prediction1.begin(), 1));
    103 
     154}
     155
     156void test_weighted(test::Suite& suite)
     157{
    104158  ////////////////////////////////////////////////////////////////
    105159  // A test of training and test both weighted
     
    107161  suite.err() << "test of predictions using weighted training and test data\n";
    108162  suite.err() << "... uniform neighbor weighting" << std::endl;
    109   weights1(0,1)=0;
    110   utility::Matrix weights2(3,4,1.0);
    111   weights2(2,3)=0;
    112   classifier::MatrixLookupWeighted mlw2(data1,weights2);
     163  utility::MatrixWeighted xw = data_weighted();
     164  xw(2,0).weight()=0;
     165  xw(0,1).weight()=0;
     166  classifier::MatrixLookupWeighted mlw1(xw);
     167   
     168  utility::MatrixWeighted xw2 = data_weighted();
     169  xw2(2,3).weight()=0;
     170  classifier::MatrixLookupWeighted mlw2(xw2);
    113171  classifier::KNN<statistics::EuclideanDistance> knn2;
    114172  knn2.k(3);
     173  classifier::Target target1(vec_target());
    115174  knn2.train(mlw2,target1);
     175  utility::Matrix prediction1;
    116176  knn2.predict(mlw1,prediction1);
     177  utility::Matrix result1(2,4);
     178  result1(0,0)=result1(0,1)=result1(1,2)=result1(1,3)=2.0;
     179  result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=1.0;
     180  result1(0,0)=1.0;
     181  result1(1,0)=2.0;
    117182  result1(0,1)=1.0;
    118183  result1(1,1)=2.0;
    119184  suite.add(suite.equal_range(result1.begin(), result1.end(),
    120185                              prediction1.begin(), 1));
    121 
    122 
     186}
     187
     188
     189void test_reciprocal_ranks(test::Suite& suite)
     190{
    123191  ////////////////////////////////////////////////////////////////
    124192  // A test of reciprocal ranks weighting with training and test both weighted
    125193  ////////////////////////////////////////////////////////////////
    126194  suite.err() << "... reciprokal rank neighbor weighting" << std::endl;
    127   utility::Matrix data2(data1);
    128   data2(1,3)=7;
    129   classifier::MatrixLookupWeighted mlw3(data2,weights2);
    130   classifier::KNN<statistics::EuclideanDistance,classifier::KNN_ReciprocalRank> knn3;
     195  utility::MatrixWeighted xw2 = data_weighted();
     196  xw2(2,3).weight()=0;
     197  classifier::MatrixLookupWeighted mlw2(xw2);
     198  utility::MatrixWeighted xw3 = data_weighted();
     199  xw3(1,3).data()=7;
     200  xw3(2,3).weight()=0;
     201  classifier::MatrixLookupWeighted mlw3(xw3);
     202  classifier::KNN<statistics::EuclideanDistance
     203    ,classifier::KNN_ReciprocalRank> knn3;
    131204  knn3.k(3);
     205  classifier::Target target1(vec_target());
    132206  knn3.train(mlw2,target1);
     207  utility::Matrix prediction1;
    133208  knn3.predict(mlw3,prediction1);
     209  utility::Matrix result1(2,4);
    134210  result1(0,0)=result1(1,3)=1.0;
    135211  result1(0,3)=result1(1,0)=5.0/6.0;
     
    138214  suite.add(suite.equal_range(result1.begin(), result1.end(),
    139215                              prediction1.begin(), 1));
    140 
    141  
     216}
     217
     218void test_reciprocal_distance(test::Suite& suite)
     219{
    142220  ////////////////////////////////////////////////////////////////
    143221  // A test of reciprocal distance weighting with training and test both weighted
     
    147225    knn4;
    148226  knn4.k(3);
     227  utility::MatrixWeighted xw2 = data_weighted();
     228  xw2(2,3).weight()=0;
     229  classifier::MatrixLookupWeighted mlw2(xw2);
     230  utility::MatrixWeighted xw3 = data_weighted();
     231  xw3(1,3).data()=7;
     232  xw3(2,3).weight()=0;
     233  classifier::MatrixLookupWeighted mlw3(xw3);
     234  classifier::Target target1(vec_target());
    149235  knn4.train(mlw2,target1);
     236  utility::Matrix prediction1;
    150237  knn4.predict(mlw3,prediction1);
    151238  if (!(std::isinf(prediction1(0,0)) && std::isinf(prediction1(0,1)) &&
     
    158245    suite.add(false);
    159246  }
    160 
    161 
     247}
     248
     249
     250void test_no_samples(test::Suite& suite)
     251{
    162252  ////////////////////////////////////////////////////////////////
    163253  // A test of when a class has no training samples, should give nan
     
    168258  std::vector<size_t> ind(2,2);
    169259  ind[1]=3;
     260  classifier::Target target1(vec_target());
    170261  classifier::Target target2(target1,utility::Index(ind));
    171   classifier::MatrixLookupWeighted mlw4(data1,weights2,
    172                                         utility::Index(data1.rows()),
     262
     263  utility::MatrixWeighted xw = data_weighted();
     264  xw(2,3).weight()=0.0;
     265
     266  classifier::MatrixLookupWeighted mlw4(xw, utility::Index(xw.rows()),
    173267                                        utility::Index(ind));
    174268  classifier::KNN<statistics::EuclideanDistance> knn5;
    175269  knn5.k(3);
    176270  knn5.train(mlw4,target2);
     271  utility::MatrixWeighted xw3 = data_weighted();
     272  xw3(1,3).data()=7;
     273  xw3(2,3).weight()=0;
     274  classifier::MatrixLookupWeighted mlw3(xw3);
     275  utility::Matrix prediction1;
    177276  knn5.predict(mlw3,prediction1);
    178277  if (!(std::isnan(prediction1(0,0)) && std::isnan(prediction1(0,1)) &&
     
    185284    suite.add(false);
    186285  }
    187 
     286}
     287
     288void test_no_features(test::Suite& suite)
     289{
    188290  ////////////////////////////////////////////////////////////////
    189291  // A test of when a test sample has no variables with non-zero
     
    191293  ////////////////////////////////////////////////////////////////
    192294  suite.err() << "test of predictions with nan distances (set to infinity in KNN)\n";
    193   weights1.all(1);
    194   weights1(1,0)=weights1(1,1)=weights1(2,0)=weights1(2,1)=0.0;
    195   weights2.all(1);
    196   weights2(0,0)=0.0;
     295  utility::MatrixWeighted xw1 = data_weighted();
     296  xw1(1,0).weight()=xw1(1,1).weight()=xw1(2,0).weight()=xw1(2,1).weight()=0.0;
     297  classifier::MatrixLookupWeighted mlw1(xw1);
     298
    197299  classifier::KNN<statistics::EuclideanDistance> knn6;
    198300  knn6.k(3);
     301  classifier::Target target1(vec_target());
    199302  knn6.train(mlw1,target1);
     303
     304  utility::MatrixWeighted xw3 = data_weighted();
     305  xw3(1,3).data()=7;
     306  xw3(0,0).weight()=0;
     307  classifier::MatrixLookupWeighted mlw3(xw3);
     308  utility::Matrix prediction1;
    200309  knn6.predict(mlw3,prediction1);
     310  utility::Matrix result1(2,4);
    201311  result1(0,0)=0;
    202312  result1(0,2)=result1(1,1)=result1(1,3)=1.0;
     
    204314  suite.add(suite.equal_range(result1.begin(), result1.end(),
    205315                              prediction1.begin(), 1));
    206   return suite.return_value();
    207 }
    208 
    209 
     316}
     317
     318std::vector<std::string> vec_target(void)
     319{
     320  std::vector<std::string> vec1(4, "pos");
     321  vec1[0]="neg";
     322  vec1[1]="neg";
     323  return vec1;
     324}
     325
Note: See TracChangeset for help on using the changeset viewer.