source: trunk/test/svm_test.cc @ 514

Last change on this file since 514 was 514, checked in by Peter, 17 years ago

generalised binary functionality in Target

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 3.5 KB
Line 
1// $Id: svm_test.cc 514 2006-02-20 09:45:34Z peter $
2
3#include <c++_tools/gslapi/matrix.h>
4#include <c++_tools/gslapi/vector.h>
5#include <c++_tools/classifier/SVM.h>
6#include <c++_tools/classifier/Kernel.h>
7#include <c++_tools/classifier/KernelLookup.h>
8#include <c++_tools/classifier/Kernel_SEV.h>
9#include <c++_tools/classifier/Kernel_MEV.h>
10#include <c++_tools/classifier/PolynomialKernelFunction.h>
11
12#include <cassert>
13#include <fstream>
14#include <iostream>
15#include <cstdlib>
16#include <limits>
17
18using namespace theplu;
19
20int main(const int argc,const char* argv[])
21{ 
22
23  std::ostream* error;
24  if (argc>1 && argv[1]==std::string("-v"))
25    error = &std::cerr;
26  else {
27    error = new std::ofstream("/dev/null");
28    if (argc>1)
29      std::cout << "svm_test -v : for printing extra information\n";
30  }
31  *error << "testing svm" << std::endl;
32  bool ok = true;
33
34  gslapi::matrix data2(2,3);
35  data2(0,0)=0;
36  data2(1,0)=0;
37  data2(0,1)=0;
38  data2(1,1)=1;
39  data2(0,2)=1;
40  data2(1,2)=0;
41  std::vector<std::string> label;
42  label.reserve(3);
43  label.push_back("-1");
44  label.push_back("1");
45  label.push_back("1");
46  classifier::Target target2(label);
47  classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction(); 
48  classifier::Kernel_MEV kernel2(data2,*kf2);
49  assert(kernel2.size()==3);
50  assert(target2.size()==3);
51  classifier::KernelLookup kv2(kernel2);
52  *error << "testing with linear kernel" << std::endl;
53  assert(kv2.rows()==target2.size());
54  classifier::SVM classifier2(kv2, target2);
55  *error << "training...";
56  classifier2.train();
57  *error << " done!" << std::endl;
58
59  double tmp=0;
60  for (size_t i=0; i<target2.size(); i++) 
61    if (target2.binary(i))
62      tmp += classifier2.alpha()(i);
63    else
64      tmp -= classifier2.alpha()(i);
65
66  if (tmp){
67    *error << "ERROR: found " << tmp << " expected zero" << std::endl;
68    return -1;
69  }
70
71  double tol=1e-6;
72  if (fabs(classifier2.alpha()(1)-2)>tol || 
73      fabs(classifier2.alpha()(2)-2)>tol){
74    *error << "wrong alpha" << std::endl;
75    *error << "alpha: " << classifier2.alpha() <<  std::endl;
76    *error << "expected: 4 2 2" <<  std::endl;
77
78    return -1;
79  }
80
81 
82
83  std::ifstream is("data/nm_data_centralized.txt");
84  gslapi::matrix transposed_data(is);
85  is.close();
86  // Because how the kernel is treated is changed, data must be transposed.
87  gslapi::matrix data=transposed_data;
88
89  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
90  classifier::Kernel_SEV kernel(data,*kf);
91
92
93  is.open("data/nm_target_bin.txt");
94  classifier::Target target(is);
95  is.close();
96
97  is.open("data/nm_alpha_linear_matlab.txt");
98  theplu::gslapi::vector alpha_matlab(is);
99  is.close();
100
101  classifier::KernelLookup kv(kernel);
102  theplu::classifier::SVM svm(kv, target);
103  if (!svm.train()){
104    ok=false;
105    *error << "Training failured" << std::endl;
106  }
107
108  theplu::gslapi::vector alpha = svm.alpha();
109     
110  // Comparing alpha to alpha_matlab
111  theplu::gslapi::vector diff_alpha(alpha);
112  diff_alpha-=alpha_matlab;
113  if (diff_alpha*diff_alpha> 1e-10 ){
114    *error << "Difference to matlab alphas too large\n";
115    ok=false;
116  }
117
118  // Comparing output to target
119  theplu::gslapi::vector output(svm.output());
120  double slack = 0;
121  for (unsigned int i=0; i<target.size(); i++){
122    if (output(i)*target(i) < 1){
123      if (target.binary(i))
124        slack += 1 - output(i);
125      else
126        slack += 1 + output(i);
127    }
128  }
129  double slack_bound=2e-7;
130  if (slack > slack_bound){
131    *error << "Slack too large. Is the bias correct?\n";
132    *error << "slack: " << slack << std::endl;
133    *error << "expected less than " << slack_bound << std::endl;
134    ok = false;
135  }
136 
137  delete kf;
138  delete kf2;
139
140  if (error!=&std::cerr)
141    delete error;
142
143  if(ok)
144    return 0;
145  return -1;
146 
147}
Note: See TracBrowser for help on using the repository browser.