source: trunk/test/svm_test.cc @ 616

Last change on this file since 616 was 616, checked in by Jari Häkkinen, 15 years ago

Removed gslapi namespace and put the code into utility namespace.
Move #ifndef _header_ idiom to top of touched header files.
Removed unneccesary #includes, and added needed #includes.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 3.7 KB
Line 
1// $Id: svm_test.cc 616 2006-08-31 08:52:02Z jari $
2
3#include <c++_tools/classifier/SVM.h>
4#include <c++_tools/classifier/Kernel.h>
5#include <c++_tools/classifier/KernelLookup.h>
6#include <c++_tools/classifier/Kernel_SEV.h>
7#include <c++_tools/classifier/Kernel_MEV.h>
8#include <c++_tools/classifier/PolynomialKernelFunction.h>
9#include <c++_tools/utility/matrix.h>
10#include <c++_tools/utility/vector.h>
11
12#include <cassert>
13#include <fstream>
14#include <iostream>
15#include <cstdlib>
16#include <limits>
17
18using namespace theplu;
19
20int main(const int argc,const char* argv[])
21{ 
22
23  std::ostream* error;
24  if (argc>1 && argv[1]==std::string("-v"))
25    error = &std::cerr;
26  else {
27    error = new std::ofstream("/dev/null");
28    if (argc>1)
29      std::cout << "svm_test -v : for printing extra information\n";
30  }
31  *error << "testing svm" << std::endl;
32  bool ok = true;
33
34  utility::matrix data2_core(2,3);
35  data2_core(0,0)=0;
36  data2_core(1,0)=0;
37  data2_core(0,1)=0;
38  data2_core(1,1)=1;
39  data2_core(0,2)=1;
40  data2_core(1,2)=0;
41  classifier::MatrixLookup data2(data2_core);
42  std::vector<std::string> label;
43  label.reserve(3);
44  label.push_back("-1");
45  label.push_back("1");
46  label.push_back("1");
47  classifier::Target target2(label);
48  classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction(); 
49  classifier::Kernel_MEV kernel2(data2,*kf2);
50  assert(kernel2.size()==3);
51  assert(target2.size()==3);
52  for (size_t i=0; i<3; i++){
53    for (size_t j=0; j<3; j++)
54      *error << kernel2(i,j) << " ";
55    *error << std::endl;
56  }
57  classifier::KernelLookup kv2(kernel2);
58  *error << "testing with linear kernel" << std::endl;
59  assert(kv2.rows()==target2.size());
60  classifier::SVM classifier2(kv2, target2);
61  *error << "training...";
62  classifier2.train();
63  *error << " done!" << std::endl;
64
65  double tmp=0;
66  for (size_t i=0; i<target2.size(); i++) 
67    if (target2.binary(i))
68      tmp += classifier2.alpha()(i);
69    else
70      tmp -= classifier2.alpha()(i);
71
72  if (tmp){
73    *error << "ERROR: found " << tmp << " expected zero" << std::endl;
74    return -1;
75  }
76
77  double tol=1e-6;
78  if (fabs(classifier2.alpha()(1)-2)>tol || 
79      fabs(classifier2.alpha()(2)-2)>tol){
80    *error << "wrong alpha" << std::endl;
81    *error << "alpha: " << classifier2.alpha() <<  std::endl;
82    *error << "expected: 4 2 2" <<  std::endl;
83
84    return -1;
85  }
86
87 
88
89  std::ifstream is("data/nm_data_centralized.txt");
90  utility::matrix data_core(is);
91  is.close();
92
93  classifier::MatrixLookup data(data_core);
94
95  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
96  classifier::Kernel_SEV kernel(data,*kf);
97
98
99  is.open("data/nm_target_bin.txt");
100  classifier::Target target(is);
101  is.close();
102
103  is.open("data/nm_alpha_linear_matlab.txt");
104  theplu::utility::vector alpha_matlab(is);
105  is.close();
106
107  classifier::KernelLookup kv(kernel);
108  theplu::classifier::SVM svm(kv, target);
109  if (!svm.train()){
110    ok=false;
111    *error << "Training failured" << std::endl;
112  }
113
114  theplu::utility::vector alpha = svm.alpha();
115     
116  // Comparing alpha to alpha_matlab
117  theplu::utility::vector diff_alpha(alpha);
118  diff_alpha-=alpha_matlab;
119  if (diff_alpha*diff_alpha> 1e-10 ){
120    *error << "Difference to matlab alphas too large\n";
121    ok=false;
122  }
123
124  // Comparing output to target
125  theplu::utility::vector output(svm.output());
126  double slack = 0;
127  for (unsigned int i=0; i<target.size(); i++){
128    if (output(i)*target(i) < 1){
129      if (target.binary(i))
130        slack += 1 - output(i);
131      else
132        slack += 1 + output(i);
133    }
134  }
135  double slack_bound=2e-7;
136  if (slack > slack_bound){
137    *error << "Slack too large. Is the bias correct?\n";
138    *error << "slack: " << slack << std::endl;
139    *error << "expected less than " << slack_bound << std::endl;
140    ok = false;
141  }
142 
143  delete kf;
144  delete kf2;
145
146  if (error!=&std::cerr)
147    delete error;
148
149  if(ok)
150    return 0;
151  return -1;
152 
153}
Note: See TracBrowser for help on using the repository browser.