source: trunk/test/svm_test.cc @ 1248

Last change on this file since 1248 was 1248, checked in by Peter, 14 years ago

fixes tests - ticket:223

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 4.4 KB
Line 
1// $Id: svm_test.cc 1248 2008-03-19 23:07:30Z peter $
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "Suite.h"
27
28#include "yat/classifier/SVM.h"
29#include "yat/classifier/Kernel.h"
30#include "yat/classifier/KernelLookup.h"
31#include "yat/classifier/Kernel_SEV.h"
32#include "yat/classifier/Kernel_MEV.h"
33#include "yat/classifier/MatrixLookup.h"
34#include "yat/classifier/PolynomialKernelFunction.h"
35#include "yat/classifier/Target.h"
36#include "yat/utility/Matrix.h"
37#include "yat/utility/Vector.h"
38
39#include <cassert>
40#include <fstream>
41#include <iostream>
42#include <cstdlib>
43#include <limits>
44
45using namespace theplu::yat;
46
47int main( int argc, char* argv[])
48{ 
49  test::Suite suite(argc, argv);
50  suite.err() << "testing svm" << std::endl;
51
52  utility::Matrix data2_core(2,3);
53  data2_core(0,0)=0;
54  data2_core(1,0)=0;
55  data2_core(0,1)=0;
56  data2_core(1,1)=1;
57  data2_core(0,2)=1;
58  data2_core(1,2)=0;
59  classifier::MatrixLookup data2(data2_core);
60  std::vector<std::string> label;
61  label.reserve(3);
62  label.push_back("-1");
63  label.push_back("1");
64  label.push_back("1");
65  classifier::Target target2(label);
66  classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction(); 
67  classifier::Kernel_MEV kernel2(data2,*kf2);
68  assert(kernel2.size()==3);
69  assert(target2.size()==3);
70  for (size_t i=0; i<3; i++){
71    for (size_t j=0; j<3; j++)
72      suite.err() << kernel2(i,j) << " ";
73    suite.err() << std::endl;
74  }
75  classifier::KernelLookup kv2(kernel2);
76  suite.err() << "testing with linear kernel" << std::endl;
77  assert(kv2.rows()==target2.size());
78  classifier::SVM classifier2;
79  suite.err() << "training...";
80  classifier2.train(kv2, target2);
81  suite.err() << " done!" << std::endl;
82
83  double tmp=0;
84  for (size_t i=0; i<target2.size(); i++) 
85    if (target2.binary(i))
86      tmp += classifier2.alpha()(i);
87    else
88      tmp -= classifier2.alpha()(i);
89
90  if (tmp){
91    suite.err() << "ERROR: found " << tmp << " expected zero" << std::endl;
92    return -1;
93  }
94
95  double tol=1e-6;
96  if (std::abs(classifier2.alpha()(1)-2)>tol || 
97      std::abs(classifier2.alpha()(2)-2)>tol){
98    suite.err() << "wrong alpha" << std::endl;
99    suite.err() << "alpha: " << classifier2.alpha() <<  std::endl;
100    suite.err() << "expected: 4 2 2" <<  std::endl;
101
102    return -1;
103  }
104
105 
106
107  std::ifstream is("data/nm_data_centralized.txt");
108  utility::Matrix data_core(is);
109  is.close();
110
111  classifier::MatrixLookup data(data_core);
112
113  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
114  classifier::Kernel_SEV kernel(data,*kf);
115
116
117  is.open("data/nm_target_bin.txt");
118  classifier::Target target(is);
119  is.close();
120
121  is.open("data/nm_alpha_linear_matlab.txt");
122  theplu::yat::utility::Vector alpha_matlab(is);
123  is.close();
124
125  classifier::KernelLookup kv(kernel);
126  theplu::yat::classifier::SVM svm;
127  svm.train(kv, target);
128
129  theplu::yat::utility::Vector alpha = svm.alpha();
130     
131  // Comparing alpha to alpha_matlab
132  theplu::yat::utility::Vector diff_alpha(alpha);
133  diff_alpha-=alpha_matlab;
134  if (!(diff_alpha*diff_alpha<1e-10) ){
135    suite.err() << "Difference to matlab alphas too large\n";
136    suite.add(false);
137  }
138
139  // Comparing output to target
140  theplu::yat::utility::Vector output(svm.output());
141  double slack = 0;
142  for (unsigned int i=0; i<target.size(); i++){
143    if (output(i)*target(i) < 1){
144      if (target.binary(i))
145        slack += 1 - output(i);
146      else
147        slack += 1 + output(i);
148    }
149  }
150  double slack_bound=2e-7;
151  if (slack > slack_bound || std::isnan(slack)){
152    suite.err() << "Slack too large. Is the bias correct?\n";
153    suite.err() << "slack: " << slack << std::endl;
154    suite.err() << "expected less than " << slack_bound << std::endl;
155    suite.add(false);
156  }
157 
158  delete kf;
159  delete kf2;
160
161  return suite.return_value();
162}
Note: See TracBrowser for help on using the repository browser.