source: trunk/yat/classifier/SVM.cc @ 1175

Last change on this file since 1175 was 1175, checked in by Peter, 14 years ago

fixed bug in SVM. SVM does not own the Kernel and should therefore never delete it. Cpy and assignment can simply copy the pointer without problem. Yet if the Kernel is deallocated outside, behavior of SVM is undefined.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.8 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "SVM.h"
27#include "KernelLookup.h"
28#include "Target.h"
29#include "yat/random/random.h"
30#include "yat/statistics/Averager.h"
31#include "yat/utility/Matrix.h"
32#include "yat/utility/Vector.h"
33
34#include <algorithm>
35#include <cassert>
36#include <cctype>
37#include <cmath>
38#include <limits>
39#include <sstream>
40#include <stdexcept>
41#include <string>
42#include <utility>
43#include <vector>
44
45namespace theplu {
46namespace yat {
47namespace classifier { 
48
49  SVM::SVM(void)
50    : bias_(0),
51      C_inverse_(0),
52      kernel_(NULL),
53      margin_(0),
54      max_epochs_(100000),
55      tolerance_(0.00000001),
56      trained_(false)
57  {
58  }
59
60
61  SVM::SVM(const SVM& other)
62    : bias_(other.bias_), C_inverse_(other.C_inverse_), kernel_(kernel_),
63      margin_(0), max_epochs_(other.max_epochs_), tolerance_(other.tolerance_),
64      trained_(other.trained_)
65  {
66  }
67
68
69  SVM::~SVM()
70  {
71  }
72
73
74  const utility::Vector& SVM::alpha(void) const
75  {
76    return alpha_;
77  }
78
79
80  double SVM::C(void) const
81  {
82    return 1.0/C_inverse_;
83  }
84
85
86  void SVM::calculate_margin(void)
87  {
88    margin_ = 0;
89    for(size_t i = 0; i<alpha_.size(); ++i){
90      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
91      for(size_t j = i+1; j<alpha_.size(); ++j)
92        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
93    }
94  }
95
96
97  /*
98  const DataLookup2D& SVM::data(void) const
99  {
100    return *kernel_;
101  }
102  */
103
104
105  double SVM::kernel_mod(const size_t i, const size_t j) const
106  {
107    assert(kernel_);
108    assert(i<kernel_->rows());
109    assert(i<kernel_->columns());
110    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
111  }
112
113
114  SVM* SVM::make_classifier(void) const
115  {
116    SVM* svm = new SVM(*this);
117    svm->trained_ = false;
118    return svm;
119  }
120
121
122  long int SVM::max_epochs(void) const
123  {
124    return max_epochs_;
125  }
126
127
128  void SVM::max_epochs(long int n)
129  {
130    max_epochs_=n;
131  }
132
133
134  const utility::Vector& SVM::output(void) const
135  {
136    return output_;
137  }
138
139  void SVM::predict(const KernelLookup& input, utility::Matrix& prediction) const
140  {
141    assert(input.rows()==alpha_.size());
142    prediction.resize(2,input.columns(),0);
143    for (size_t i = 0; i<input.columns(); i++){
144      for (size_t j = 0; j<input.rows(); j++){
145        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
146        assert(target(j));
147      }
148      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
149    }
150   
151    for (size_t i = 0; i<prediction.columns(); i++)
152      prediction(1,i) = -prediction(0,i);
153  }
154
155  double SVM::predict(const DataLookup1D& x) const
156  {
157    double y=0;
158    for (size_t i=0; i<alpha_.size(); i++)
159      y += alpha_(i)*target_(i)*kernel_->element(x,i);
160
161    return margin_*(y+bias_);
162  }
163
164  double SVM::predict(const DataLookupWeighted1D& x) const
165  {
166    double y=0;
167    for (size_t i=0; i<alpha_.size(); i++)
168      y += alpha_(i)*target_(i)*kernel_->element(x,i);
169
170    return margin_*(y+bias_);
171  }
172
173  int SVM::target(size_t i) const
174  {
175    assert(i<target_.size());
176    return target_.binary(i) ? 1 : -1;
177  }
178
179  void SVM::train(const KernelLookup& kernel, const Target& targ) 
180  {
181    kernel_ = new KernelLookup(kernel);
182    target_ = targ;
183   
184    alpha_ = utility::Vector(targ.size(), 0.0);
185    output_ = utility::Vector(targ.size(), 0.0);
186    // initializing variables for optimization
187    assert(target_.size()==kernel_->rows());
188    assert(target_.size()==alpha_.size());
189
190    sample_.init(alpha_,tolerance_);
191    utility::Vector   E(target_.size(),0);
192    for (size_t i=0; i<E.size(); i++) {
193      for (size_t j=0; j<E.size(); j++) 
194        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
195      E(i)-=target(i);
196    }
197    assert(target_.size()==E.size());
198    assert(target_.size()==sample_.size());
199
200    unsigned long int epochs = 0;
201    double alpha_new2;
202    double alpha_new1;
203    double u;
204    double v;
205
206    // Training loop
207    while(choose(E)) {
208      bounds(u,v);       
209      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
210                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
211                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
212     
213      double alpha_old1=alpha_(sample_.value_first());
214      double alpha_old2=alpha_(sample_.value_second());
215      alpha_new2 = ( alpha_(sample_.value_second()) + 
216                     target(sample_.value_second())*
217                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
218     
219      if (alpha_new2 > v)
220        alpha_new2 = v;
221      else if (alpha_new2<u)
222        alpha_new2 = u;
223     
224      // Updating the alphas
225      // if alpha is 'zero' make the sample a non-support vector
226      if (alpha_new2 < tolerance_){
227        sample_.nsv_second();
228      }
229      else{
230        sample_.sv_second();
231      }
232     
233     
234      alpha_new1 = (alpha_(sample_.value_first()) + 
235                    (target(sample_.value_first()) * 
236                     target(sample_.value_second()) * 
237                     (alpha_(sample_.value_second()) - alpha_new2) ));
238           
239      // if alpha is 'zero' make the sample a non-support vector
240      if (alpha_new1 < tolerance_){
241        sample_.nsv_first();
242      }
243      else
244        sample_.sv_first();
245     
246      alpha_(sample_.value_first()) = alpha_new1;
247      alpha_(sample_.value_second()) = alpha_new2;
248     
249      // update E vector
250      // Peter, perhaps one should only update SVs, but what happens in choose?
251      for (size_t i=0; i<E.size(); i++) {
252        E(i)+=( kernel_mod(i,sample_.value_first())*
253                target(sample_.value_first()) *
254                (alpha_new1-alpha_old1) );
255        E(i)+=( kernel_mod(i,sample_.value_second())*
256                target(sample_.value_second()) *
257                (alpha_new2-alpha_old2) );
258      }
259           
260      epochs++; 
261      if (epochs>max_epochs_){
262        throw std::runtime_error("SVM: maximal number of epochs reached.");
263      }
264    }
265    calculate_margin();
266    calculate_bias();
267    trained_ = true;
268  }
269
270
271  bool SVM::choose(const theplu::yat::utility::Vector& E)
272  {
273    // First check for violation among SVs
274    // E should be the same for all SVs
275    // Choose that pair having largest violation/difference.
276    sample_.update_second(0);
277    sample_.update_first(0);
278    if (sample_.nof_sv()>1){
279
280      double max = E(sample_(0));
281      double min = max;
282      for (size_t i=1; i<sample_.nof_sv(); i++){ 
283        assert(alpha_(sample_(i))>tolerance_);
284        if (E(sample_(i)) > max){
285          max = E(sample_(i));
286          sample_.update_second(i);
287        }
288        else if (E(sample_(i))<min){
289          min = E(sample_(i));
290          sample_.update_first(i);
291        }
292      }
293      assert(alpha_(sample_.value_first())>tolerance_);
294      assert(alpha_(sample_.value_second())>tolerance_);
295
296      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
297        return true;
298      }
299     
300      // If no violation check among non-support vectors
301      sample_.shuffle();
302      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
303        if (target_.binary(sample_(i))){
304          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
305            sample_.update_second(i);
306            return true;
307          }
308        }
309        else{
310          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
311            sample_.update_first(i);
312            return true;
313          }
314        }
315      }
316    }
317
318    // if no support vectors - special case
319    else{
320      // to avoid getting stuck we shuffle
321      sample_.shuffle();
322      for (size_t i=0; i<sample_.size(); i++) {
323        if (target(sample_(i))==1){
324          for (size_t j=0; j<sample_.size(); j++) {
325            if ( target(sample_(j))==-1 && 
326                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
327              sample_.update_first(i);
328              sample_.update_second(j);
329              return true;
330            }
331          }
332        }
333      }
334    }
335   
336    // If there is no violation then we should stop training
337    return false;
338
339  }
340 
341 
342  void SVM::bounds( double& u, double& v) const
343  {
344    if (target(sample_.value_first())!=target(sample_.value_second())) {
345      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
346        v = std::numeric_limits<double>::max();
347        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
348      }
349      else {
350        v = (std::numeric_limits<double>::max() - 
351             alpha_(sample_.value_first()) + 
352             alpha_(sample_.value_second()));
353        u = 0;
354      }
355    }
356    else {       
357      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
358           std::numeric_limits<double>::max()) {
359        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
360              std::numeric_limits<double>::max());
361        v =  std::numeric_limits<double>::max();   
362      }
363      else {
364        u = 0;
365        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
366      }
367    }
368  }
369 
370  void SVM::calculate_bias(void)
371  {
372
373    // calculating output without bias
374    for (size_t i=0; i<output_.size(); i++) {
375      output_(i)=0;
376      for (size_t j=0; j<output_.size(); j++) 
377        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
378    }
379
380    if (!sample_.nof_sv()){
381      std::stringstream ss;
382      ss << "yat::classifier::SVM::train() error: " 
383         << "Cannot calculate bias because there is no support vector"; 
384      throw std::runtime_error(ss.str());
385    }
386
387    // For samples with alpha>0, we have: target*output=1-alpha/C
388    bias_=0;
389    for (size_t i=0; i<sample_.nof_sv(); i++) 
390      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
391                output_(sample_(i)) );
392    bias_=bias_/sample_.nof_sv();
393    for (size_t i=0; i<output_.size(); i++) 
394      output_(i) += bias_;
395  }
396
397  void SVM::set_C(const double C)
398  {
399    C_inverse_ = 1/C;
400  }
401
402}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.