source: trunk/test/svm_test.cc @ 337

Last change on this file since 337 was 337, checked in by Peter, 18 years ago

changed bound of slack

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.6 KB
Line 
1// $Id: svm_test.cc 337 2005-06-03 14:17:17Z peter $
2
3#include <c++_tools/gslapi/matrix.h>
4#include <c++_tools/gslapi/vector.h>
5#include <c++_tools/svm/SVM.h>
6#include <c++_tools/svm/Kernel.h>
7#include <c++_tools/svm/Kernel_SEV.h>
8#include <c++_tools/svm/Kernel_MEV.h>
9#include <c++_tools/svm/PolynomialKernelFunction.h>
10
11#include <fstream>
12#include <iostream>
13#include <cstdlib>
14#include <limits.h>
15
16using namespace theplu;
17
18int main(const int argc,const char* argv[])
19{ 
20
21  bool print = (argc>1 && argv[1]==std::string("-p"));
22  bool ok = true;
23
24  gslapi::matrix data2(2,3);
25  data2(0,0)=0;
26  data2(1,0)=0;
27  data2(0,1)=0;
28  data2(1,1)=1;
29  data2(0,2)=1;
30  data2(1,2)=0;
31  gslapi::vector target2(3);
32  target2(0)=-1;
33  target2(1)=1;
34  target2(2)=1;
35  svm::KernelFunction* kf2 = new svm::PolynomialKernelFunction(); 
36  svm::Kernel_MEV kernel2(data2,*kf2);
37  assert(kernel2.size()==3);
38  svm::SVM svm2(kernel2, target2);
39  svm2.train();
40
41  if (svm2.alpha()*target2){
42    std::cerr << "condition not fullfilled" << std::endl;
43    return -1;
44  }
45
46  if (svm2.alpha()(1)!=2 || svm2.alpha()(2)!=2){
47    std::cerr << "wrong alpha" << std::endl;
48    std::cerr << "alpha: " << svm2.alpha() <<  std::endl;
49    std::cerr << "expected: 4 2 2" <<  std::endl;
50
51    return -1;
52  }
53
54 
55
56  std::ifstream is("data/nm_data_centralized.txt");
57  gslapi::matrix transposed_data(is);
58  is.close();
59  // Because how the kernel is treated is changed, data must be transposed.
60  gslapi::matrix data=transposed_data;
61
62  svm::KernelFunction* kf = new svm::PolynomialKernelFunction(); 
63  svm::Kernel_SEV kernel(data,*kf);
64
65
66  is.open("data/nm_target_bin.txt");
67  theplu::gslapi::vector target(is);
68  is.close();
69
70  is.open("data/nm_alpha_linear_matlab.txt");
71  theplu::gslapi::vector alpha_matlab(is);
72  is.close();
73
74  theplu::svm::SVM svm(kernel, target);
75  if (!svm.train()){
76    ok=false;
77    if (print)
78      std::cerr << "Training failured" << std::endl;
79  }
80
81  theplu::gslapi::vector alpha = svm.alpha();
82     
83  // Comparing alpha to alpha_matlab
84  theplu::gslapi::vector diff_alpha = alpha - alpha_matlab;
85  if (diff_alpha*diff_alpha> 1e-10 ){
86    if (print) 
87      std::cerr << "Difference to matlab alphas too large\n";
88    ok=false;
89  }
90
91  // Comparing output to target
92  theplu::gslapi::vector output = svm.output();
93  double slack = 0;
94  for (unsigned int i=0; i<target.size(); i++){
95    if (output[i]*target[i] < 1){
96      slack += 1 - output[i]*target[i];
97    }
98  }
99  double slack_bound=2e-7;
100  if (slack > slack_bound){
101    if (print){
102      std::cerr << "Slack too large. Is the bias correct?\n";
103      std::cerr << "slack: " << slack << std::endl;
104      std::cerr << "expected less than " << slack_bound << std::endl;
105    }
106    ok = false;
107  }
108 
109  delete kf;
110  delete kf2;
111
112  if(ok)
113    return 0;
114  return -1;
115 
116}
Note: See TracBrowser for help on using the repository browser.