source: trunk/test/svm_test.cc @ 1210

Last change on this file since 1210 was 1210, checked in by Peter, 14 years ago

refs #223 change fabs to std::abs

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 4.6 KB
Line 
1// $Id: svm_test.cc 1210 2008-03-06 16:00:44Z peter $
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "yat/classifier/SVM.h"
27#include "yat/classifier/Kernel.h"
28#include "yat/classifier/KernelLookup.h"
29#include "yat/classifier/Kernel_SEV.h"
30#include "yat/classifier/Kernel_MEV.h"
31#include "yat/classifier/MatrixLookup.h"
32#include "yat/classifier/PolynomialKernelFunction.h"
33#include "yat/classifier/Target.h"
34#include "yat/utility/Matrix.h"
35#include "yat/utility/Vector.h"
36
37#include <cassert>
38#include <fstream>
39#include <iostream>
40#include <cstdlib>
41#include <limits>
42
43using namespace theplu::yat;
44
45int main(const int argc,const char* argv[])
46{ 
47
48  std::ostream* error;
49  if (argc>1 && argv[1]==std::string("-v"))
50    error = &std::cerr;
51  else {
52    error = new std::ofstream("/dev/null");
53    if (argc>1)
54      std::cout << "svm_test -v : for printing extra information\n";
55  }
56  *error << "testing svm" << std::endl;
57  bool ok = true;
58
59  utility::Matrix data2_core(2,3);
60  data2_core(0,0)=0;
61  data2_core(1,0)=0;
62  data2_core(0,1)=0;
63  data2_core(1,1)=1;
64  data2_core(0,2)=1;
65  data2_core(1,2)=0;
66  classifier::MatrixLookup data2(data2_core);
67  std::vector<std::string> label;
68  label.reserve(3);
69  label.push_back("-1");
70  label.push_back("1");
71  label.push_back("1");
72  classifier::Target target2(label);
73  classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction(); 
74  classifier::Kernel_MEV kernel2(data2,*kf2);
75  assert(kernel2.size()==3);
76  assert(target2.size()==3);
77  for (size_t i=0; i<3; i++){
78    for (size_t j=0; j<3; j++)
79      *error << kernel2(i,j) << " ";
80    *error << std::endl;
81  }
82  classifier::KernelLookup kv2(kernel2);
83  *error << "testing with linear kernel" << std::endl;
84  assert(kv2.rows()==target2.size());
85  classifier::SVM classifier2;
86  *error << "training...";
87  classifier2.train(kv2, target2);
88  *error << " done!" << std::endl;
89
90  double tmp=0;
91  for (size_t i=0; i<target2.size(); i++) 
92    if (target2.binary(i))
93      tmp += classifier2.alpha()(i);
94    else
95      tmp -= classifier2.alpha()(i);
96
97  if (tmp){
98    *error << "ERROR: found " << tmp << " expected zero" << std::endl;
99    return -1;
100  }
101
102  double tol=1e-6;
103  if (std::abs(classifier2.alpha()(1)-2)>tol || 
104      std::abs(classifier2.alpha()(2)-2)>tol){
105    *error << "wrong alpha" << std::endl;
106    *error << "alpha: " << classifier2.alpha() <<  std::endl;
107    *error << "expected: 4 2 2" <<  std::endl;
108
109    return -1;
110  }
111
112 
113
114  std::ifstream is("data/nm_data_centralized.txt");
115  utility::Matrix data_core(is);
116  is.close();
117
118  classifier::MatrixLookup data(data_core);
119
120  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
121  classifier::Kernel_SEV kernel(data,*kf);
122
123
124  is.open("data/nm_target_bin.txt");
125  classifier::Target target(is);
126  is.close();
127
128  is.open("data/nm_alpha_linear_matlab.txt");
129  theplu::yat::utility::Vector alpha_matlab(is);
130  is.close();
131
132  classifier::KernelLookup kv(kernel);
133  theplu::yat::classifier::SVM svm;
134  svm.train(kv, target);
135
136  theplu::yat::utility::Vector alpha = svm.alpha();
137     
138  // Comparing alpha to alpha_matlab
139  theplu::yat::utility::Vector diff_alpha(alpha);
140  diff_alpha-=alpha_matlab;
141  if (diff_alpha*diff_alpha> 1e-10 ){
142    *error << "Difference to matlab alphas too large\n";
143    ok=false;
144  }
145
146  // Comparing output to target
147  theplu::yat::utility::Vector output(svm.output());
148  double slack = 0;
149  for (unsigned int i=0; i<target.size(); i++){
150    if (output(i)*target(i) < 1){
151      if (target.binary(i))
152        slack += 1 - output(i);
153      else
154        slack += 1 + output(i);
155    }
156  }
157  double slack_bound=2e-7;
158  if (slack > slack_bound || std::isnan(slack)){
159    *error << "Slack too large. Is the bias correct?\n";
160    *error << "slack: " << slack << std::endl;
161    *error << "expected less than " << slack_bound << std::endl;
162    ok = false;
163  }
164 
165  delete kf;
166  delete kf2;
167
168  if (error!=&std::cerr)
169    delete error;
170
171  if(ok)
172    return 0;
173  return -1;
174 
175}
Note: See TracBrowser for help on using the repository browser.