1 | // $Id: SVM.cc 544 2006-03-05 17:33:33Z peter $ |
---|
2 | |
---|
3 | #include <c++_tools/classifier/SVM.h> |
---|
4 | |
---|
5 | #include <c++_tools/classifier/DataLookup2D.h> |
---|
6 | #include <c++_tools/gslapi/matrix.h> |
---|
7 | #include <c++_tools/gslapi/vector.h> |
---|
8 | #include <c++_tools/statistics/Averager.h> |
---|
9 | #include <c++_tools/statistics/Score.h> |
---|
10 | #include <c++_tools/random/random.h> |
---|
11 | |
---|
12 | #include <algorithm> |
---|
13 | #include <cassert> |
---|
14 | #include <cmath> |
---|
15 | #include <limits> |
---|
16 | #include <utility> |
---|
17 | #include <vector> |
---|
18 | |
---|
19 | |
---|
20 | namespace theplu { |
---|
21 | namespace classifier { |
---|
22 | |
---|
23 | SVM::SVM(const KernelLookup& kernel, const Target& target) |
---|
24 | : SupervisedClassifier(target), |
---|
25 | alpha_(target.size(),0), |
---|
26 | bias_(0), |
---|
27 | C_inverse_(0), |
---|
28 | kernel_(&kernel), |
---|
29 | max_epochs_(10000000), |
---|
30 | output_(target.size(),0), |
---|
31 | sample_(target.size()), |
---|
32 | trained_(false), |
---|
33 | tolerance_(0.00000001) |
---|
34 | { |
---|
35 | } |
---|
36 | |
---|
37 | SVM::SVM(const KernelLookup& kernel, const Target& target, |
---|
38 | statistics::Score& score, const size_t nof_inputs) |
---|
39 | : SupervisedClassifier(target, &score, nof_inputs), |
---|
40 | alpha_(target.size(),0), |
---|
41 | bias_(0), |
---|
42 | C_inverse_(0), |
---|
43 | kernel_(&kernel), |
---|
44 | max_epochs_(10000000), |
---|
45 | output_(target.size(),0), |
---|
46 | sample_(target.size()), |
---|
47 | trained_(false), |
---|
48 | tolerance_(0.00000001) |
---|
49 | { |
---|
50 | } |
---|
51 | |
---|
52 | |
---|
53 | |
---|
54 | SupervisedClassifier* SVM::make_classifier(const DataLookup2D& data, |
---|
55 | const Target& target) const |
---|
56 | { |
---|
57 | // Peter, should check success of dynamic_cast |
---|
58 | const KernelLookup& tmp = dynamic_cast<const KernelLookup&>(data); |
---|
59 | SVM* sc; |
---|
60 | if (score_) |
---|
61 | sc = new SVM(tmp,target,*score_,nof_inputs_); |
---|
62 | else |
---|
63 | sc = new SVM(tmp,target); |
---|
64 | |
---|
65 | //Copy those variables possible to modify from outside |
---|
66 | return sc; |
---|
67 | } |
---|
68 | |
---|
69 | void SVM::predict(const DataLookup2D& input, gslapi::matrix& prediction) const |
---|
70 | { |
---|
71 | assert(input.rows()==alpha_.size()); |
---|
72 | prediction=gslapi::matrix(2,input.columns(),0); |
---|
73 | for (size_t i = 0; i<input.columns(); i++){ |
---|
74 | for (size_t j = 0; i<input.rows(); i++) |
---|
75 | prediction(0,i) += target(j)*alpha_(i)*input(j,i); |
---|
76 | prediction(0,i) += bias_; |
---|
77 | } |
---|
78 | |
---|
79 | for (size_t i = 0; i<prediction.columns(); i++) |
---|
80 | prediction(1,i) = -prediction(0,i); |
---|
81 | } |
---|
82 | |
---|
83 | double SVM::predict(const DataLookup1D& x) const |
---|
84 | { |
---|
85 | double y=0; |
---|
86 | for (size_t i=0; i<alpha_.size(); i++) |
---|
87 | y += alpha_(i)*target_(i)*kernel_->element(x,i); |
---|
88 | |
---|
89 | return y+bias_; |
---|
90 | } |
---|
91 | |
---|
92 | double SVM::predict(const DataLookup1D& x, const DataLookup1D& w) const |
---|
93 | { |
---|
94 | double y=0; |
---|
95 | for (size_t i=0; i<alpha_.size(); i++) |
---|
96 | y += alpha_(i)*target_(i)*kernel_->element(x,w,i); |
---|
97 | |
---|
98 | return y+bias_; |
---|
99 | } |
---|
100 | |
---|
101 | bool SVM::train(void) |
---|
102 | { |
---|
103 | // initializing variables for optimization |
---|
104 | assert(target_.size()==kernel_->rows()); |
---|
105 | assert(target_.size()==alpha_.size()); |
---|
106 | |
---|
107 | sample_.init(alpha_,tolerance_); |
---|
108 | gslapi::vector E(target_.size(),0); |
---|
109 | for (size_t i=0; i<E.size(); i++) { |
---|
110 | for (size_t j=0; j<E.size(); j++) |
---|
111 | E(i) += kernel_mod(i,j)*target(j)*alpha_(j); |
---|
112 | E(i)=E(i)-target(i); |
---|
113 | } |
---|
114 | assert(target_.size()==E.size()); |
---|
115 | assert(target_.size()==sample_.n()); |
---|
116 | |
---|
117 | unsigned long int epochs = 0; |
---|
118 | double alpha_new2; |
---|
119 | double alpha_new1; |
---|
120 | double u; |
---|
121 | double v; |
---|
122 | |
---|
123 | // Training loop |
---|
124 | while(choose(E)) { |
---|
125 | bounds(u,v); |
---|
126 | double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + |
---|
127 | kernel_mod(sample_.value_second(), sample_.value_second()) - |
---|
128 | 2*kernel_mod(sample_.value_first(), sample_.value_second())); |
---|
129 | |
---|
130 | double alpha_old1=alpha_(sample_.value_first()); |
---|
131 | double alpha_old2=alpha_(sample_.value_second()); |
---|
132 | |
---|
133 | alpha_new2 = ( alpha_(sample_.value_second()) + |
---|
134 | target(sample_.value_second())* |
---|
135 | ( E(sample_.value_first())-E(sample_.value_second()) )/k ); |
---|
136 | |
---|
137 | if (alpha_new2 > v) |
---|
138 | alpha_new2 = v; |
---|
139 | else if (alpha_new2<u) |
---|
140 | alpha_new2 = u; |
---|
141 | |
---|
142 | |
---|
143 | // Updating the alphas |
---|
144 | // if alpha is 'zero' make the sample a non-support vector |
---|
145 | if (alpha_new2 < tolerance_){ |
---|
146 | sample_.nsv_second(); |
---|
147 | } |
---|
148 | else{ |
---|
149 | sample_.sv_second(); |
---|
150 | } |
---|
151 | |
---|
152 | |
---|
153 | alpha_new1 = (alpha_(sample_.value_first()) + |
---|
154 | (target(sample_.value_first()) * |
---|
155 | target(sample_.value_second()) * |
---|
156 | (alpha_(sample_.value_second()) - alpha_new2) )); |
---|
157 | |
---|
158 | // if alpha is 'zero' make the sample a non-support vector |
---|
159 | if (alpha_new1 < tolerance_){ |
---|
160 | sample_.nsv_first(); |
---|
161 | } |
---|
162 | else |
---|
163 | sample_.sv_first(); |
---|
164 | |
---|
165 | alpha_(sample_.value_first()) = alpha_new1; |
---|
166 | alpha_(sample_.value_second()) = alpha_new2; |
---|
167 | |
---|
168 | // update E vector |
---|
169 | // Peter, perhaps one should only update SVs, but what happens in choose? |
---|
170 | for (size_t i=0; i<E.size(); i++) { |
---|
171 | E(i)+=( kernel_mod(i,sample_.value_first())* |
---|
172 | target(sample_.value_first()) * |
---|
173 | (alpha_new1-alpha_old1) ); |
---|
174 | E(i)+=( kernel_mod(i,sample_.value_second())* |
---|
175 | target(sample_.value_second()) * |
---|
176 | (alpha_new2-alpha_old2) ); |
---|
177 | } |
---|
178 | |
---|
179 | epochs++; |
---|
180 | if (epochs>max_epochs_){ |
---|
181 | std::cerr << "WARNING: SVM: maximal number of epochs reached.\n"; |
---|
182 | return false; |
---|
183 | } |
---|
184 | } |
---|
185 | |
---|
186 | trained_ = calculate_bias(); |
---|
187 | return trained_; |
---|
188 | } |
---|
189 | |
---|
190 | |
---|
191 | bool SVM::choose(const theplu::gslapi::vector& E) |
---|
192 | { |
---|
193 | // First check for violation among SVs |
---|
194 | // E should be the same for all SVs |
---|
195 | // Choose that pair having largest violation/difference. |
---|
196 | sample_.update_second(0); |
---|
197 | sample_.update_first(0); |
---|
198 | if (sample_.nof_sv()>1){ |
---|
199 | |
---|
200 | double max = E(sample_(0)); |
---|
201 | double min = max; |
---|
202 | for (size_t i=1; i<sample_.nof_sv(); i++){ |
---|
203 | assert(alpha_(sample_(i))>tolerance_); |
---|
204 | if (E(sample_(i)) > max){ |
---|
205 | max = E(sample_(i)); |
---|
206 | sample_.update_second(i); |
---|
207 | } |
---|
208 | else if (E(sample_(i))<min){ |
---|
209 | min = E(sample_(i)); |
---|
210 | sample_.update_first(i); |
---|
211 | } |
---|
212 | } |
---|
213 | assert(alpha_(sample_.value_first())>tolerance_); |
---|
214 | assert(alpha_(sample_.value_second())>tolerance_); |
---|
215 | |
---|
216 | |
---|
217 | if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){ |
---|
218 | return true; |
---|
219 | } |
---|
220 | |
---|
221 | // If no violation check among non-support vectors |
---|
222 | sample_.shuffle(); |
---|
223 | for (size_t i=sample_.nof_sv(); i<sample_.n();i++){ |
---|
224 | if (target_.binary(sample_(i))){ |
---|
225 | if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){ |
---|
226 | sample_.update_second(i); |
---|
227 | return true; |
---|
228 | } |
---|
229 | } |
---|
230 | else{ |
---|
231 | if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){ |
---|
232 | sample_.update_first(i); |
---|
233 | return true; |
---|
234 | } |
---|
235 | } |
---|
236 | } |
---|
237 | } |
---|
238 | |
---|
239 | // if no support vectors - special case |
---|
240 | else{ |
---|
241 | for (size_t i=0; i<sample_.n(); i++) { |
---|
242 | if (target_.binary(sample_(i))){ |
---|
243 | for (size_t j=0; j<sample_.n(); j++) { |
---|
244 | if ( !target_.binary(sample_(j)) && |
---|
245 | E(sample_(i)) < E(sample_(j))+2*tolerance_ ){ |
---|
246 | sample_.update_first(i); |
---|
247 | sample_.update_second(j); |
---|
248 | return true; |
---|
249 | } |
---|
250 | } |
---|
251 | } |
---|
252 | } |
---|
253 | } |
---|
254 | |
---|
255 | // If there is no violation then we should stop training |
---|
256 | return false; |
---|
257 | |
---|
258 | } |
---|
259 | |
---|
260 | |
---|
261 | void SVM::bounds( double& u, double& v) const |
---|
262 | { |
---|
263 | if (target(sample_.value_first())!=target(sample_.value_second())) { |
---|
264 | if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) { |
---|
265 | v = std::numeric_limits<double>::max(); |
---|
266 | u = alpha_(sample_.value_second()) - alpha_(sample_.value_first()); |
---|
267 | } |
---|
268 | else { |
---|
269 | v = (std::numeric_limits<double>::max() - |
---|
270 | alpha_(sample_.value_first()) + |
---|
271 | alpha_(sample_.value_second())); |
---|
272 | u = 0; |
---|
273 | } |
---|
274 | } |
---|
275 | else { |
---|
276 | if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > |
---|
277 | std::numeric_limits<double>::max()) { |
---|
278 | u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - |
---|
279 | std::numeric_limits<double>::max()); |
---|
280 | v = std::numeric_limits<double>::max(); |
---|
281 | } |
---|
282 | else { |
---|
283 | u = 0; |
---|
284 | v = alpha_(sample_.value_first()) + alpha_(sample_.value_second()); |
---|
285 | } |
---|
286 | } |
---|
287 | } |
---|
288 | |
---|
289 | bool SVM::calculate_bias(void) |
---|
290 | { |
---|
291 | |
---|
292 | // calculating output without bias |
---|
293 | for (size_t i=0; i<output_.size(); i++) { |
---|
294 | output_(i)=0; |
---|
295 | for (size_t j=0; j<output_.size(); j++) |
---|
296 | output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j); |
---|
297 | } |
---|
298 | |
---|
299 | if (!sample_.nof_sv()){ |
---|
300 | std::cerr << "SVM::train() error: " |
---|
301 | << "Cannot calculate bias because there is no support vector" |
---|
302 | << std::endl; |
---|
303 | return false; |
---|
304 | } |
---|
305 | |
---|
306 | // For samples with alpha>0, we have: target*output=1-alpha/C |
---|
307 | bias_=0; |
---|
308 | for (size_t i=0; i<sample_.nof_sv(); i++) |
---|
309 | bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - |
---|
310 | output_(sample_(i)) ); |
---|
311 | bias_=bias_/sample_.nof_sv(); |
---|
312 | for (size_t i=0; i<output_.size(); i++) |
---|
313 | output_(i) += bias_; |
---|
314 | |
---|
315 | return true; |
---|
316 | } |
---|
317 | |
---|
318 | Index::Index(void) |
---|
319 | : nof_sv_(0), vec_(std::vector<size_t>(0)) |
---|
320 | { |
---|
321 | } |
---|
322 | |
---|
323 | Index::Index(const size_t n) |
---|
324 | : nof_sv_(0), vec_(std::vector<size_t>(n)) |
---|
325 | { |
---|
326 | for (size_t i=0; i<vec_.size(); i++) |
---|
327 | vec_[i]=i; |
---|
328 | } |
---|
329 | |
---|
330 | void Index::init(const gslapi::vector& alpha, const double tol) |
---|
331 | { |
---|
332 | nof_sv_=0; |
---|
333 | size_t nof_nsv=0; |
---|
334 | for (size_t i=0; i<alpha.size(); i++) |
---|
335 | if (alpha(i)<tol){ |
---|
336 | nof_nsv++; |
---|
337 | vec_[vec_.size()-nof_nsv]=i; |
---|
338 | } |
---|
339 | else{ |
---|
340 | vec_[nof_sv_]=i; |
---|
341 | nof_sv_++; |
---|
342 | } |
---|
343 | assert(nof_sv_+nof_nsv==vec_.size()); |
---|
344 | |
---|
345 | } |
---|
346 | |
---|
347 | void Index::sv_first(void) |
---|
348 | { |
---|
349 | // if already sv, do nothing |
---|
350 | if (index_first_<nof_sv()) |
---|
351 | return; |
---|
352 | |
---|
353 | // swap elements |
---|
354 | if(index_second_==nof_sv_){ |
---|
355 | index_second_=index_first_; |
---|
356 | } |
---|
357 | vec_[index_first_]=vec_[nof_sv_]; |
---|
358 | vec_[nof_sv_]=value_first_; |
---|
359 | index_first_ = nof_sv_; |
---|
360 | |
---|
361 | nof_sv_++; |
---|
362 | |
---|
363 | } |
---|
364 | |
---|
365 | void Index::sv_second(void) |
---|
366 | { |
---|
367 | // if already sv, do nothing |
---|
368 | if (index_second_<nof_sv()) |
---|
369 | return; |
---|
370 | |
---|
371 | // swap elements |
---|
372 | if(index_first_==nof_sv_){ |
---|
373 | index_first_=index_second_; |
---|
374 | } |
---|
375 | |
---|
376 | vec_[index_second_]=vec_[nof_sv_]; |
---|
377 | vec_[nof_sv_]=value_second_; |
---|
378 | index_second_=nof_sv_; |
---|
379 | |
---|
380 | nof_sv_++; |
---|
381 | } |
---|
382 | |
---|
383 | void Index::nsv_first(void) |
---|
384 | { |
---|
385 | // if already nsv, do nothing |
---|
386 | if ( !(index_first_<nof_sv()) ) |
---|
387 | return; |
---|
388 | |
---|
389 | if(index_second_==nof_sv_-1) |
---|
390 | index_second_=index_first_; |
---|
391 | vec_[index_first_]=vec_[nof_sv_-1]; |
---|
392 | vec_[nof_sv_-1]=value_first_; |
---|
393 | index_first_=nof_sv_-1; |
---|
394 | |
---|
395 | nof_sv_--; |
---|
396 | } |
---|
397 | |
---|
398 | void Index::nsv_second(void) |
---|
399 | { |
---|
400 | // if already nsv, do nothing |
---|
401 | if ( !(index_second_<nof_sv()) ) |
---|
402 | return; |
---|
403 | |
---|
404 | if(index_first_==nof_sv_-1) |
---|
405 | index_first_=index_second_; |
---|
406 | vec_[index_second_]=vec_[nof_sv_-1]; |
---|
407 | vec_[nof_sv_-1]=value_second_; |
---|
408 | index_second_ = nof_sv_-1; |
---|
409 | |
---|
410 | nof_sv_--; |
---|
411 | } |
---|
412 | |
---|
413 | |
---|
414 | void Index::shuffle(void) |
---|
415 | { |
---|
416 | random::DiscreteUniform a; |
---|
417 | random_shuffle(vec_.begin()+nof_sv_, vec_.end(), a); |
---|
418 | } |
---|
419 | |
---|
420 | void Index::update_first(const size_t i) |
---|
421 | { |
---|
422 | assert(i<n()); |
---|
423 | index_first_=i; |
---|
424 | value_first_=vec_[i]; |
---|
425 | } |
---|
426 | |
---|
427 | void Index::update_second(const size_t i) |
---|
428 | { |
---|
429 | assert(i<n()); |
---|
430 | index_second_=i; |
---|
431 | value_second_=vec_[i]; |
---|
432 | } |
---|
433 | |
---|
434 | }} // of namespace classifier and namespace theplu |
---|