1 | // $Id: svm_test.cc 1672 2008-12-22 13:27:20Z peter $ |
---|
2 | |
---|
3 | /* |
---|
4 | Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson |
---|
5 | Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér |
---|
6 | Copyright (C) 2007 Jari Häkkinen, Peter Johansson |
---|
7 | Copyright (C) 2008 Peter Johansson |
---|
8 | |
---|
9 | This file is part of the yat library, http://dev.thep.lu.se/yat |
---|
10 | |
---|
11 | The yat library is free software; you can redistribute it and/or |
---|
12 | modify it under the terms of the GNU General Public License as |
---|
13 | published by the Free Software Foundation; either version 3 of the |
---|
14 | License, or (at your option) any later version. |
---|
15 | |
---|
16 | The yat library is distributed in the hope that it will be useful, |
---|
17 | but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
18 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
---|
19 | General Public License for more details. |
---|
20 | |
---|
21 | You should have received a copy of the GNU General Public License |
---|
22 | along with yat. If not, see <http://www.gnu.org/licenses/>. |
---|
23 | */ |
---|
24 | |
---|
25 | #include "Suite.h" |
---|
26 | |
---|
27 | #include "yat/classifier/SVM.h" |
---|
28 | #include "yat/classifier/Kernel.h" |
---|
29 | #include "yat/classifier/KernelLookup.h" |
---|
30 | #include "yat/classifier/Kernel_SEV.h" |
---|
31 | #include "yat/classifier/Kernel_MEV.h" |
---|
32 | #include "yat/classifier/MatrixLookup.h" |
---|
33 | #include "yat/classifier/PolynomialKernelFunction.h" |
---|
34 | #include "yat/classifier/Target.h" |
---|
35 | #include "yat/utility/Matrix.h" |
---|
36 | #include "yat/utility/Vector.h" |
---|
37 | |
---|
38 | #include <cassert> |
---|
39 | #include <fstream> |
---|
40 | #include <iostream> |
---|
41 | #include <cstdlib> |
---|
42 | #include <limits> |
---|
43 | |
---|
44 | using namespace theplu::yat; |
---|
45 | |
---|
46 | int main( int argc, char* argv[]) |
---|
47 | { |
---|
48 | test::Suite suite(argc, argv); |
---|
49 | suite.err() << "testing svm" << std::endl; |
---|
50 | |
---|
51 | utility::Matrix data2_core(2,3); |
---|
52 | data2_core(0,0)=0; |
---|
53 | data2_core(1,0)=0; |
---|
54 | data2_core(0,1)=0; |
---|
55 | data2_core(1,1)=1; |
---|
56 | data2_core(0,2)=1; |
---|
57 | data2_core(1,2)=0; |
---|
58 | classifier::MatrixLookup data2(data2_core); |
---|
59 | std::vector<std::string> label; |
---|
60 | label.reserve(3); |
---|
61 | label.push_back("-1"); |
---|
62 | label.push_back("1"); |
---|
63 | label.push_back("1"); |
---|
64 | classifier::Target target2(label); |
---|
65 | classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction(); |
---|
66 | classifier::Kernel_MEV kernel2(data2,*kf2); |
---|
67 | assert(kernel2.size()==3); |
---|
68 | assert(target2.size()==3); |
---|
69 | for (size_t i=0; i<3; i++){ |
---|
70 | for (size_t j=0; j<3; j++) |
---|
71 | suite.err() << kernel2(i,j) << " "; |
---|
72 | suite.err() << std::endl; |
---|
73 | } |
---|
74 | classifier::KernelLookup kv2(kernel2); |
---|
75 | suite.err() << "testing with linear kernel" << std::endl; |
---|
76 | assert(kv2.rows()==target2.size()); |
---|
77 | classifier::SVM classifier2; |
---|
78 | suite.err() << "training..."; |
---|
79 | classifier2.train(kv2, target2); |
---|
80 | suite.err() << " done!" << std::endl; |
---|
81 | |
---|
82 | double tmp=0; |
---|
83 | for (size_t i=0; i<target2.size(); i++) |
---|
84 | if (target2.binary(i)) |
---|
85 | tmp += classifier2.alpha()(i); |
---|
86 | else |
---|
87 | tmp -= classifier2.alpha()(i); |
---|
88 | |
---|
89 | if (tmp){ |
---|
90 | suite.err() << "ERROR: found " << tmp << " expected zero" << std::endl; |
---|
91 | return -1; |
---|
92 | } |
---|
93 | |
---|
94 | // tol defined on learning precision |
---|
95 | double tol=1e-6; |
---|
96 | if (!suite.equal_fix(classifier2.alpha()(1), 2.0, tol) || |
---|
97 | !suite.equal_fix(classifier2.alpha()(2), 2.0, tol) ) { |
---|
98 | suite.err() << "wrong alpha" << std::endl; |
---|
99 | suite.err() << "alpha: " << classifier2.alpha() << std::endl; |
---|
100 | suite.err() << "expected: 4 2 2" << std::endl; |
---|
101 | |
---|
102 | return -1; |
---|
103 | } |
---|
104 | |
---|
105 | |
---|
106 | |
---|
107 | std::ifstream is(test::filename("data/nm_data_centralized.txt").c_str()); |
---|
108 | utility::Matrix data_core(is); |
---|
109 | is.close(); |
---|
110 | |
---|
111 | classifier::MatrixLookup data(data_core); |
---|
112 | |
---|
113 | classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); |
---|
114 | classifier::Kernel_SEV kernel(data,*kf); |
---|
115 | |
---|
116 | |
---|
117 | is.open(test::filename("data/nm_target_bin.txt").c_str()); |
---|
118 | classifier::Target target(is); |
---|
119 | is.close(); |
---|
120 | |
---|
121 | is.open(test::filename("data/nm_alpha_linear_matlab.txt").c_str()); |
---|
122 | theplu::yat::utility::Vector alpha_matlab(is); |
---|
123 | is.close(); |
---|
124 | |
---|
125 | classifier::KernelLookup kv(kernel); |
---|
126 | theplu::yat::classifier::SVM svm; |
---|
127 | svm.train(kv, target); |
---|
128 | |
---|
129 | theplu::yat::utility::Vector alpha = svm.alpha(); |
---|
130 | |
---|
131 | // Comparing alpha to alpha_matlab |
---|
132 | if (!suite.equal_range_fix(alpha.begin(), alpha.end(), |
---|
133 | alpha_matlab.begin(), 1e-6) ) { |
---|
134 | suite.err() << "Difference to matlab alphas too large\n"; |
---|
135 | suite.add(false); |
---|
136 | } |
---|
137 | |
---|
138 | |
---|
139 | // Comparing output to target |
---|
140 | theplu::yat::utility::Vector output(svm.output()); |
---|
141 | double slack = 0; |
---|
142 | for (unsigned int i=0; i<target.size(); i++){ |
---|
143 | if (output(i)*target(i) < 1){ |
---|
144 | if (target.binary(i)) |
---|
145 | slack = 1 - output(i); |
---|
146 | else |
---|
147 | slack = 1 + output(i); |
---|
148 | double slack_bound=2e-7; |
---|
149 | if (slack > slack_bound || std::isnan(slack)){ |
---|
150 | suite.err() << "Slack too large. Is the bias correct?\n"; |
---|
151 | suite.err() << "slack: " << slack << std::endl; |
---|
152 | suite.err() << "expected less than " << slack_bound << std::endl; |
---|
153 | suite.add(false); |
---|
154 | } |
---|
155 | } |
---|
156 | } |
---|
157 | |
---|
158 | delete kf; |
---|
159 | delete kf2; |
---|
160 | |
---|
161 | return suite.return_value(); |
---|
162 | } |
---|