source: trunk/yat/classifier/SVM.cc @ 1102

Last change on this file since 1102 was 1102, checked in by Peter, 14 years ago

fixes #314

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.7 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "SVM.h"
27#include "KernelLookup.h"
28#include "Target.h"
29#include "yat/random/random.h"
30#include "yat/statistics/Averager.h"
31#include "yat/utility/matrix.h"
32#include "yat/utility/vector.h"
33
34#include <algorithm>
35#include <cassert>
36#include <cctype>
37#include <cmath>
38#include <limits>
39#include <sstream>
40#include <stdexcept>
41#include <string>
42#include <utility>
43#include <vector>
44
45namespace theplu {
46namespace yat {
47namespace classifier { 
48
49  SVM::SVM(void)
50    : bias_(0),
51      C_inverse_(0),
52      kernel_(NULL),
53      margin_(0),
54      max_epochs_(100000),
55      tolerance_(0.00000001),
56      trained_(false)
57  {
58  }
59
60
61  SVM::~SVM()
62  {
63    if (kernel_)
64      delete kernel_;
65  }
66
67
68  const utility::vector& SVM::alpha(void) const
69  {
70    return alpha_;
71  }
72
73
74  double SVM::C(void) const
75  {
76    return 1.0/C_inverse_;
77  }
78
79
80  void SVM::calculate_margin(void)
81  {
82    margin_ = 0;
83    for(size_t i = 0; i<alpha_.size(); ++i){
84      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
85      for(size_t j = i+1; j<alpha_.size(); ++j)
86        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
87    }
88  }
89
90
91  /*
92  const DataLookup2D& SVM::data(void) const
93  {
94    return *kernel_;
95  }
96  */
97
98
99  double SVM::kernel_mod(const size_t i, const size_t j) const
100  {
101    assert(kernel_);
102    assert(i<kernel_->rows());
103    assert(i<kernel_->columns());
104    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
105  }
106
107
108  SVM* SVM::make_classifier(void) const
109  {
110    return new SVM(*this);
111  }
112
113
114  long int SVM::max_epochs(void) const
115  {
116    return max_epochs_;
117  }
118
119
120  void SVM::max_epochs(long int n)
121  {
122    max_epochs_=n;
123  }
124
125
126  const theplu::yat::utility::vector& SVM::output(void) const
127  {
128    return output_;
129  }
130
131  void SVM::predict(const KernelLookup& input, utility::matrix& prediction) const
132  {
133    assert(input.rows()==alpha_.size());
134    prediction.resize(2,input.columns(),0);
135    for (size_t i = 0; i<input.columns(); i++){
136      for (size_t j = 0; j<input.rows(); j++){
137        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
138        assert(target(j));
139      }
140      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
141    }
142   
143    for (size_t i = 0; i<prediction.columns(); i++)
144      prediction(1,i) = -prediction(0,i);
145  }
146
147  double SVM::predict(const DataLookup1D& x) const
148  {
149    double y=0;
150    for (size_t i=0; i<alpha_.size(); i++)
151      y += alpha_(i)*target_(i)*kernel_->element(x,i);
152
153    return margin_*(y+bias_);
154  }
155
156  double SVM::predict(const DataLookupWeighted1D& x) const
157  {
158    double y=0;
159    for (size_t i=0; i<alpha_.size(); i++)
160      y += alpha_(i)*target_(i)*kernel_->element(x,i);
161
162    return margin_*(y+bias_);
163  }
164
165  int SVM::target(size_t i) const
166  {
167    assert(i<target_.size());
168    return target_.binary(i) ? 1 : -1;
169  }
170
171  void SVM::train(const KernelLookup& kernel, const Target& targ) 
172  {
173    if (kernel_)
174      delete kernel_;
175    kernel_ = new KernelLookup(kernel);
176    target_ = targ;
177   
178    alpha_ = utility::vector(targ.size(), 0.0);
179    output_ = utility::vector(targ.size(), 0.0);
180    // initializing variables for optimization
181    assert(target_.size()==kernel_->rows());
182    assert(target_.size()==alpha_.size());
183
184    sample_.init(alpha_,tolerance_);
185    utility::vector   E(target_.size(),0);
186    for (size_t i=0; i<E.size(); i++) {
187      for (size_t j=0; j<E.size(); j++) 
188        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
189      E(i)-=target(i);
190    }
191    assert(target_.size()==E.size());
192    assert(target_.size()==sample_.size());
193
194    unsigned long int epochs = 0;
195    double alpha_new2;
196    double alpha_new1;
197    double u;
198    double v;
199
200    // Training loop
201    while(choose(E)) {
202      bounds(u,v);       
203      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
204                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
205                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
206     
207      double alpha_old1=alpha_(sample_.value_first());
208      double alpha_old2=alpha_(sample_.value_second());
209      alpha_new2 = ( alpha_(sample_.value_second()) + 
210                     target(sample_.value_second())*
211                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
212     
213      if (alpha_new2 > v)
214        alpha_new2 = v;
215      else if (alpha_new2<u)
216        alpha_new2 = u;
217     
218      // Updating the alphas
219      // if alpha is 'zero' make the sample a non-support vector
220      if (alpha_new2 < tolerance_){
221        sample_.nsv_second();
222      }
223      else{
224        sample_.sv_second();
225      }
226     
227     
228      alpha_new1 = (alpha_(sample_.value_first()) + 
229                    (target(sample_.value_first()) * 
230                     target(sample_.value_second()) * 
231                     (alpha_(sample_.value_second()) - alpha_new2) ));
232           
233      // if alpha is 'zero' make the sample a non-support vector
234      if (alpha_new1 < tolerance_){
235        sample_.nsv_first();
236      }
237      else
238        sample_.sv_first();
239     
240      alpha_(sample_.value_first()) = alpha_new1;
241      alpha_(sample_.value_second()) = alpha_new2;
242     
243      // update E vector
244      // Peter, perhaps one should only update SVs, but what happens in choose?
245      for (size_t i=0; i<E.size(); i++) {
246        E(i)+=( kernel_mod(i,sample_.value_first())*
247                target(sample_.value_first()) *
248                (alpha_new1-alpha_old1) );
249        E(i)+=( kernel_mod(i,sample_.value_second())*
250                target(sample_.value_second()) *
251                (alpha_new2-alpha_old2) );
252      }
253           
254      epochs++; 
255      if (epochs>max_epochs_){
256        throw std::runtime_error("SVM: maximal number of epochs reached.");
257      }
258    }
259    calculate_margin();
260    calculate_bias();
261    trained_ = true;
262  }
263
264
265  bool SVM::choose(const theplu::yat::utility::vector& E)
266  {
267    // First check for violation among SVs
268    // E should be the same for all SVs
269    // Choose that pair having largest violation/difference.
270    sample_.update_second(0);
271    sample_.update_first(0);
272    if (sample_.nof_sv()>1){
273
274      double max = E(sample_(0));
275      double min = max;
276      for (size_t i=1; i<sample_.nof_sv(); i++){ 
277        assert(alpha_(sample_(i))>tolerance_);
278        if (E(sample_(i)) > max){
279          max = E(sample_(i));
280          sample_.update_second(i);
281        }
282        else if (E(sample_(i))<min){
283          min = E(sample_(i));
284          sample_.update_first(i);
285        }
286      }
287      assert(alpha_(sample_.value_first())>tolerance_);
288      assert(alpha_(sample_.value_second())>tolerance_);
289
290      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
291        return true;
292      }
293     
294      // If no violation check among non-support vectors
295      sample_.shuffle();
296      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
297        if (target_.binary(sample_(i))){
298          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
299            sample_.update_second(i);
300            return true;
301          }
302        }
303        else{
304          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
305            sample_.update_first(i);
306            return true;
307          }
308        }
309      }
310    }
311
312    // if no support vectors - special case
313    else{
314      // to avoid getting stuck we shuffle
315      sample_.shuffle();
316      for (size_t i=0; i<sample_.size(); i++) {
317        if (target(sample_(i))==1){
318          for (size_t j=0; j<sample_.size(); j++) {
319            if ( target(sample_(j))==-1 && 
320                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
321              sample_.update_first(i);
322              sample_.update_second(j);
323              return true;
324            }
325          }
326        }
327      }
328    }
329   
330    // If there is no violation then we should stop training
331    return false;
332
333  }
334 
335 
336  void SVM::bounds( double& u, double& v) const
337  {
338    if (target(sample_.value_first())!=target(sample_.value_second())) {
339      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
340        v = std::numeric_limits<double>::max();
341        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
342      }
343      else {
344        v = (std::numeric_limits<double>::max() - 
345             alpha_(sample_.value_first()) + 
346             alpha_(sample_.value_second()));
347        u = 0;
348      }
349    }
350    else {       
351      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
352           std::numeric_limits<double>::max()) {
353        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
354              std::numeric_limits<double>::max());
355        v =  std::numeric_limits<double>::max();   
356      }
357      else {
358        u = 0;
359        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
360      }
361    }
362  }
363 
364  void SVM::calculate_bias(void)
365  {
366
367    // calculating output without bias
368    for (size_t i=0; i<output_.size(); i++) {
369      output_(i)=0;
370      for (size_t j=0; j<output_.size(); j++) 
371        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
372    }
373
374    if (!sample_.nof_sv()){
375      std::stringstream ss;
376      ss << "yat::classifier::SVM::train() error: " 
377         << "Cannot calculate bias because there is no support vector"; 
378      throw std::runtime_error(ss.str());
379    }
380
381    // For samples with alpha>0, we have: target*output=1-alpha/C
382    bias_=0;
383    for (size_t i=0; i<sample_.nof_sv(); i++) 
384      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
385                output_(sample_(i)) );
386    bias_=bias_/sample_.nof_sv();
387    for (size_t i=0; i<output_.size(); i++) 
388      output_(i) += bias_;
389  }
390
391  void SVM::set_C(const double C)
392  {
393    C_inverse_ = 1/C;
394  }
395
396}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.