source: trunk/yat/classifier/SVM.cc @ 1100

Last change on this file since 1100 was 1100, checked in by Peter, 15 years ago

fixes #313 - SVM constructor is void and passing kernel and target in train function instead

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.9 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "SVM.h"
27#include "DataLookup2D.h"
28#include "Target.h"
29#include "yat/random/random.h"
30#include "yat/statistics/Averager.h"
31#include "yat/utility/matrix.h"
32#include "yat/utility/vector.h"
33
34#include <algorithm>
35#include <cassert>
36#include <cctype>
37#include <cmath>
38#include <limits>
39#include <sstream>
40#include <stdexcept>
41#include <string>
42#include <utility>
43#include <vector>
44
45namespace theplu {
46namespace yat {
47namespace classifier { 
48
49  SVM::SVM(void)
50    : bias_(0),
51      C_inverse_(0),
52      kernel_(NULL),
53      margin_(0),
54      max_epochs_(100000),
55      tolerance_(0.00000001),
56      trained_(false)
57  {
58  }
59
60
61  SVM::~SVM()
62  {
63    if (kernel_)
64      delete kernel_;
65  }
66
67
68  const utility::vector& SVM::alpha(void) const
69  {
70    return alpha_;
71  }
72
73
74  double SVM::C(void) const
75  {
76    return 1.0/C_inverse_;
77  }
78
79
80  void SVM::calculate_margin(void)
81  {
82    margin_ = 0;
83    for(size_t i = 0; i<alpha_.size(); ++i){
84      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
85      for(size_t j = i+1; j<alpha_.size(); ++j)
86        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
87    }
88  }
89
90
91  const DataLookup2D& SVM::data(void) const
92  {
93    return *kernel_;
94  }
95
96
97  double SVM::kernel_mod(const size_t i, const size_t j) const
98  {
99    assert(kernel_);
100    assert(i<kernel_->rows());
101    assert(i<kernel_->columns());
102    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
103  }
104
105
106  SVM* SVM::make_classifier(void) const
107  {
108    return new SVM(*this);
109  }
110
111
112  long int SVM::max_epochs(void) const
113  {
114    return max_epochs_;
115  }
116
117
118  void SVM::max_epochs(long int n)
119  {
120    max_epochs_=n;
121  }
122
123
124  const theplu::yat::utility::vector& SVM::output(void) const
125  {
126    return output_;
127  }
128
129  void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
130  {
131    try {
132      const KernelLookup& input_kernel =dynamic_cast<const KernelLookup&>(input);
133      assert(input.rows()==alpha_.size());
134      prediction.resize(2,input.columns(),0);
135      for (size_t i = 0; i<input.columns(); i++){
136        for (size_t j = 0; j<input.rows(); j++){
137          prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
138          assert(target(j));
139        }
140        prediction(0,i) = margin_ * (prediction(0,i) + bias_);
141      }
142     
143      for (size_t i = 0; i<prediction.columns(); i++)
144        prediction(1,i) = -prediction(0,i);
145    }
146    catch (std::bad_cast) {
147      std::string str = 
148        "Error in SVM::predict: DataLookup2D of unexpected class.";
149      throw std::runtime_error(str);
150    }
151   
152  }
153
154  double SVM::predict(const DataLookup1D& x) const
155  {
156    double y=0;
157    for (size_t i=0; i<alpha_.size(); i++)
158      y += alpha_(i)*target_(i)*kernel_->element(x,i);
159
160    return margin_*(y+bias_);
161  }
162
163  double SVM::predict(const DataLookupWeighted1D& x) const
164  {
165    double y=0;
166    for (size_t i=0; i<alpha_.size(); i++)
167      y += alpha_(i)*target_(i)*kernel_->element(x,i);
168
169    return margin_*(y+bias_);
170  }
171
172  int SVM::target(size_t i) const
173  {
174    assert(i<target_.size());
175    return target_.binary(i) ? 1 : -1;
176  }
177
178  void SVM::train(const KernelLookup& kernel, const Target& targ) 
179  {
180    if (kernel_)
181      delete kernel_;
182    kernel_ = new KernelLookup(kernel);
183    target_ = targ;
184   
185    alpha_ = utility::vector(targ.size(), 0.0);
186    output_ = utility::vector(targ.size(), 0.0);
187    // initializing variables for optimization
188    assert(target_.size()==kernel_->rows());
189    assert(target_.size()==alpha_.size());
190
191    sample_.init(alpha_,tolerance_);
192    utility::vector   E(target_.size(),0);
193    for (size_t i=0; i<E.size(); i++) {
194      for (size_t j=0; j<E.size(); j++) 
195        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
196      E(i)-=target(i);
197    }
198    assert(target_.size()==E.size());
199    assert(target_.size()==sample_.size());
200
201    unsigned long int epochs = 0;
202    double alpha_new2;
203    double alpha_new1;
204    double u;
205    double v;
206
207    // Training loop
208    while(choose(E)) {
209      bounds(u,v);       
210      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
211                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
212                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
213     
214      double alpha_old1=alpha_(sample_.value_first());
215      double alpha_old2=alpha_(sample_.value_second());
216      alpha_new2 = ( alpha_(sample_.value_second()) + 
217                     target(sample_.value_second())*
218                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
219     
220      if (alpha_new2 > v)
221        alpha_new2 = v;
222      else if (alpha_new2<u)
223        alpha_new2 = u;
224     
225      // Updating the alphas
226      // if alpha is 'zero' make the sample a non-support vector
227      if (alpha_new2 < tolerance_){
228        sample_.nsv_second();
229      }
230      else{
231        sample_.sv_second();
232      }
233     
234     
235      alpha_new1 = (alpha_(sample_.value_first()) + 
236                    (target(sample_.value_first()) * 
237                     target(sample_.value_second()) * 
238                     (alpha_(sample_.value_second()) - alpha_new2) ));
239           
240      // if alpha is 'zero' make the sample a non-support vector
241      if (alpha_new1 < tolerance_){
242        sample_.nsv_first();
243      }
244      else
245        sample_.sv_first();
246     
247      alpha_(sample_.value_first()) = alpha_new1;
248      alpha_(sample_.value_second()) = alpha_new2;
249     
250      // update E vector
251      // Peter, perhaps one should only update SVs, but what happens in choose?
252      for (size_t i=0; i<E.size(); i++) {
253        E(i)+=( kernel_mod(i,sample_.value_first())*
254                target(sample_.value_first()) *
255                (alpha_new1-alpha_old1) );
256        E(i)+=( kernel_mod(i,sample_.value_second())*
257                target(sample_.value_second()) *
258                (alpha_new2-alpha_old2) );
259      }
260           
261      epochs++; 
262      if (epochs>max_epochs_){
263        throw std::runtime_error("SVM: maximal number of epochs reached.");
264      }
265    }
266    calculate_margin();
267    calculate_bias();
268    trained_ = true;
269  }
270
271
272  bool SVM::choose(const theplu::yat::utility::vector& E)
273  {
274    // First check for violation among SVs
275    // E should be the same for all SVs
276    // Choose that pair having largest violation/difference.
277    sample_.update_second(0);
278    sample_.update_first(0);
279    if (sample_.nof_sv()>1){
280
281      double max = E(sample_(0));
282      double min = max;
283      for (size_t i=1; i<sample_.nof_sv(); i++){ 
284        assert(alpha_(sample_(i))>tolerance_);
285        if (E(sample_(i)) > max){
286          max = E(sample_(i));
287          sample_.update_second(i);
288        }
289        else if (E(sample_(i))<min){
290          min = E(sample_(i));
291          sample_.update_first(i);
292        }
293      }
294      assert(alpha_(sample_.value_first())>tolerance_);
295      assert(alpha_(sample_.value_second())>tolerance_);
296
297      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
298        return true;
299      }
300     
301      // If no violation check among non-support vectors
302      sample_.shuffle();
303      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
304        if (target_.binary(sample_(i))){
305          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
306            sample_.update_second(i);
307            return true;
308          }
309        }
310        else{
311          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
312            sample_.update_first(i);
313            return true;
314          }
315        }
316      }
317    }
318
319    // if no support vectors - special case
320    else{
321      // to avoid getting stuck we shuffle
322      sample_.shuffle();
323      for (size_t i=0; i<sample_.size(); i++) {
324        if (target(sample_(i))==1){
325          for (size_t j=0; j<sample_.size(); j++) {
326            if ( target(sample_(j))==-1 && 
327                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
328              sample_.update_first(i);
329              sample_.update_second(j);
330              return true;
331            }
332          }
333        }
334      }
335    }
336   
337    // If there is no violation then we should stop training
338    return false;
339
340  }
341 
342 
343  void SVM::bounds( double& u, double& v) const
344  {
345    if (target(sample_.value_first())!=target(sample_.value_second())) {
346      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
347        v = std::numeric_limits<double>::max();
348        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
349      }
350      else {
351        v = (std::numeric_limits<double>::max() - 
352             alpha_(sample_.value_first()) + 
353             alpha_(sample_.value_second()));
354        u = 0;
355      }
356    }
357    else {       
358      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
359           std::numeric_limits<double>::max()) {
360        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
361              std::numeric_limits<double>::max());
362        v =  std::numeric_limits<double>::max();   
363      }
364      else {
365        u = 0;
366        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
367      }
368    }
369  }
370 
371  void SVM::calculate_bias(void)
372  {
373
374    // calculating output without bias
375    for (size_t i=0; i<output_.size(); i++) {
376      output_(i)=0;
377      for (size_t j=0; j<output_.size(); j++) 
378        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
379    }
380
381    if (!sample_.nof_sv()){
382      std::stringstream ss;
383      ss << "yat::classifier::SVM::train() error: " 
384         << "Cannot calculate bias because there is no support vector"; 
385      throw std::runtime_error(ss.str());
386    }
387
388    // For samples with alpha>0, we have: target*output=1-alpha/C
389    bias_=0;
390    for (size_t i=0; i<sample_.nof_sv(); i++) 
391      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
392                output_(sample_(i)) );
393    bias_=bias_/sample_.nof_sv();
394    for (size_t i=0; i<output_.size(); i++) 
395      output_(i) += bias_;
396  }
397
398  void SVM::set_C(const double C)
399  {
400    C_inverse_ = 1/C;
401  }
402
403}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.