source: trunk/test/svm_test.cc @ 323

Last change on this file since 323 was 323, checked in by Peter, 18 years ago

major modifications in SVM::train() in an attempt to speed it up. Interface is changed, so from now if validation is needed it should be taken care of by the Kernel object.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.4 KB
Line 
1// $Id: svm_test.cc 323 2005-05-26 20:54:30Z peter $
2
3#include <c++_tools/gslapi/matrix.h>
4#include <c++_tools/gslapi/vector.h>
5#include <c++_tools/svm/SVM.h>
6#include <c++_tools/svm/Kernel_SEV.h>
7#include <c++_tools/svm/PolynomialKernelFunction.h>
8
9#include <fstream>
10#include <iostream>
11#include <cstdlib>
12#include <limits.h>
13
14using namespace theplu;
15
16int main(const int argc,const char* argv[])
17{ 
18
19  bool print = (argc>1 && argv[1]==std::string("-p"));
20  bool ok = true;
21
22  gslapi::matrix data2(2,3);
23  data2(0,0)=0;
24  data2(1,0)=0;
25  data2(0,1)=0;
26  data2(1,1)=1;
27  data2(0,2)=1;
28  data2(1,2)=0;
29  gslapi::vector target2(3);
30  target2(0)=-1;
31  target2(1)=1;
32  target2(2)=1;
33  svm::KernelFunction* kf2 = new svm::PolynomialKernelFunction(); 
34  svm::Kernel_MEV kernel2(data2,*kf2);
35  assert(kernel2.size()==3);
36  svm::SVM svm2(kernel2, target2);
37  svm2.train();
38
39  if (svm2.alpha()*target2){
40    std::cerr << "condition not fullfilled" << std::endl;
41    return -1;
42  }
43
44  if (svm2.alpha()(1)!=2 || svm2.alpha()(2)!=2){
45    std::cerr << "wrong alpha" << std::endl;
46    std::cerr << "alpha: " << svm2.alpha() <<  std::endl;
47    std::cerr << "expected: 4 2 2" <<  std::endl;
48
49    return -1;
50  }
51
52 
53
54  std::ifstream is("data/nm_data_centralized.txt");
55  gslapi::matrix transposed_data(is);
56  is.close();
57  // Because how the kernel is treated is changed, data must be transposed.
58  gslapi::matrix data=transposed_data;
59
60  svm::KernelFunction* kf = new svm::PolynomialKernelFunction(); 
61  svm::Kernel_SEV kernel(data,*kf);
62
63
64  is.open("data/nm_target_bin.txt");
65  theplu::gslapi::vector target(is);
66  is.close();
67
68  is.open("data/nm_alpha_linear_matlab.txt");
69  theplu::gslapi::vector alpha_matlab(is);
70  is.close();
71
72  theplu::svm::SVM svm(kernel, target);
73  if (!svm.train()){
74    ok=false;
75    if (print)
76      std::cerr << "Training failured" << std::endl;
77  }
78
79  theplu::gslapi::vector alpha = svm.alpha();
80     
81  // Comparing alpha to alpha_matlab
82  theplu::gslapi::vector diff_alpha = alpha - alpha_matlab;
83  if (diff_alpha*diff_alpha> 1e-10 ){
84    if (print) 
85      std::cerr << "Difference to matlab alphas too large\n";
86    ok=false;
87  }
88
89  // Comparing output to target
90  theplu::gslapi::vector output = svm.output();
91  double slack = 0;
92  for (unsigned int i=0; i<target.size(); i++){
93    if (output[i]*target[i] < 1){
94      slack += 1 - output[i]*target[i];
95    }
96  }
97  if (slack > 1e-7){
98    if (print){
99      std::cerr << "Slack too large. Is the bias correct?\n";
100      std::cerr << "slack: " << slack << std::endl;
101    }
102    ok = false;
103  }
104 
105 
106  if(ok)
107    return 0;
108  return -1;
109 
110}
Note: See TracBrowser for help on using the repository browser.