source: branches/0.4-stable/yat/classifier/SVM.cc @ 1392

Last change on this file since 1392 was 1392, checked in by Peter, 13 years ago

trac has moved

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.9 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7  Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8
9  This file is part of the yat library, http://dev.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "SVM.h"
28#include "KernelLookup.h"
29#include "Target.h"
30#include "yat/random/random.h"
31#include "yat/statistics/Averager.h"
32#include "yat/utility/Matrix.h"
33#include "yat/utility/Vector.h"
34
35#include <algorithm>
36#include <cassert>
37#include <cctype>
38#include <cmath>
39#include <limits>
40#include <sstream>
41#include <stdexcept>
42#include <string>
43#include <utility>
44#include <vector>
45
46namespace theplu {
47namespace yat {
48namespace classifier { 
49
50  SVM::SVM(void)
51    : bias_(0),
52      C_inverse_(0),
53      kernel_(NULL),
54      margin_(0),
55      max_epochs_(100000),
56      tolerance_(0.00000001),
57      trained_(false)
58  {
59  }
60
61
62  SVM::SVM(const SVM& other)
63    : bias_(other.bias_), C_inverse_(other.C_inverse_), kernel_(other.kernel_),
64      margin_(0), max_epochs_(other.max_epochs_), tolerance_(other.tolerance_),
65      trained_(other.trained_)
66  {
67  }
68
69
70  SVM::~SVM()
71  {
72  }
73
74
75  const utility::Vector& SVM::alpha(void) const
76  {
77    return alpha_;
78  }
79
80
81  double SVM::C(void) const
82  {
83    return 1.0/C_inverse_;
84  }
85
86
87  void SVM::calculate_margin(void)
88  {
89    margin_ = 0;
90    for(size_t i = 0; i<alpha_.size(); ++i){
91      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
92      for(size_t j = i+1; j<alpha_.size(); ++j)
93        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
94    }
95  }
96
97
98  /*
99  const DataLookup2D& SVM::data(void) const
100  {
101    return *kernel_;
102  }
103  */
104
105
106  double SVM::kernel_mod(const size_t i, const size_t j) const
107  {
108    assert(kernel_);
109    assert(i<kernel_->rows());
110    assert(i<kernel_->columns());
111    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
112  }
113
114
115  SVM* SVM::make_classifier(void) const
116  {
117    SVM* svm = new SVM(*this);
118    svm->trained_ = false;
119    return svm;
120  }
121
122
123  long int SVM::max_epochs(void) const
124  {
125    return max_epochs_;
126  }
127
128
129  void SVM::max_epochs(long int n)
130  {
131    max_epochs_=n;
132  }
133
134
135  const utility::Vector& SVM::output(void) const
136  {
137    return output_;
138  }
139
140  void SVM::predict(const KernelLookup& input, utility::Matrix& prediction) const
141  {
142    assert(input.rows()==alpha_.size());
143    prediction.resize(2,input.columns(),0);
144    for (size_t i = 0; i<input.columns(); i++){
145      for (size_t j = 0; j<input.rows(); j++){
146        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
147        assert(target(j));
148      }
149      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
150    }
151   
152    for (size_t i = 0; i<prediction.columns(); i++)
153      prediction(1,i) = -prediction(0,i);
154  }
155
156  /*
157  double SVM::predict(const DataLookup1D& x) const
158  {
159    double y=0;
160    for (size_t i=0; i<alpha_.size(); i++)
161      y += alpha_(i)*target_(i)*kernel_->element(x,i);
162
163    return margin_*(y+bias_);
164  }
165
166  double SVM::predict(const DataLookupWeighted1D& x) const
167  {
168    double y=0;
169    for (size_t i=0; i<alpha_.size(); i++)
170      y += alpha_(i)*target_(i)*kernel_->element(x,i);
171
172    return margin_*(y+bias_);
173  }
174  */
175
176  int SVM::target(size_t i) const
177  {
178    assert(i<target_.size());
179    return target_.binary(i) ? 1 : -1;
180  }
181
182  void SVM::train(const KernelLookup& kernel, const Target& targ) 
183  {
184    kernel_ = new KernelLookup(kernel);
185    target_ = targ;
186   
187    alpha_ = utility::Vector(targ.size(), 0.0);
188    output_ = utility::Vector(targ.size(), 0.0);
189    // initializing variables for optimization
190    assert(target_.size()==kernel_->rows());
191    assert(target_.size()==alpha_.size());
192
193    sample_.init(alpha_,tolerance_);
194    utility::Vector   E(target_.size(),0);
195    for (size_t i=0; i<E.size(); i++) {
196      for (size_t j=0; j<E.size(); j++) 
197        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
198      E(i)-=target(i);
199    }
200    assert(target_.size()==E.size());
201    assert(target_.size()==sample_.size());
202
203    unsigned long int epochs = 0;
204    double alpha_new2;
205    double alpha_new1;
206    double u;
207    double v;
208
209    // Training loop
210    while(choose(E)) {
211      bounds(u,v);       
212      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
213                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
214                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
215     
216      double alpha_old1=alpha_(sample_.value_first());
217      double alpha_old2=alpha_(sample_.value_second());
218      alpha_new2 = ( alpha_(sample_.value_second()) + 
219                     target(sample_.value_second())*
220                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
221     
222      if (alpha_new2 > v)
223        alpha_new2 = v;
224      else if (alpha_new2<u)
225        alpha_new2 = u;
226     
227      // Updating the alphas
228      // if alpha is 'zero' make the sample a non-support vector
229      if (alpha_new2 < tolerance_){
230        sample_.nsv_second();
231      }
232      else{
233        sample_.sv_second();
234      }
235     
236     
237      alpha_new1 = (alpha_(sample_.value_first()) + 
238                    (target(sample_.value_first()) * 
239                     target(sample_.value_second()) * 
240                     (alpha_(sample_.value_second()) - alpha_new2) ));
241           
242      // if alpha is 'zero' make the sample a non-support vector
243      if (alpha_new1 < tolerance_){
244        sample_.nsv_first();
245      }
246      else
247        sample_.sv_first();
248     
249      alpha_(sample_.value_first()) = alpha_new1;
250      alpha_(sample_.value_second()) = alpha_new2;
251     
252      // update E vector
253      // Peter, perhaps one should only update SVs, but what happens in choose?
254      for (size_t i=0; i<E.size(); i++) {
255        E(i)+=( kernel_mod(i,sample_.value_first())*
256                target(sample_.value_first()) *
257                (alpha_new1-alpha_old1) );
258        E(i)+=( kernel_mod(i,sample_.value_second())*
259                target(sample_.value_second()) *
260                (alpha_new2-alpha_old2) );
261      }
262           
263      epochs++; 
264      if (epochs>max_epochs_){
265        throw std::runtime_error("SVM: maximal number of epochs reached.");
266      }
267    }
268    calculate_margin();
269    calculate_bias();
270    trained_ = true;
271  }
272
273
274  bool SVM::choose(const theplu::yat::utility::Vector& E)
275  {
276    // First check for violation among SVs
277    // E should be the same for all SVs
278    // Choose that pair having largest violation/difference.
279    sample_.update_second(0);
280    sample_.update_first(0);
281    if (sample_.nof_sv()>1){
282
283      double max = E(sample_(0));
284      double min = max;
285      for (size_t i=1; i<sample_.nof_sv(); i++){ 
286        assert(alpha_(sample_(i))>tolerance_);
287        if (E(sample_(i)) > max){
288          max = E(sample_(i));
289          sample_.update_second(i);
290        }
291        else if (E(sample_(i))<min){
292          min = E(sample_(i));
293          sample_.update_first(i);
294        }
295      }
296      assert(alpha_(sample_.value_first())>tolerance_);
297      assert(alpha_(sample_.value_second())>tolerance_);
298
299      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
300        return true;
301      }
302     
303      // If no violation check among non-support vectors
304      sample_.shuffle();
305      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
306        if (target_.binary(sample_(i))){
307          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
308            sample_.update_second(i);
309            return true;
310          }
311        }
312        else{
313          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
314            sample_.update_first(i);
315            return true;
316          }
317        }
318      }
319    }
320
321    // if no support vectors - special case
322    else{
323      // to avoid getting stuck we shuffle
324      sample_.shuffle();
325      for (size_t i=0; i<sample_.size(); i++) {
326        if (target(sample_(i))==1){
327          for (size_t j=0; j<sample_.size(); j++) {
328            if ( target(sample_(j))==-1 && 
329                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
330              sample_.update_first(i);
331              sample_.update_second(j);
332              return true;
333            }
334          }
335        }
336      }
337    }
338   
339    // If there is no violation then we should stop training
340    return false;
341
342  }
343 
344 
345  void SVM::bounds( double& u, double& v) const
346  {
347    if (target(sample_.value_first())!=target(sample_.value_second())) {
348      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
349        v = std::numeric_limits<double>::max();
350        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
351      }
352      else {
353        v = (std::numeric_limits<double>::max() - 
354             alpha_(sample_.value_first()) + 
355             alpha_(sample_.value_second()));
356        u = 0;
357      }
358    }
359    else {       
360      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
361           std::numeric_limits<double>::max()) {
362        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
363              std::numeric_limits<double>::max());
364        v =  std::numeric_limits<double>::max();   
365      }
366      else {
367        u = 0;
368        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
369      }
370    }
371  }
372 
373  void SVM::calculate_bias(void)
374  {
375
376    // calculating output without bias
377    for (size_t i=0; i<output_.size(); i++) {
378      output_(i)=0;
379      for (size_t j=0; j<output_.size(); j++) 
380        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
381    }
382
383    if (!sample_.nof_sv()){
384      std::stringstream ss;
385      ss << "yat::classifier::SVM::train() error: " 
386         << "Cannot calculate bias because there is no support vector"; 
387      throw std::runtime_error(ss.str());
388    }
389
390    // For samples with alpha>0, we have: target*output=1-alpha/C
391    bias_=0;
392    for (size_t i=0; i<sample_.nof_sv(); i++) 
393      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
394                output_(sample_(i)) );
395    bias_=bias_/sample_.nof_sv();
396    for (size_t i=0; i<output_.size(); i++) 
397      output_(i) += bias_;
398  }
399
400  void SVM::set_C(const double C)
401  {
402    C_inverse_ = 1/C;
403  }
404
405}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.