source: trunk/yat/classifier/SVM.cc @ 1177

Last change on this file since 1177 was 1177, checked in by Peter, 14 years ago

minor fix - see [1175]

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.8 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "SVM.h"
27#include "KernelLookup.h"
28#include "Target.h"
29#include "yat/random/random.h"
30#include "yat/statistics/Averager.h"
31#include "yat/utility/Matrix.h"
32#include "yat/utility/Vector.h"
33
34#include <algorithm>
35#include <cassert>
36#include <cctype>
37#include <cmath>
38#include <limits>
39#include <sstream>
40#include <stdexcept>
41#include <string>
42#include <utility>
43#include <vector>
44
45namespace theplu {
46namespace yat {
47namespace classifier { 
48
49  SVM::SVM(void)
50    : bias_(0),
51      C_inverse_(0),
52      kernel_(NULL),
53      margin_(0),
54      max_epochs_(100000),
55      tolerance_(0.00000001),
56      trained_(false)
57  {
58  }
59
60
61  SVM::SVM(const SVM& other)
62    : bias_(other.bias_), C_inverse_(other.C_inverse_), kernel_(other.kernel_),
63      margin_(0), max_epochs_(other.max_epochs_), tolerance_(other.tolerance_),
64      trained_(other.trained_)
65  {
66  }
67
68
69  SVM::~SVM()
70  {
71  }
72
73
74  const utility::Vector& SVM::alpha(void) const
75  {
76    return alpha_;
77  }
78
79
80  double SVM::C(void) const
81  {
82    return 1.0/C_inverse_;
83  }
84
85
86  void SVM::calculate_margin(void)
87  {
88    margin_ = 0;
89    for(size_t i = 0; i<alpha_.size(); ++i){
90      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
91      for(size_t j = i+1; j<alpha_.size(); ++j)
92        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
93    }
94  }
95
96
97  /*
98  const DataLookup2D& SVM::data(void) const
99  {
100    return *kernel_;
101  }
102  */
103
104
105  double SVM::kernel_mod(const size_t i, const size_t j) const
106  {
107    assert(kernel_);
108    assert(i<kernel_->rows());
109    assert(i<kernel_->columns());
110    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
111  }
112
113
114  SVM* SVM::make_classifier(void) const
115  {
116    SVM* svm = new SVM(*this);
117    svm->trained_ = false;
118    return svm;
119  }
120
121
122  long int SVM::max_epochs(void) const
123  {
124    return max_epochs_;
125  }
126
127
128  void SVM::max_epochs(long int n)
129  {
130    max_epochs_=n;
131  }
132
133
134  const utility::Vector& SVM::output(void) const
135  {
136    return output_;
137  }
138
139  void SVM::predict(const KernelLookup& input, utility::Matrix& prediction) const
140  {
141    assert(input.rows()==alpha_.size());
142    prediction.resize(2,input.columns(),0);
143    for (size_t i = 0; i<input.columns(); i++){
144      for (size_t j = 0; j<input.rows(); j++){
145        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
146        assert(target(j));
147      }
148      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
149    }
150   
151    for (size_t i = 0; i<prediction.columns(); i++)
152      prediction(1,i) = -prediction(0,i);
153  }
154
155  double SVM::predict(const DataLookup1D& x) const
156  {
157    double y=0;
158    for (size_t i=0; i<alpha_.size(); i++)
159      y += alpha_(i)*target_(i)*kernel_->element(x,i);
160
161    return margin_*(y+bias_);
162  }
163
164  double SVM::predict(const DataLookupWeighted1D& x) const
165  {
166    double y=0;
167    for (size_t i=0; i<alpha_.size(); i++)
168      y += alpha_(i)*target_(i)*kernel_->element(x,i);
169
170    return margin_*(y+bias_);
171  }
172
173  int SVM::target(size_t i) const
174  {
175    assert(i<target_.size());
176    return target_.binary(i) ? 1 : -1;
177  }
178
179  void SVM::train(const KernelLookup& kernel, const Target& targ) 
180  {
181    kernel_ = new KernelLookup(kernel);
182    target_ = targ;
183   
184    alpha_ = utility::Vector(targ.size(), 0.0);
185    output_ = utility::Vector(targ.size(), 0.0);
186    // initializing variables for optimization
187    assert(target_.size()==kernel_->rows());
188    assert(target_.size()==alpha_.size());
189
190    sample_.init(alpha_,tolerance_);
191    utility::Vector   E(target_.size(),0);
192    for (size_t i=0; i<E.size(); i++) {
193      for (size_t j=0; j<E.size(); j++) 
194        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
195      E(i)-=target(i);
196    }
197    assert(target_.size()==E.size());
198    assert(target_.size()==sample_.size());
199
200    unsigned long int epochs = 0;
201    double alpha_new2;
202    double alpha_new1;
203    double u;
204    double v;
205
206    // Training loop
207    while(choose(E)) {
208      bounds(u,v);       
209      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
210                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
211                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
212     
213      double alpha_old1=alpha_(sample_.value_first());
214      double alpha_old2=alpha_(sample_.value_second());
215      alpha_new2 = ( alpha_(sample_.value_second()) + 
216                     target(sample_.value_second())*
217                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
218     
219      if (alpha_new2 > v)
220        alpha_new2 = v;
221      else if (alpha_new2<u)
222        alpha_new2 = u;
223     
224      // Updating the alphas
225      // if alpha is 'zero' make the sample a non-support vector
226      if (alpha_new2 < tolerance_){
227        sample_.nsv_second();
228      }
229      else{
230        sample_.sv_second();
231      }
232     
233     
234      alpha_new1 = (alpha_(sample_.value_first()) + 
235                    (target(sample_.value_first()) * 
236                     target(sample_.value_second()) * 
237                     (alpha_(sample_.value_second()) - alpha_new2) ));
238           
239      // if alpha is 'zero' make the sample a non-support vector
240      if (alpha_new1 < tolerance_){
241        sample_.nsv_first();
242      }
243      else
244        sample_.sv_first();
245     
246      alpha_(sample_.value_first()) = alpha_new1;
247      alpha_(sample_.value_second()) = alpha_new2;
248     
249      // update E vector
250      // Peter, perhaps one should only update SVs, but what happens in choose?
251      for (size_t i=0; i<E.size(); i++) {
252        E(i)+=( kernel_mod(i,sample_.value_first())*
253                target(sample_.value_first()) *
254                (alpha_new1-alpha_old1) );
255        E(i)+=( kernel_mod(i,sample_.value_second())*
256                target(sample_.value_second()) *
257                (alpha_new2-alpha_old2) );
258      }
259           
260      epochs++; 
261      if (epochs>max_epochs_){
262        throw std::runtime_error("SVM: maximal number of epochs reached.");
263      }
264    }
265    calculate_margin();
266    calculate_bias();
267    trained_ = true;
268  }
269
270
271  bool SVM::choose(const theplu::yat::utility::Vector& E)
272  {
273    // First check for violation among SVs
274    // E should be the same for all SVs
275    // Choose that pair having largest violation/difference.
276    sample_.update_second(0);
277    sample_.update_first(0);
278    if (sample_.nof_sv()>1){
279
280      double max = E(sample_(0));
281      double min = max;
282      for (size_t i=1; i<sample_.nof_sv(); i++){ 
283        assert(alpha_(sample_(i))>tolerance_);
284        if (E(sample_(i)) > max){
285          max = E(sample_(i));
286          sample_.update_second(i);
287        }
288        else if (E(sample_(i))<min){
289          min = E(sample_(i));
290          sample_.update_first(i);
291        }
292      }
293      assert(alpha_(sample_.value_first())>tolerance_);
294      assert(alpha_(sample_.value_second())>tolerance_);
295
296      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
297        return true;
298      }
299     
300      // If no violation check among non-support vectors
301      sample_.shuffle();
302      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
303        if (target_.binary(sample_(i))){
304          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
305            sample_.update_second(i);
306            return true;
307          }
308        }
309        else{
310          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
311            sample_.update_first(i);
312            return true;
313          }
314        }
315      }
316    }
317
318    // if no support vectors - special case
319    else{
320      // to avoid getting stuck we shuffle
321      sample_.shuffle();
322      for (size_t i=0; i<sample_.size(); i++) {
323        if (target(sample_(i))==1){
324          for (size_t j=0; j<sample_.size(); j++) {
325            if ( target(sample_(j))==-1 && 
326                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
327              sample_.update_first(i);
328              sample_.update_second(j);
329              return true;
330            }
331          }
332        }
333      }
334    }
335   
336    // If there is no violation then we should stop training
337    return false;
338
339  }
340 
341 
342  void SVM::bounds( double& u, double& v) const
343  {
344    if (target(sample_.value_first())!=target(sample_.value_second())) {
345      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
346        v = std::numeric_limits<double>::max();
347        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
348      }
349      else {
350        v = (std::numeric_limits<double>::max() - 
351             alpha_(sample_.value_first()) + 
352             alpha_(sample_.value_second()));
353        u = 0;
354      }
355    }
356    else {       
357      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
358           std::numeric_limits<double>::max()) {
359        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
360              std::numeric_limits<double>::max());
361        v =  std::numeric_limits<double>::max();   
362      }
363      else {
364        u = 0;
365        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
366      }
367    }
368  }
369 
370  void SVM::calculate_bias(void)
371  {
372
373    // calculating output without bias
374    for (size_t i=0; i<output_.size(); i++) {
375      output_(i)=0;
376      for (size_t j=0; j<output_.size(); j++) 
377        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
378    }
379
380    if (!sample_.nof_sv()){
381      std::stringstream ss;
382      ss << "yat::classifier::SVM::train() error: " 
383         << "Cannot calculate bias because there is no support vector"; 
384      throw std::runtime_error(ss.str());
385    }
386
387    // For samples with alpha>0, we have: target*output=1-alpha/C
388    bias_=0;
389    for (size_t i=0; i<sample_.nof_sv(); i++) 
390      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
391                output_(sample_(i)) );
392    bias_=bias_/sample_.nof_sv();
393    for (size_t i=0; i<output_.size(); i++) 
394      output_(i) += bias_;
395  }
396
397  void SVM::set_C(const double C)
398  {
399    C_inverse_ = 1/C;
400  }
401
402}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.