source: trunk/yat/classifier/SVM.cc @ 1487

Last change on this file since 1487 was 1487, checked in by Jari Häkkinen, 13 years ago

Addresses #436. GPL license copy reference should also be updated.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.8 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7  Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://dev.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 3 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with yat. If not, see <http://www.gnu.org/licenses/>.
23*/
24
25#include "SVM.h"
26#include "KernelLookup.h"
27#include "Target.h"
28#include "yat/random/random.h"
29#include "yat/statistics/Averager.h"
30#include "yat/utility/Matrix.h"
31#include "yat/utility/Vector.h"
32
33#include <algorithm>
34#include <cassert>
35#include <cctype>
36#include <cmath>
37#include <limits>
38#include <sstream>
39#include <stdexcept>
40#include <string>
41#include <utility>
42#include <vector>
43
44namespace theplu {
45namespace yat {
46namespace classifier { 
47
48  SVM::SVM(void)
49    : bias_(0),
50      C_inverse_(0),
51      kernel_(NULL),
52      margin_(0),
53      max_epochs_(100000),
54      tolerance_(0.00000001),
55      trained_(false)
56  {
57  }
58
59
60  SVM::SVM(const SVM& other)
61    : bias_(other.bias_), C_inverse_(other.C_inverse_), kernel_(other.kernel_),
62      margin_(0), max_epochs_(other.max_epochs_), tolerance_(other.tolerance_),
63      trained_(other.trained_)
64  {
65  }
66
67
68  SVM::~SVM()
69  {
70  }
71
72
73  const utility::Vector& SVM::alpha(void) const
74  {
75    return alpha_;
76  }
77
78
79  double SVM::C(void) const
80  {
81    return 1.0/C_inverse_;
82  }
83
84
85  void SVM::calculate_margin(void)
86  {
87    margin_ = 0;
88    for(size_t i = 0; i<alpha_.size(); ++i){
89      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
90      for(size_t j = i+1; j<alpha_.size(); ++j)
91        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
92    }
93  }
94
95
96  /*
97  const DataLookup2D& SVM::data(void) const
98  {
99    return *kernel_;
100  }
101  */
102
103
104  double SVM::kernel_mod(const size_t i, const size_t j) const
105  {
106    assert(kernel_);
107    assert(i<kernel_->rows());
108    assert(i<kernel_->columns());
109    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
110  }
111
112
113  SVM* SVM::make_classifier(void) const
114  {
115    SVM* svm = new SVM(*this);
116    svm->trained_ = false;
117    return svm;
118  }
119
120
121  long int SVM::max_epochs(void) const
122  {
123    return max_epochs_;
124  }
125
126
127  void SVM::max_epochs(long int n)
128  {
129    max_epochs_=n;
130  }
131
132
133  const utility::Vector& SVM::output(void) const
134  {
135    return output_;
136  }
137
138  void SVM::predict(const KernelLookup& input, utility::Matrix& prediction) const
139  {
140    assert(input.rows()==alpha_.size());
141    prediction.resize(2,input.columns(),0);
142    for (size_t i = 0; i<input.columns(); i++){
143      for (size_t j = 0; j<input.rows(); j++){
144        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
145        assert(target(j));
146      }
147      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
148    }
149   
150    for (size_t i = 0; i<prediction.columns(); i++)
151      prediction(1,i) = -prediction(0,i);
152  }
153
154  /*
155  double SVM::predict(const DataLookup1D& x) const
156  {
157    double y=0;
158    for (size_t i=0; i<alpha_.size(); i++)
159      y += alpha_(i)*target_(i)*kernel_->element(x,i);
160
161    return margin_*(y+bias_);
162  }
163
164  double SVM::predict(const DataLookupWeighted1D& x) const
165  {
166    double y=0;
167    for (size_t i=0; i<alpha_.size(); i++)
168      y += alpha_(i)*target_(i)*kernel_->element(x,i);
169
170    return margin_*(y+bias_);
171  }
172  */
173
174  int SVM::target(size_t i) const
175  {
176    assert(i<target_.size());
177    return target_.binary(i) ? 1 : -1;
178  }
179
180  void SVM::train(const KernelLookup& kernel, const Target& targ) 
181  {
182    kernel_ = new KernelLookup(kernel);
183    target_ = targ;
184   
185    alpha_ = utility::Vector(targ.size(), 0.0);
186    output_ = utility::Vector(targ.size(), 0.0);
187    // initializing variables for optimization
188    assert(target_.size()==kernel_->rows());
189    assert(target_.size()==alpha_.size());
190
191    sample_.init(alpha_,tolerance_);
192    utility::Vector   E(target_.size(),0);
193    for (size_t i=0; i<E.size(); i++) {
194      for (size_t j=0; j<E.size(); j++) 
195        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
196      E(i)-=target(i);
197    }
198    assert(target_.size()==E.size());
199    assert(target_.size()==sample_.size());
200
201    unsigned long int epochs = 0;
202    double alpha_new2;
203    double alpha_new1;
204    double u;
205    double v;
206
207    // Training loop
208    while(choose(E)) {
209      bounds(u,v);       
210      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
211                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
212                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
213     
214      double alpha_old1=alpha_(sample_.value_first());
215      double alpha_old2=alpha_(sample_.value_second());
216      alpha_new2 = ( alpha_(sample_.value_second()) + 
217                     target(sample_.value_second())*
218                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
219     
220      if (alpha_new2 > v)
221        alpha_new2 = v;
222      else if (alpha_new2<u)
223        alpha_new2 = u;
224     
225      // Updating the alphas
226      // if alpha is 'zero' make the sample a non-support vector
227      if (alpha_new2 < tolerance_){
228        sample_.nsv_second();
229      }
230      else{
231        sample_.sv_second();
232      }
233     
234     
235      alpha_new1 = (alpha_(sample_.value_first()) + 
236                    (target(sample_.value_first()) * 
237                     target(sample_.value_second()) * 
238                     (alpha_(sample_.value_second()) - alpha_new2) ));
239           
240      // if alpha is 'zero' make the sample a non-support vector
241      if (alpha_new1 < tolerance_){
242        sample_.nsv_first();
243      }
244      else
245        sample_.sv_first();
246     
247      alpha_(sample_.value_first()) = alpha_new1;
248      alpha_(sample_.value_second()) = alpha_new2;
249     
250      // update E vector
251      // Peter, perhaps one should only update SVs, but what happens in choose?
252      for (size_t i=0; i<E.size(); i++) {
253        E(i)+=( kernel_mod(i,sample_.value_first())*
254                target(sample_.value_first()) *
255                (alpha_new1-alpha_old1) );
256        E(i)+=( kernel_mod(i,sample_.value_second())*
257                target(sample_.value_second()) *
258                (alpha_new2-alpha_old2) );
259      }
260           
261      epochs++; 
262      if (epochs>max_epochs_){
263        throw std::runtime_error("SVM: maximal number of epochs reached.");
264      }
265    }
266    calculate_margin();
267    calculate_bias();
268    trained_ = true;
269  }
270
271
272  bool SVM::choose(const theplu::yat::utility::Vector& E)
273  {
274    // First check for violation among SVs
275    // E should be the same for all SVs
276    // Choose that pair having largest violation/difference.
277    sample_.update_second(0);
278    sample_.update_first(0);
279    if (sample_.nof_sv()>1){
280
281      double max = E(sample_(0));
282      double min = max;
283      for (size_t i=1; i<sample_.nof_sv(); i++){ 
284        assert(alpha_(sample_(i))>tolerance_);
285        if (E(sample_(i)) > max){
286          max = E(sample_(i));
287          sample_.update_second(i);
288        }
289        else if (E(sample_(i))<min){
290          min = E(sample_(i));
291          sample_.update_first(i);
292        }
293      }
294      assert(alpha_(sample_.value_first())>tolerance_);
295      assert(alpha_(sample_.value_second())>tolerance_);
296
297      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
298        return true;
299      }
300     
301      // If no violation check among non-support vectors
302      sample_.shuffle();
303      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
304        if (target_.binary(sample_(i))){
305          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
306            sample_.update_second(i);
307            return true;
308          }
309        }
310        else{
311          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
312            sample_.update_first(i);
313            return true;
314          }
315        }
316      }
317    }
318
319    // if no support vectors - special case
320    else{
321      // to avoid getting stuck we shuffle
322      sample_.shuffle();
323      for (size_t i=0; i<sample_.size(); i++) {
324        if (target(sample_(i))==1){
325          for (size_t j=0; j<sample_.size(); j++) {
326            if ( target(sample_(j))==-1 && 
327                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
328              sample_.update_first(i);
329              sample_.update_second(j);
330              return true;
331            }
332          }
333        }
334      }
335    }
336   
337    // If there is no violation then we should stop training
338    return false;
339
340  }
341 
342 
343  void SVM::bounds( double& u, double& v) const
344  {
345    if (target(sample_.value_first())!=target(sample_.value_second())) {
346      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
347        v = std::numeric_limits<double>::max();
348        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
349      }
350      else {
351        v = (std::numeric_limits<double>::max() - 
352             alpha_(sample_.value_first()) + 
353             alpha_(sample_.value_second()));
354        u = 0;
355      }
356    }
357    else {       
358      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
359           std::numeric_limits<double>::max()) {
360        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
361              std::numeric_limits<double>::max());
362        v =  std::numeric_limits<double>::max();   
363      }
364      else {
365        u = 0;
366        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
367      }
368    }
369  }
370 
371  void SVM::calculate_bias(void)
372  {
373
374    // calculating output without bias
375    for (size_t i=0; i<output_.size(); i++) {
376      output_(i)=0;
377      for (size_t j=0; j<output_.size(); j++) 
378        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
379    }
380
381    if (!sample_.nof_sv()){
382      std::stringstream ss;
383      ss << "yat::classifier::SVM::train() error: " 
384         << "Cannot calculate bias because there is no support vector"; 
385      throw std::runtime_error(ss.str());
386    }
387
388    // For samples with alpha>0, we have: target*output=1-alpha/C
389    bias_=0;
390    for (size_t i=0; i<sample_.nof_sv(); i++) 
391      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
392                output_(sample_(i)) );
393    bias_=bias_/sample_.nof_sv();
394    for (size_t i=0; i<output_.size(); i++) 
395      output_(i) += bias_;
396  }
397
398  void SVM::set_C(const double C)
399  {
400    C_inverse_ = 1/C;
401  }
402
403}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.