source: trunk/yat/classifier/SVM.cc @ 2103

Last change on this file since 2103 was 2103, checked in by Peter, 12 years ago

merging patch release 0.5.5 into trunk. Delta 0.5.5 - 0.5.4

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.0 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7  Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009 Peter Johansson
9
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24*/
25
26#include "SVM.h"
27#include "KernelLookup.h"
28#include "Target.h"
29#include "yat/utility/Matrix.h"
30#include "yat/utility/Vector.h"
31
32#include <algorithm>
33#include <cassert>
34#include <cctype>
35#include <cmath>
36#include <limits>
37#include <sstream>
38#include <stdexcept>
39#include <string>
40#include <utility>
41#include <vector>
42
43namespace theplu {
44namespace yat {
45namespace classifier { 
46
47  SVM::SVM(void)
48    : bias_(0),
49      C_inverse_(0),
50      kernel_(NULL),
51      margin_(0),
52      max_epochs_(100000),
53      tolerance_(0.00000001),
54      trained_(false)
55  {
56  }
57
58
59  SVM::SVM(const SVM& other)
60    : bias_(other.bias_), C_inverse_(other.C_inverse_), kernel_(other.kernel_),
61      margin_(0), max_epochs_(other.max_epochs_), tolerance_(other.tolerance_),
62      trained_(other.trained_)
63  {
64  }
65
66
67  SVM::~SVM()
68  {
69  }
70
71
72  const utility::Vector& SVM::alpha(void) const
73  {
74    return alpha_;
75  }
76
77
78  double SVM::C(void) const
79  {
80    return 1.0/C_inverse_;
81  }
82
83
84  void SVM::calculate_margin(void)
85  {
86    margin_ = 0;
87    for(size_t i = 0; i<alpha_.size(); ++i){
88      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
89      for(size_t j = i+1; j<alpha_.size(); ++j)
90        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
91    }
92    margin_ = std::sqrt(margin_);
93  }
94
95
96  /*
97  const DataLookup2D& SVM::data(void) const
98  {
99    return *kernel_;
100  }
101  */
102
103
104  double SVM::kernel_mod(const size_t i, const size_t j) const
105  {
106    assert(kernel_);
107    assert(i<kernel_->rows());
108    assert(i<kernel_->columns());
109    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
110  }
111
112
113  SVM* SVM::make_classifier(void) const
114  {
115    SVM* svm = new SVM(*this);
116    svm->trained_ = false;
117    return svm;
118  }
119
120
121  unsigned long int SVM::max_epochs(void) const
122  {
123    return max_epochs_;
124  }
125
126
127  void SVM::max_epochs(unsigned long int n)
128  {
129    max_epochs_=n;
130  }
131
132
133  const utility::Vector& SVM::output(void) const
134  {
135    return output_;
136  }
137
138  void SVM::predict(const KernelLookup& input, utility::Matrix& prediction) const
139  {
140    assert(input.rows()==alpha_.size());
141    prediction.resize(2,input.columns(),0);
142    for (size_t i = 0; i<input.columns(); i++){
143      for (size_t j = 0; j<input.rows(); j++){
144        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
145        assert(target(j));
146      }
147      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
148    }
149   
150    for (size_t i = 0; i<prediction.columns(); i++)
151      prediction(1,i) = -prediction(0,i);
152  }
153
154  /*
155  double SVM::predict(const DataLookup1D& x) const
156  {
157    double y=0;
158    for (size_t i=0; i<alpha_.size(); i++)
159      y += alpha_(i)*target_(i)*kernel_->element(x,i);
160
161    return margin_*(y+bias_);
162  }
163
164  double SVM::predict(const DataLookupWeighted1D& x) const
165  {
166    double y=0;
167    for (size_t i=0; i<alpha_.size(); i++)
168      y += alpha_(i)*target_(i)*kernel_->element(x,i);
169
170    return margin_*(y+bias_);
171  }
172  */
173
174  int SVM::target(size_t i) const
175  {
176    assert(i<target_.size());
177    return target_.binary(i) ? 1 : -1;
178  }
179
180  void SVM::train(const KernelLookup& kernel, const Target& targ) 
181  {
182    trained_ = false;
183    kernel_ = new KernelLookup(kernel);
184    target_ = targ;
185   
186    alpha_ = utility::Vector(targ.size(), 0.0);
187    output_ = utility::Vector(targ.size(), 0.0);
188    // initializing variables for optimization
189    assert(target_.size()==kernel_->rows());
190    assert(target_.size()==alpha_.size());
191
192    sample_.init(alpha_,tolerance_);
193    utility::Vector   E(target_.size(),0);
194    for (size_t i=0; i<E.size(); i++) {
195      for (size_t j=0; j<E.size(); j++) 
196        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
197      E(i)-=target(i);
198    }
199    assert(target_.size()==E.size());
200    assert(target_.size()==sample_.size());
201
202    unsigned long int epochs = 0;
203    double alpha_new2;
204    double alpha_new1;
205    double u;
206    double v;
207
208    // Training loop
209    while(choose(E)) {
210      bounds(u,v);       
211      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
212                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
213                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
214     
215      double alpha_old1=alpha_(sample_.value_first());
216      double alpha_old2=alpha_(sample_.value_second());
217      alpha_new2 = ( alpha_(sample_.value_second()) + 
218                     target(sample_.value_second())*
219                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
220     
221      if (alpha_new2 > v)
222        alpha_new2 = v;
223      else if (alpha_new2<u)
224        alpha_new2 = u;
225     
226      // Updating the alphas
227      // if alpha is 'zero' make the sample a non-support vector
228      if (alpha_new2 < tolerance_){
229        sample_.nsv_second();
230      }
231      else{
232        sample_.sv_second();
233      }
234     
235     
236      alpha_new1 = (alpha_(sample_.value_first()) + 
237                    (target(sample_.value_first()) * 
238                     target(sample_.value_second()) * 
239                     (alpha_(sample_.value_second()) - alpha_new2) ));
240           
241      // if alpha is 'zero' make the sample a non-support vector
242      if (alpha_new1 < tolerance_){
243        sample_.nsv_first();
244      }
245      else
246        sample_.sv_first();
247     
248      alpha_(sample_.value_first()) = alpha_new1;
249      alpha_(sample_.value_second()) = alpha_new2;
250     
251      // update E vector
252      // Peter, perhaps one should only update SVs, but what happens in choose?
253      for (size_t i=0; i<E.size(); i++) {
254        E(i)+=( kernel_mod(i,sample_.value_first())*
255                target(sample_.value_first()) *
256                (alpha_new1-alpha_old1) );
257        E(i)+=( kernel_mod(i,sample_.value_second())*
258                target(sample_.value_second()) *
259                (alpha_new2-alpha_old2) );
260      }
261           
262      epochs++; 
263      if (epochs>max_epochs_){
264        throw std::runtime_error("SVM: maximal number of epochs reached.");
265      }
266    }
267    calculate_margin();
268    calculate_bias();
269    trained_ = true;
270  }
271
272
273  bool SVM::choose(const theplu::yat::utility::Vector& E)
274  {
275    // First check for violation among SVs
276    // E should be the same for all SVs
277    // Choose that pair having largest violation/difference.
278    sample_.update_second(0);
279    sample_.update_first(0);
280    if (sample_.nof_sv()>1){
281
282      double max = E(sample_(0));
283      double min = max;
284      for (size_t i=1; i<sample_.nof_sv(); i++){ 
285        assert(alpha_(sample_(i))>tolerance_);
286        if (E(sample_(i)) > max){
287          max = E(sample_(i));
288          sample_.update_second(i);
289        }
290        else if (E(sample_(i))<min){
291          min = E(sample_(i));
292          sample_.update_first(i);
293        }
294      }
295      assert(alpha_(sample_.value_first())>tolerance_);
296      assert(alpha_(sample_.value_second())>tolerance_);
297
298      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
299        return true;
300      }
301     
302      // If no violation check among non-support vectors
303      sample_.shuffle();
304      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
305        if (target_.binary(sample_(i))){
306          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
307            sample_.update_second(i);
308            return true;
309          }
310        }
311        else{
312          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
313            sample_.update_first(i);
314            return true;
315          }
316        }
317      }
318    }
319
320    // if no support vectors - special case
321    else{
322      // to avoid getting stuck we shuffle
323      sample_.shuffle();
324      for (size_t i=0; i<sample_.size(); i++) {
325        if (target(sample_(i))==1){
326          for (size_t j=0; j<sample_.size(); j++) {
327            if ( target(sample_(j))==-1 && 
328                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
329              sample_.update_first(i);
330              sample_.update_second(j);
331              return true;
332            }
333          }
334        }
335      }
336    }
337   
338    // If there is no violation then we should stop training
339    return false;
340
341  }
342 
343 
344  void SVM::bounds( double& u, double& v) const
345  {
346    if (target(sample_.value_first())!=target(sample_.value_second())) {
347      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
348        v = std::numeric_limits<double>::max();
349        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
350      }
351      else {
352        v = (std::numeric_limits<double>::max() - 
353             alpha_(sample_.value_first()) + 
354             alpha_(sample_.value_second()));
355        u = 0;
356      }
357    }
358    else {       
359      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
360           std::numeric_limits<double>::max()) {
361        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
362              std::numeric_limits<double>::max());
363        v =  std::numeric_limits<double>::max();   
364      }
365      else {
366        u = 0;
367        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
368      }
369    }
370  }
371 
372  void SVM::calculate_bias(void)
373  {
374
375    // calculating output without bias
376    for (size_t i=0; i<output_.size(); i++) {
377      output_(i)=0;
378      for (size_t j=0; j<output_.size(); j++) 
379        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
380    }
381
382    if (!sample_.nof_sv()){
383      std::stringstream ss;
384      ss << "yat::classifier::SVM::train() error: " 
385         << "Cannot calculate bias because there is no support vector"; 
386      throw std::runtime_error(ss.str());
387    }
388
389    // For samples with alpha>0, we have: target*output=1-alpha/C
390    bias_=0;
391    for (size_t i=0; i<sample_.nof_sv(); i++) 
392      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
393                output_(sample_(i)) );
394    bias_=bias_/sample_.nof_sv();
395    for (size_t i=0; i<output_.size(); i++) 
396      output_(i) += bias_;
397  }
398
399  void SVM::set_C(const double C)
400  {
401    C_inverse_ = 1/C;
402  }
403
404 
405  void SVM::reset(void)
406  {
407    trained_ = false;
408  }
409
410
411  bool SVM::trained(void) const
412  {
413    return trained_;
414  }
415
416}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.