source: trunk/test/svm_test.cc @ 1121

Last change on this file since 1121 was 1121, checked in by Peter, 16 years ago

fixes #308

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 4.5 KB
Line 
1// $Id: svm_test.cc 1121 2008-02-22 15:29:56Z peter $
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "yat/classifier/SVM.h"
27#include "yat/classifier/Kernel.h"
28#include "yat/classifier/KernelLookup.h"
29#include "yat/classifier/Kernel_SEV.h"
30#include "yat/classifier/Kernel_MEV.h"
31#include "yat/classifier/PolynomialKernelFunction.h"
32#include "yat/classifier/Target.h"
33#include "yat/utility/Matrix.h"
34#include "yat/utility/Vector.h"
35
36#include <cassert>
37#include <fstream>
38#include <iostream>
39#include <cstdlib>
40#include <limits>
41
42using namespace theplu::yat;
43
44int main(const int argc,const char* argv[])
45{ 
46
47  std::ostream* error;
48  if (argc>1 && argv[1]==std::string("-v"))
49    error = &std::cerr;
50  else {
51    error = new std::ofstream("/dev/null");
52    if (argc>1)
53      std::cout << "svm_test -v : for printing extra information\n";
54  }
55  *error << "testing svm" << std::endl;
56  bool ok = true;
57
58  utility::Matrix data2_core(2,3);
59  data2_core(0,0)=0;
60  data2_core(1,0)=0;
61  data2_core(0,1)=0;
62  data2_core(1,1)=1;
63  data2_core(0,2)=1;
64  data2_core(1,2)=0;
65  classifier::MatrixLookup data2(data2_core);
66  std::vector<std::string> label;
67  label.reserve(3);
68  label.push_back("-1");
69  label.push_back("1");
70  label.push_back("1");
71  classifier::Target target2(label);
72  classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction(); 
73  classifier::Kernel_MEV kernel2(data2,*kf2);
74  assert(kernel2.size()==3);
75  assert(target2.size()==3);
76  for (size_t i=0; i<3; i++){
77    for (size_t j=0; j<3; j++)
78      *error << kernel2(i,j) << " ";
79    *error << std::endl;
80  }
81  classifier::KernelLookup kv2(kernel2);
82  *error << "testing with linear kernel" << std::endl;
83  assert(kv2.rows()==target2.size());
84  classifier::SVM classifier2;
85  *error << "training...";
86  classifier2.train(kv2, target2);
87  *error << " done!" << std::endl;
88
89  double tmp=0;
90  for (size_t i=0; i<target2.size(); i++) 
91    if (target2.binary(i))
92      tmp += classifier2.alpha()(i);
93    else
94      tmp -= classifier2.alpha()(i);
95
96  if (tmp){
97    *error << "ERROR: found " << tmp << " expected zero" << std::endl;
98    return -1;
99  }
100
101  double tol=1e-6;
102  if (fabs(classifier2.alpha()(1)-2)>tol || 
103      fabs(classifier2.alpha()(2)-2)>tol){
104    *error << "wrong alpha" << std::endl;
105    *error << "alpha: " << classifier2.alpha() <<  std::endl;
106    *error << "expected: 4 2 2" <<  std::endl;
107
108    return -1;
109  }
110
111 
112
113  std::ifstream is("data/nm_data_centralized.txt");
114  utility::Matrix data_core(is);
115  is.close();
116
117  classifier::MatrixLookup data(data_core);
118
119  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
120  classifier::Kernel_SEV kernel(data,*kf);
121
122
123  is.open("data/nm_target_bin.txt");
124  classifier::Target target(is);
125  is.close();
126
127  is.open("data/nm_alpha_linear_matlab.txt");
128  theplu::yat::utility::Vector alpha_matlab(is);
129  is.close();
130
131  classifier::KernelLookup kv(kernel);
132  theplu::yat::classifier::SVM svm;
133  svm.train(kv, target);
134
135  theplu::yat::utility::Vector alpha = svm.alpha();
136     
137  // Comparing alpha to alpha_matlab
138  theplu::yat::utility::Vector diff_alpha(alpha);
139  diff_alpha-=alpha_matlab;
140  if (diff_alpha*diff_alpha> 1e-10 ){
141    *error << "Difference to matlab alphas too large\n";
142    ok=false;
143  }
144
145  // Comparing output to target
146  theplu::yat::utility::Vector output(svm.output());
147  double slack = 0;
148  for (unsigned int i=0; i<target.size(); i++){
149    if (output(i)*target(i) < 1){
150      if (target.binary(i))
151        slack += 1 - output(i);
152      else
153        slack += 1 + output(i);
154    }
155  }
156  double slack_bound=2e-7;
157  if (slack > slack_bound || std::isnan(slack)){
158    *error << "Slack too large. Is the bias correct?\n";
159    *error << "slack: " << slack << std::endl;
160    *error << "expected less than " << slack_bound << std::endl;
161    ok = false;
162  }
163 
164  delete kf;
165  delete kf2;
166
167  if (error!=&std::cerr)
168    delete error;
169
170  if(ok)
171    return 0;
172  return -1;
173 
174}
Note: See TracBrowser for help on using the repository browser.