source: trunk/yat/classifier/SVM.cc @ 1108

Last change on this file since 1108 was 1108, checked in by Peter, 14 years ago

adding svm copy constructor

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.9 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "SVM.h"
27#include "KernelLookup.h"
28#include "Target.h"
29#include "yat/random/random.h"
30#include "yat/statistics/Averager.h"
31#include "yat/utility/matrix.h"
32#include "yat/utility/vector.h"
33
34#include <algorithm>
35#include <cassert>
36#include <cctype>
37#include <cmath>
38#include <limits>
39#include <sstream>
40#include <stdexcept>
41#include <string>
42#include <utility>
43#include <vector>
44
45namespace theplu {
46namespace yat {
47namespace classifier { 
48
49  SVM::SVM(void)
50    : bias_(0),
51      C_inverse_(0),
52      kernel_(NULL),
53      margin_(0),
54      max_epochs_(100000),
55      tolerance_(0.00000001),
56      trained_(false)
57  {
58  }
59
60
61  SVM::SVM(const SVM& other)
62    : bias_(other.bias_), C_inverse_(other.C_inverse_), kernel_(NULL),
63      margin_(0), max_epochs_(other.max_epochs_), tolerance_(other.tolerance_),
64      trained_(other.trained_)
65  {
66    if (other.kernel_)
67      kernel_ = new KernelLookup(*other.kernel_);
68  }
69
70
71  SVM::~SVM()
72  {
73    if (kernel_)
74      delete kernel_;
75  }
76
77
78  const utility::vector& SVM::alpha(void) const
79  {
80    return alpha_;
81  }
82
83
84  double SVM::C(void) const
85  {
86    return 1.0/C_inverse_;
87  }
88
89
90  void SVM::calculate_margin(void)
91  {
92    margin_ = 0;
93    for(size_t i = 0; i<alpha_.size(); ++i){
94      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
95      for(size_t j = i+1; j<alpha_.size(); ++j)
96        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
97    }
98  }
99
100
101  /*
102  const DataLookup2D& SVM::data(void) const
103  {
104    return *kernel_;
105  }
106  */
107
108
109  double SVM::kernel_mod(const size_t i, const size_t j) const
110  {
111    assert(kernel_);
112    assert(i<kernel_->rows());
113    assert(i<kernel_->columns());
114    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
115  }
116
117
118  SVM* SVM::make_classifier(void) const
119  {
120    return new SVM(*this);
121  }
122
123
124  long int SVM::max_epochs(void) const
125  {
126    return max_epochs_;
127  }
128
129
130  void SVM::max_epochs(long int n)
131  {
132    max_epochs_=n;
133  }
134
135
136  const theplu::yat::utility::vector& SVM::output(void) const
137  {
138    return output_;
139  }
140
141  void SVM::predict(const KernelLookup& input, utility::matrix& prediction) const
142  {
143    assert(input.rows()==alpha_.size());
144    prediction.resize(2,input.columns(),0);
145    for (size_t i = 0; i<input.columns(); i++){
146      for (size_t j = 0; j<input.rows(); j++){
147        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
148        assert(target(j));
149      }
150      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
151    }
152   
153    for (size_t i = 0; i<prediction.columns(); i++)
154      prediction(1,i) = -prediction(0,i);
155  }
156
157  double SVM::predict(const DataLookup1D& x) const
158  {
159    double y=0;
160    for (size_t i=0; i<alpha_.size(); i++)
161      y += alpha_(i)*target_(i)*kernel_->element(x,i);
162
163    return margin_*(y+bias_);
164  }
165
166  double SVM::predict(const DataLookupWeighted1D& x) const
167  {
168    double y=0;
169    for (size_t i=0; i<alpha_.size(); i++)
170      y += alpha_(i)*target_(i)*kernel_->element(x,i);
171
172    return margin_*(y+bias_);
173  }
174
175  int SVM::target(size_t i) const
176  {
177    assert(i<target_.size());
178    return target_.binary(i) ? 1 : -1;
179  }
180
181  void SVM::train(const KernelLookup& kernel, const Target& targ) 
182  {
183    if (kernel_)
184      delete kernel_;
185    kernel_ = new KernelLookup(kernel);
186    target_ = targ;
187   
188    alpha_ = utility::vector(targ.size(), 0.0);
189    output_ = utility::vector(targ.size(), 0.0);
190    // initializing variables for optimization
191    assert(target_.size()==kernel_->rows());
192    assert(target_.size()==alpha_.size());
193
194    sample_.init(alpha_,tolerance_);
195    utility::vector   E(target_.size(),0);
196    for (size_t i=0; i<E.size(); i++) {
197      for (size_t j=0; j<E.size(); j++) 
198        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
199      E(i)-=target(i);
200    }
201    assert(target_.size()==E.size());
202    assert(target_.size()==sample_.size());
203
204    unsigned long int epochs = 0;
205    double alpha_new2;
206    double alpha_new1;
207    double u;
208    double v;
209
210    // Training loop
211    while(choose(E)) {
212      bounds(u,v);       
213      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
214                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
215                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
216     
217      double alpha_old1=alpha_(sample_.value_first());
218      double alpha_old2=alpha_(sample_.value_second());
219      alpha_new2 = ( alpha_(sample_.value_second()) + 
220                     target(sample_.value_second())*
221                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
222     
223      if (alpha_new2 > v)
224        alpha_new2 = v;
225      else if (alpha_new2<u)
226        alpha_new2 = u;
227     
228      // Updating the alphas
229      // if alpha is 'zero' make the sample a non-support vector
230      if (alpha_new2 < tolerance_){
231        sample_.nsv_second();
232      }
233      else{
234        sample_.sv_second();
235      }
236     
237     
238      alpha_new1 = (alpha_(sample_.value_first()) + 
239                    (target(sample_.value_first()) * 
240                     target(sample_.value_second()) * 
241                     (alpha_(sample_.value_second()) - alpha_new2) ));
242           
243      // if alpha is 'zero' make the sample a non-support vector
244      if (alpha_new1 < tolerance_){
245        sample_.nsv_first();
246      }
247      else
248        sample_.sv_first();
249     
250      alpha_(sample_.value_first()) = alpha_new1;
251      alpha_(sample_.value_second()) = alpha_new2;
252     
253      // update E vector
254      // Peter, perhaps one should only update SVs, but what happens in choose?
255      for (size_t i=0; i<E.size(); i++) {
256        E(i)+=( kernel_mod(i,sample_.value_first())*
257                target(sample_.value_first()) *
258                (alpha_new1-alpha_old1) );
259        E(i)+=( kernel_mod(i,sample_.value_second())*
260                target(sample_.value_second()) *
261                (alpha_new2-alpha_old2) );
262      }
263           
264      epochs++; 
265      if (epochs>max_epochs_){
266        throw std::runtime_error("SVM: maximal number of epochs reached.");
267      }
268    }
269    calculate_margin();
270    calculate_bias();
271    trained_ = true;
272  }
273
274
275  bool SVM::choose(const theplu::yat::utility::vector& E)
276  {
277    // First check for violation among SVs
278    // E should be the same for all SVs
279    // Choose that pair having largest violation/difference.
280    sample_.update_second(0);
281    sample_.update_first(0);
282    if (sample_.nof_sv()>1){
283
284      double max = E(sample_(0));
285      double min = max;
286      for (size_t i=1; i<sample_.nof_sv(); i++){ 
287        assert(alpha_(sample_(i))>tolerance_);
288        if (E(sample_(i)) > max){
289          max = E(sample_(i));
290          sample_.update_second(i);
291        }
292        else if (E(sample_(i))<min){
293          min = E(sample_(i));
294          sample_.update_first(i);
295        }
296      }
297      assert(alpha_(sample_.value_first())>tolerance_);
298      assert(alpha_(sample_.value_second())>tolerance_);
299
300      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
301        return true;
302      }
303     
304      // If no violation check among non-support vectors
305      sample_.shuffle();
306      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
307        if (target_.binary(sample_(i))){
308          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
309            sample_.update_second(i);
310            return true;
311          }
312        }
313        else{
314          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
315            sample_.update_first(i);
316            return true;
317          }
318        }
319      }
320    }
321
322    // if no support vectors - special case
323    else{
324      // to avoid getting stuck we shuffle
325      sample_.shuffle();
326      for (size_t i=0; i<sample_.size(); i++) {
327        if (target(sample_(i))==1){
328          for (size_t j=0; j<sample_.size(); j++) {
329            if ( target(sample_(j))==-1 && 
330                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
331              sample_.update_first(i);
332              sample_.update_second(j);
333              return true;
334            }
335          }
336        }
337      }
338    }
339   
340    // If there is no violation then we should stop training
341    return false;
342
343  }
344 
345 
346  void SVM::bounds( double& u, double& v) const
347  {
348    if (target(sample_.value_first())!=target(sample_.value_second())) {
349      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
350        v = std::numeric_limits<double>::max();
351        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
352      }
353      else {
354        v = (std::numeric_limits<double>::max() - 
355             alpha_(sample_.value_first()) + 
356             alpha_(sample_.value_second()));
357        u = 0;
358      }
359    }
360    else {       
361      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
362           std::numeric_limits<double>::max()) {
363        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
364              std::numeric_limits<double>::max());
365        v =  std::numeric_limits<double>::max();   
366      }
367      else {
368        u = 0;
369        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
370      }
371    }
372  }
373 
374  void SVM::calculate_bias(void)
375  {
376
377    // calculating output without bias
378    for (size_t i=0; i<output_.size(); i++) {
379      output_(i)=0;
380      for (size_t j=0; j<output_.size(); j++) 
381        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
382    }
383
384    if (!sample_.nof_sv()){
385      std::stringstream ss;
386      ss << "yat::classifier::SVM::train() error: " 
387         << "Cannot calculate bias because there is no support vector"; 
388      throw std::runtime_error(ss.str());
389    }
390
391    // For samples with alpha>0, we have: target*output=1-alpha/C
392    bias_=0;
393    for (size_t i=0; i<sample_.nof_sv(); i++) 
394      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
395                output_(sample_(i)) );
396    bias_=bias_/sample_.nof_sv();
397    for (size_t i=0; i<output_.size(); i++) 
398      output_(i) += bias_;
399  }
400
401  void SVM::set_C(const double C)
402  {
403    C_inverse_ = 1/C;
404  }
405
406}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.