source: trunk/test/svm_test.cc @ 329

Last change on this file since 329 was 329, checked in by Peter, 18 years ago

corrected includes

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.5 KB
Line 
1// $Id: svm_test.cc 329 2005-06-01 21:28:43Z peter $
2
3#include <c++_tools/gslapi/matrix.h>
4#include <c++_tools/gslapi/vector.h>
5#include <c++_tools/svm/SVM.h>
6#include <c++_tools/svm/Kernel_SEV.h>
7#include <c++_tools/svm/Kernel_MEV.h>
8#include <c++_tools/svm/PolynomialKernelFunction.h>
9
10#include <fstream>
11#include <iostream>
12#include <cstdlib>
13#include <limits.h>
14
15using namespace theplu;
16
17int main(const int argc,const char* argv[])
18{ 
19
20  bool print = (argc>1 && argv[1]==std::string("-p"));
21  bool ok = true;
22
23  gslapi::matrix data2(2,3);
24  data2(0,0)=0;
25  data2(1,0)=0;
26  data2(0,1)=0;
27  data2(1,1)=1;
28  data2(0,2)=1;
29  data2(1,2)=0;
30  gslapi::vector target2(3);
31  target2(0)=-1;
32  target2(1)=1;
33  target2(2)=1;
34  svm::KernelFunction* kf2 = new svm::PolynomialKernelFunction(); 
35  svm::Kernel_MEV kernel2(data2,*kf2);
36  assert(kernel2.size()==3);
37  svm::SVM svm2(kernel2, target2);
38  svm2.train();
39
40  if (svm2.alpha()*target2){
41    std::cerr << "condition not fullfilled" << std::endl;
42    return -1;
43  }
44
45  if (svm2.alpha()(1)!=2 || svm2.alpha()(2)!=2){
46    std::cerr << "wrong alpha" << std::endl;
47    std::cerr << "alpha: " << svm2.alpha() <<  std::endl;
48    std::cerr << "expected: 4 2 2" <<  std::endl;
49
50    return -1;
51  }
52
53 
54
55  std::ifstream is("data/nm_data_centralized.txt");
56  gslapi::matrix transposed_data(is);
57  is.close();
58  // Because how the kernel is treated is changed, data must be transposed.
59  gslapi::matrix data=transposed_data;
60
61  svm::KernelFunction* kf = new svm::PolynomialKernelFunction(); 
62  svm::Kernel_SEV kernel(data,*kf);
63
64
65  is.open("data/nm_target_bin.txt");
66  theplu::gslapi::vector target(is);
67  is.close();
68
69  is.open("data/nm_alpha_linear_matlab.txt");
70  theplu::gslapi::vector alpha_matlab(is);
71  is.close();
72
73  theplu::svm::SVM svm(kernel, target);
74  if (!svm.train()){
75    ok=false;
76    if (print)
77      std::cerr << "Training failured" << std::endl;
78  }
79
80  theplu::gslapi::vector alpha = svm.alpha();
81     
82  // Comparing alpha to alpha_matlab
83  theplu::gslapi::vector diff_alpha = alpha - alpha_matlab;
84  if (diff_alpha*diff_alpha> 1e-10 ){
85    if (print) 
86      std::cerr << "Difference to matlab alphas too large\n";
87    ok=false;
88  }
89
90  // Comparing output to target
91  theplu::gslapi::vector output = svm.output();
92  double slack = 0;
93  for (unsigned int i=0; i<target.size(); i++){
94    if (output[i]*target[i] < 1){
95      slack += 1 - output[i]*target[i];
96    }
97  }
98  if (slack > 1e-7){
99    if (print){
100      std::cerr << "Slack too large. Is the bias correct?\n";
101      std::cerr << "slack: " << slack << std::endl;
102    }
103    ok = false;
104  }
105 
106 
107  if(ok)
108    return 0;
109  return -1;
110 
111}
Note: See TracBrowser for help on using the repository browser.