source: trunk/test/svm_test.cc @ 475

Last change on this file since 475 was 475, checked in by Peter, 16 years ago

I dont know what happened, but everything is changed...

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 3.3 KB
Line 
1// $Id: svm_test.cc 475 2005-12-22 15:03:51Z peter $
2
3#include <c++_tools/gslapi/matrix.h>
4#include <c++_tools/gslapi/vector.h>
5#include <c++_tools/classifier/SVM.h>
6#include <c++_tools/classifier/Kernel.h>
7#include <c++_tools/classifier/KernelLookup.h>
8#include <c++_tools/classifier/Kernel_SEV.h>
9#include <c++_tools/classifier/Kernel_MEV.h>
10#include <c++_tools/classifier/PolynomialKernelFunction.h>
11
12#include <cassert>
13#include <fstream>
14#include <iostream>
15#include <cstdlib>
16#include <limits>
17
18using namespace theplu;
19
20int main(const int argc,const char* argv[])
21{ 
22
23  std::ostream* error;
24  if (argc>1 && argv[1]==std::string("-v"))
25    error = &std::cerr;
26  else {
27    error = new std::ofstream("/dev/null");
28    if (argc>1)
29      std::cout << "svm_test -v : for printing extra information\n";
30  }
31  *error << "testing svm" << std::endl;
32  bool ok = true;
33
34  gslapi::matrix data2(2,3);
35  data2(0,0)=0;
36  data2(1,0)=0;
37  data2(0,1)=0;
38  data2(1,1)=1;
39  data2(0,2)=1;
40  data2(1,2)=0;
41  classifier::Target target2(3);
42  target2(0)=-1;
43  target2(1)=1;
44  target2(2)=1;
45  classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction(); 
46  classifier::Kernel_MEV kernel2(data2,*kf2);
47  assert(kernel2.size()==3);
48  assert(target2.size()==3);
49  classifier::KernelLookup kv2(kernel2);
50  *error << "testing with linear kernel" << std::endl;
51  assert(kv2.rows()==target2.size());
52  classifier::SVM classifier2(kv2, target2);
53  *error << "training...";
54  classifier2.train();
55  *error << " done." << std::endl;
56
57  double tmp=0;
58  for (size_t i=0; i<target2.size(); i++) 
59    tmp += classifier2.alpha()(i)*target2(i);
60 
61  if (tmp){
62    *error << "condition not fullfilled" << std::endl;
63    return -1;
64  }
65
66  if (classifier2.alpha()(1)!=2 || classifier2.alpha()(2)!=2){
67    *error << "wrong alpha" << std::endl;
68    *error << "alpha: " << classifier2.alpha() <<  std::endl;
69    *error << "expected: 4 2 2" <<  std::endl;
70
71    return -1;
72  }
73
74 
75
76  std::ifstream is("data/nm_data_centralized.txt");
77  gslapi::matrix transposed_data(is);
78  is.close();
79  // Because how the kernel is treated is changed, data must be transposed.
80  gslapi::matrix data=transposed_data;
81
82  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
83  classifier::Kernel_SEV kernel(data,*kf);
84
85
86  is.open("data/nm_target_bin.txt");
87  classifier::Target target(is);
88  is.close();
89
90  is.open("data/nm_alpha_linear_matlab.txt");
91  theplu::gslapi::vector alpha_matlab(is);
92  is.close();
93
94  classifier::KernelLookup kv(kernel);
95  theplu::classifier::SVM svm(kv, target);
96  if (!svm.train()){
97    ok=false;
98    *error << "Training failured" << std::endl;
99  }
100
101  theplu::gslapi::vector alpha = svm.alpha();
102     
103  // Comparing alpha to alpha_matlab
104  theplu::gslapi::vector diff_alpha(alpha);
105  diff_alpha-=alpha_matlab;
106  if (diff_alpha*diff_alpha> 1e-10 ){
107    *error << "Difference to matlab alphas too large\n";
108    ok=false;
109  }
110
111  // Comparing output to target
112  theplu::gslapi::vector output(svm.output());
113  double slack = 0;
114  for (unsigned int i=0; i<target.size(); i++){
115    if (output(i)*target(i) < 1){
116      slack += 1 - output(i)*target(i);
117    }
118  }
119  double slack_bound=2e-7;
120  if (slack > slack_bound){
121    *error << "Slack too large. Is the bias correct?\n";
122    *error << "slack: " << slack << std::endl;
123    *error << "expected less than " << slack_bound << std::endl;
124    ok = false;
125  }
126 
127  delete kf;
128  delete kf2;
129
130  if (error!=&std::cerr)
131    delete error;
132
133  if(ok)
134    return 0;
135  return -1;
136 
137}
Note: See TracBrowser for help on using the repository browser.