source: trunk/test/vector_distance_test.cc @ 901

Last change on this file since 901 was 901, checked in by Markus Ringnér, 15 years ago

Added a template KNN classifier where the distance measure is the template. Refs #250 and #182

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 3.8 KB
Line 
1// $Id: vector_distance_test.cc 901 2007-09-27 09:00:05Z markus $
2
3#include "yat/classifier/DataLookupWeighted1D.h"
4#include "yat/classifier/MatrixLookupWeighted.h"
5#include "yat/statistics/euclidean_vector_distance.h"
6#include "yat/statistics/pearson_vector_distance.h"
7#include "yat/statistics/vector_distance_ptr.h"
8#include "yat/utility/matrix.h"
9#include "yat/utility/vector.h"
10
11#include <cassert>
12#include <fstream>
13#include <iostream>
14#include <list>
15#include <vector>
16
17
18using namespace theplu::yat;
19
20
21// Function to test pointers to distance specialized for DataLookup1D::iterator
22double f(statistics::vector_distance_lookup_weighted_ptr distance) {
23  utility::matrix m(2,3,1);
24  m(0,1)=2;
25  m(1,0)=0;
26  m(1,1)=0;
27  utility::matrix w(2,3,1);
28  w(0,0)=0;
29  classifier::MatrixLookupWeighted mw(m,w);
30  classifier::DataLookupWeighted1D aw(mw,0,true);
31  classifier::DataLookupWeighted1D bw(mw,1,true);
32 
33  double dist=(*distance)(aw.begin(),aw.end(),bw.begin());
34  return dist; 
35}
36
37int main(const int argc,const char* argv[])
38
39{ 
40  std::ostream* error;
41  if (argc>1 && argv[1]==std::string("-v"))
42    error = &std::cerr;
43  else {
44    error = new std::ofstream("/dev/null");
45    if (argc>1)
46      std::cout << "vector_distance_test -v : for printing extra information\n";
47  }
48  *error << "testing vector_distance" << std::endl;
49  bool ok = true;
50 
51  utility::vector a(3,1);
52  a(1) = 2;
53  utility::vector b(3,0);
54  b(2) = 1;
55 
56  double tolerance=1e-4;
57 
58  double dist=statistics::vector_distance(a.begin(),a.end(),b.begin(),
59                                          statistics::euclidean_vector_distance_tag());
60  if(fabs(dist-2.23607)>tolerance) {
61    *error << "Error in unweighted Euclidean vector_distance " << std::endl;
62    ok=false;
63  }
64 
65  dist=statistics::vector_distance(a.begin(),a.end(),b.begin(),
66                                   statistics::pearson_vector_distance_tag()); 
67  if(fabs(dist-1.5)>tolerance) {
68    *error << "Error in unweighted Pearson vector_distance " << std::endl;
69    ok=false;
70  }
71 
72 
73  // Testing weighted versions
74  utility::matrix m(2,3,1);
75  m(0,1)=2;
76  m(1,0)=0;
77  m(1,1)=0;
78  utility::matrix w(2,3,1);
79  w(0,0)=0;
80  classifier::MatrixLookupWeighted mw(m,w);
81  classifier::DataLookupWeighted1D aw(mw,0,true);
82  classifier::DataLookupWeighted1D bw(mw,1,true);
83 
84  dist=statistics::vector_distance(aw.begin(),aw.end(),bw.begin(),
85                                   statistics::euclidean_vector_distance_tag());
86 
87  if(fabs(dist-2)>tolerance) {
88    *error << "Error in weighted Euclidean vector_distance " << std::endl;
89    ok=false;
90  }
91 
92  dist=statistics::vector_distance(aw.begin(),aw.end(),bw.begin(),
93                                   statistics::pearson_vector_distance_tag());
94 
95  if(fabs(dist-2)>tolerance) {
96    *error << "Error in weighted Pearson vector_distance " << std::endl;
97    ok=false;
98  }
99 
100 
101  // Test with pointer to a vector_distance
102  statistics::vector_distance_lookup_weighted_ptr test_ptr=
103    statistics::vector_distance<statistics::euclidean_vector_distance_tag>;
104  dist=(*test_ptr)(aw.begin(),aw.end(),bw.begin());
105  if(fabs(dist-2)>tolerance) {
106    *error << "Error when using pointer to vector_distance" << std::endl;
107    ok=false;
108  }
109 
110  // Test with std::vectors
111  std::vector<double> sa(3,1);
112  sa[1] = 2;
113  std::vector<double> sb(3,0);
114  sb[2] = 1;
115 
116  dist=statistics::vector_distance(sa.begin(),sa.end(),sb.begin(),
117                                   statistics::euclidean_vector_distance_tag()); 
118  if(fabs(dist-2.23607)>tolerance) {
119    *error << "Error in vector_distance for std::vector " << std::endl;
120    ok=false;
121  }
122 
123  // Test for a std::list and a std::vector
124  std::list<double> la;
125  std::copy(sa.begin(),sa.end(),std::back_inserter<std::list<double> >(la));
126  dist=statistics::vector_distance(la.begin(),la.end(),sb.begin(),
127                                   statistics::euclidean_vector_distance_tag()); 
128  if(fabs(dist-2.23607)>tolerance) {
129    *error << "Error in vector_distance for std::list " << std::endl;
130    ok=false;
131  }
132 
133  if(!ok) {
134    *error << "vector_distance_test failed" << std::endl;
135  }
136  if (error!=&std::cerr)
137    delete error;
138  if (ok=true) 
139    return 0;
140  return -1;
141}
142
143
Note: See TracBrowser for help on using the repository browser.