source: trunk/test/svm_test.cc @ 442

Last change on this file since 442 was 442, checked in by Jari Häkkinen, 17 years ago

Added include of needed header files.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.6 KB
Line 
1// $Id: svm_test.cc 442 2005-12-15 14:17:53Z jari $
2
3#include <c++_tools/gslapi/matrix.h>
4#include <c++_tools/gslapi/vector.h>
5#include <c++_tools/svm/SVM.h>
6#include <c++_tools/svm/Kernel.h>
7#include <c++_tools/svm/Kernel_SEV.h>
8#include <c++_tools/svm/Kernel_MEV.h>
9#include <c++_tools/svm/PolynomialKernelFunction.h>
10
11#include <cassert>
12#include <fstream>
13#include <iostream>
14#include <cstdlib>
15#include <limits>
16
17using namespace theplu;
18
19int main(const int argc,const char* argv[])
20{ 
21
22  bool print = (argc>1 && argv[1]==std::string("-p"));
23  bool ok = true;
24
25  gslapi::matrix data2(2,3);
26  data2(0,0)=0;
27  data2(1,0)=0;
28  data2(0,1)=0;
29  data2(1,1)=1;
30  data2(0,2)=1;
31  data2(1,2)=0;
32  gslapi::vector target2(3);
33  target2(0)=-1;
34  target2(1)=1;
35  target2(2)=1;
36  svm::KernelFunction* kf2 = new svm::PolynomialKernelFunction(); 
37  svm::Kernel_MEV kernel2(data2,*kf2);
38  assert(kernel2.size()==3);
39  svm::SVM svm2(kernel2, target2);
40  svm2.train();
41
42  if (svm2.alpha()*target2){
43    std::cerr << "condition not fullfilled" << std::endl;
44    return -1;
45  }
46
47  if (svm2.alpha()(1)!=2 || svm2.alpha()(2)!=2){
48    std::cerr << "wrong alpha" << std::endl;
49    std::cerr << "alpha: " << svm2.alpha() <<  std::endl;
50    std::cerr << "expected: 4 2 2" <<  std::endl;
51
52    return -1;
53  }
54
55 
56
57  std::ifstream is("data/nm_data_centralized.txt");
58  gslapi::matrix transposed_data(is);
59  is.close();
60  // Because how the kernel is treated is changed, data must be transposed.
61  gslapi::matrix data=transposed_data;
62
63  svm::KernelFunction* kf = new svm::PolynomialKernelFunction(); 
64  svm::Kernel_SEV kernel(data,*kf);
65
66
67  is.open("data/nm_target_bin.txt");
68  theplu::gslapi::vector target(is);
69  is.close();
70
71  is.open("data/nm_alpha_linear_matlab.txt");
72  theplu::gslapi::vector alpha_matlab(is);
73  is.close();
74
75  theplu::svm::SVM svm(kernel, target);
76  if (!svm.train()){
77    ok=false;
78    if (print)
79      std::cerr << "Training failured" << std::endl;
80  }
81
82  theplu::gslapi::vector alpha = svm.alpha();
83     
84  // Comparing alpha to alpha_matlab
85  theplu::gslapi::vector diff_alpha(alpha);
86  diff_alpha-=alpha_matlab;
87  if (diff_alpha*diff_alpha> 1e-10 ){
88    if (print) 
89      std::cerr << "Difference to matlab alphas too large\n";
90    ok=false;
91  }
92
93  // Comparing output to target
94  theplu::gslapi::vector output(svm.output());
95  double slack = 0;
96  for (unsigned int i=0; i<target.size(); i++){
97    if (output[i]*target[i] < 1){
98      slack += 1 - output[i]*target[i];
99    }
100  }
101  double slack_bound=2e-7;
102  if (slack > slack_bound){
103    if (print){
104      std::cerr << "Slack too large. Is the bias correct?\n";
105      std::cerr << "slack: " << slack << std::endl;
106      std::cerr << "expected less than " << slack_bound << std::endl;
107    }
108    ok = false;
109  }
110 
111  delete kf;
112  delete kf2;
113
114  if(ok)
115    return 0;
116  return -1;
117 
118}
Note: See TracBrowser for help on using the repository browser.