source: trunk/test/svm_test.cc @ 1659

Last change on this file since 1659 was 1487, checked in by Jari Häkkinen, 13 years ago

Addresses #436. GPL license copy reference should also be updated.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 4.5 KB
Line 
1// $Id: svm_test.cc 1487 2008-09-10 08:41:36Z jari $
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7  Copyright (C) 2008 Peter Johansson
8
9  This file is part of the yat library, http://dev.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 3 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with yat. If not, see <http://www.gnu.org/licenses/>.
23*/
24
25#include "Suite.h"
26
27#include "yat/classifier/SVM.h"
28#include "yat/classifier/Kernel.h"
29#include "yat/classifier/KernelLookup.h"
30#include "yat/classifier/Kernel_SEV.h"
31#include "yat/classifier/Kernel_MEV.h"
32#include "yat/classifier/MatrixLookup.h"
33#include "yat/classifier/PolynomialKernelFunction.h"
34#include "yat/classifier/Target.h"
35#include "yat/utility/Matrix.h"
36#include "yat/utility/Vector.h"
37
38#include <cassert>
39#include <fstream>
40#include <iostream>
41#include <cstdlib>
42#include <limits>
43
44using namespace theplu::yat;
45
46int main( int argc, char* argv[])
47{ 
48  test::Suite suite(argc, argv);
49  suite.err() << "testing svm" << std::endl;
50
51  utility::Matrix data2_core(2,3);
52  data2_core(0,0)=0;
53  data2_core(1,0)=0;
54  data2_core(0,1)=0;
55  data2_core(1,1)=1;
56  data2_core(0,2)=1;
57  data2_core(1,2)=0;
58  classifier::MatrixLookup data2(data2_core);
59  std::vector<std::string> label;
60  label.reserve(3);
61  label.push_back("-1");
62  label.push_back("1");
63  label.push_back("1");
64  classifier::Target target2(label);
65  classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction(); 
66  classifier::Kernel_MEV kernel2(data2,*kf2);
67  assert(kernel2.size()==3);
68  assert(target2.size()==3);
69  for (size_t i=0; i<3; i++){
70    for (size_t j=0; j<3; j++)
71      suite.err() << kernel2(i,j) << " ";
72    suite.err() << std::endl;
73  }
74  classifier::KernelLookup kv2(kernel2);
75  suite.err() << "testing with linear kernel" << std::endl;
76  assert(kv2.rows()==target2.size());
77  classifier::SVM classifier2;
78  suite.err() << "training...";
79  classifier2.train(kv2, target2);
80  suite.err() << " done!" << std::endl;
81
82  double tmp=0;
83  for (size_t i=0; i<target2.size(); i++) 
84    if (target2.binary(i))
85      tmp += classifier2.alpha()(i);
86    else
87      tmp -= classifier2.alpha()(i);
88
89  if (tmp){
90    suite.err() << "ERROR: found " << tmp << " expected zero" << std::endl;
91    return -1;
92  }
93
94  double tol=1e-6;
95  if (std::abs(classifier2.alpha()(1)-2)>tol || 
96      std::abs(classifier2.alpha()(2)-2)>tol){
97    suite.err() << "wrong alpha" << std::endl;
98    suite.err() << "alpha: " << classifier2.alpha() <<  std::endl;
99    suite.err() << "expected: 4 2 2" <<  std::endl;
100
101    return -1;
102  }
103
104 
105
106  std::ifstream is(test::filename("data/nm_data_centralized.txt").c_str());
107  utility::Matrix data_core(is);
108  is.close();
109
110  classifier::MatrixLookup data(data_core);
111
112  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
113  classifier::Kernel_SEV kernel(data,*kf);
114
115
116  is.open(test::filename("data/nm_target_bin.txt").c_str());
117  classifier::Target target(is);
118  is.close();
119
120  is.open(test::filename("data/nm_alpha_linear_matlab.txt").c_str());
121  theplu::yat::utility::Vector alpha_matlab(is);
122  is.close();
123
124  classifier::KernelLookup kv(kernel);
125  theplu::yat::classifier::SVM svm;
126  svm.train(kv, target);
127
128  theplu::yat::utility::Vector alpha = svm.alpha();
129     
130  // Comparing alpha to alpha_matlab
131  theplu::yat::utility::Vector diff_alpha(alpha);
132  diff_alpha-=alpha_matlab;
133  if (!(diff_alpha*diff_alpha<1e-10) ){
134    suite.err() << "Difference to matlab alphas too large\n";
135    suite.add(false);
136  }
137
138  // Comparing output to target
139  theplu::yat::utility::Vector output(svm.output());
140  double slack = 0;
141  for (unsigned int i=0; i<target.size(); i++){
142    if (output(i)*target(i) < 1){
143      if (target.binary(i))
144        slack += 1 - output(i);
145      else
146        slack += 1 + output(i);
147    }
148  }
149  double slack_bound=2e-7;
150  if (slack > slack_bound || std::isnan(slack)){
151    suite.err() << "Slack too large. Is the bias correct?\n";
152    suite.err() << "slack: " << slack << std::endl;
153    suite.err() << "expected less than " << slack_bound << std::endl;
154    suite.add(false);
155  }
156 
157  delete kf;
158  delete kf2;
159
160  return suite.return_value();
161}
Note: See TracBrowser for help on using the repository browser.