source: trunk/yat/classifier/SVM.cc @ 1592

Last change on this file since 1592 was 1592, checked in by Peter, 13 years ago

cleaning up some includes

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.8 KB
RevLine 
[30]1// $Id$
2
[675]3/*
[831]4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
[1275]5  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
[831]6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
[1275]7  Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
[295]8
[1437]9  This file is part of the yat library, http://dev.thep.lu.se/yat
[295]10
[675]11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
[1486]13  published by the Free Software Foundation; either version 3 of the
[675]14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
[1487]22  along with yat. If not, see <http://www.gnu.org/licenses/>.
[675]23*/
24
[680]25#include "SVM.h"
[1102]26#include "KernelLookup.h"
[747]27#include "Target.h"
[1121]28#include "yat/utility/Matrix.h"
[1120]29#include "yat/utility/Vector.h"
[675]30
[323]31#include <algorithm>
[463]32#include <cassert>
[569]33#include <cctype>
[55]34#include <cmath>
[323]35#include <limits>
[1049]36#include <sstream>
[964]37#include <stdexcept>
[1100]38#include <string>
[47]39#include <utility>
[61]40#include <vector>
[30]41
[42]42namespace theplu {
[680]43namespace yat {
[450]44namespace classifier { 
[30]45
[1100]46  SVM::SVM(void)
47    : bias_(0),
[491]48      C_inverse_(0),
[1100]49      kernel_(NULL),
[569]50      margin_(0),
[568]51      max_epochs_(100000),
[1100]52      tolerance_(0.00000001),
53      trained_(false)
[68]54  {
55  }
[30]56
[1100]57
[1108]58  SVM::SVM(const SVM& other)
[1177]59    : bias_(other.bias_), C_inverse_(other.C_inverse_), kernel_(other.kernel_),
[1108]60      margin_(0), max_epochs_(other.max_epochs_), tolerance_(other.tolerance_),
61      trained_(other.trained_)
62  {
63  }
64
65
[547]66  SVM::~SVM()
67  {
68  }
69
[1100]70
[1120]71  const utility::Vector& SVM::alpha(void) const
[720]72  {
73    return alpha_;
74  }
[547]75
[1100]76
[720]77  double SVM::C(void) const
78  {
79    return 1.0/C_inverse_;
80  }
81
[1100]82
[569]83  void SVM::calculate_margin(void)
84  {
85    margin_ = 0;
86    for(size_t i = 0; i<alpha_.size(); ++i){
87      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
88      for(size_t j = i+1; j<alpha_.size(); ++j)
89        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
90    }
91  }
92
[1100]93
[1102]94  /*
[722]95  const DataLookup2D& SVM::data(void) const
96  {
97    return *kernel_;
98  }
[1102]99  */
[722]100
101
[720]102  double SVM::kernel_mod(const size_t i, const size_t j) const
103  {
[1100]104    assert(kernel_);
105    assert(i<kernel_->rows());
106    assert(i<kernel_->columns());
[720]107    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
108  }
[569]109
[1100]110
111  SVM* SVM::make_classifier(void) const
[493]112  {
[1175]113    SVM* svm = new SVM(*this);
114    svm->trained_ = false;
115    return svm;
[493]116  }
117
[1100]118
[720]119  long int SVM::max_epochs(void) const
120  {
121    return max_epochs_;
122  }
123
[964]124
125  void SVM::max_epochs(long int n)
126  {
127    max_epochs_=n;
128  }
129
130
[1120]131  const utility::Vector& SVM::output(void) const
[720]132  {
133    return output_;
134  }
135
[1121]136  void SVM::predict(const KernelLookup& input, utility::Matrix& prediction) const
[523]137  {
[1102]138    assert(input.rows()==alpha_.size());
139    prediction.resize(2,input.columns(),0);
140    for (size_t i = 0; i<input.columns(); i++){
141      for (size_t j = 0; j<input.rows(); j++){
142        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
143        assert(target(j));
[559]144      }
[1102]145      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
[523]146    }
[559]147   
[1102]148    for (size_t i = 0; i<prediction.columns(); i++)
149      prediction(1,i) = -prediction(0,i);
[523]150  }
151
[1200]152  /*
[539]153  double SVM::predict(const DataLookup1D& x) const
[523]154  {
155    double y=0;
[539]156    for (size_t i=0; i<alpha_.size(); i++)
157      y += alpha_(i)*target_(i)*kernel_->element(x,i);
[523]158
[569]159    return margin_*(y+bias_);
[523]160  }
161
[628]162  double SVM::predict(const DataLookupWeighted1D& x) const
[542]163  {
164    double y=0;
165    for (size_t i=0; i<alpha_.size(); i++)
[628]166      y += alpha_(i)*target_(i)*kernel_->element(x,i);
[542]167
[569]168    return margin_*(y+bias_);
[542]169  }
[1200]170  */
[542]171
[720]172  int SVM::target(size_t i) const
173  {
[1100]174    assert(i<target_.size());
[720]175    return target_.binary(i) ? 1 : -1;
176  }
177
[1100]178  void SVM::train(const KernelLookup& kernel, const Target& targ) 
[114]179  {
[1100]180    kernel_ = new KernelLookup(kernel);
181    target_ = targ;
182   
[1120]183    alpha_ = utility::Vector(targ.size(), 0.0);
184    output_ = utility::Vector(targ.size(), 0.0);
[323]185    // initializing variables for optimization
[523]186    assert(target_.size()==kernel_->rows());
[323]187    assert(target_.size()==alpha_.size());
[167]188
[323]189    sample_.init(alpha_,tolerance_);
[1120]190    utility::Vector   E(target_.size(),0);
[323]191    for (size_t i=0; i<E.size(); i++) {
192      for (size_t j=0; j<E.size(); j++) 
[509]193        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
[559]194      E(i)-=target(i);
[162]195    }
[323]196    assert(target_.size()==E.size());
[568]197    assert(target_.size()==sample_.size());
[114]198
[162]199    unsigned long int epochs = 0;
200    double alpha_new2;
201    double alpha_new1;
202    double u;
203    double v;
[323]204
[162]205    // Training loop
[323]206    while(choose(E)) {
207      bounds(u,v);       
208      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
209                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
210                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
211     
212      double alpha_old1=alpha_(sample_.value_first());
213      double alpha_old2=alpha_(sample_.value_second());
214      alpha_new2 = ( alpha_(sample_.value_second()) + 
[509]215                     target(sample_.value_second())*
[323]216                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
217     
[162]218      if (alpha_new2 > v)
219        alpha_new2 = v;
220      else if (alpha_new2<u)
221        alpha_new2 = u;
[323]222     
223      // Updating the alphas
224      // if alpha is 'zero' make the sample a non-support vector
225      if (alpha_new2 < tolerance_){
226        sample_.nsv_second();
227      }
228      else{
229        sample_.sv_second();
230      }
231     
232     
233      alpha_new1 = (alpha_(sample_.value_first()) + 
[509]234                    (target(sample_.value_first()) * 
235                     target(sample_.value_second()) * 
[323]236                     (alpha_(sample_.value_second()) - alpha_new2) ));
237           
238      // if alpha is 'zero' make the sample a non-support vector
[162]239      if (alpha_new1 < tolerance_){
[323]240        sample_.nsv_first();
[162]241      }
[323]242      else
243        sample_.sv_first();
244     
245      alpha_(sample_.value_first()) = alpha_new1;
246      alpha_(sample_.value_second()) = alpha_new2;
247     
248      // update E vector
249      // Peter, perhaps one should only update SVs, but what happens in choose?
250      for (size_t i=0; i<E.size(); i++) {
251        E(i)+=( kernel_mod(i,sample_.value_first())*
[509]252                target(sample_.value_first()) *
[323]253                (alpha_new1-alpha_old1) );
254        E(i)+=( kernel_mod(i,sample_.value_second())*
[509]255                target(sample_.value_second()) *
[323]256                (alpha_new2-alpha_old2) );
257      }
258           
[162]259      epochs++; 
[520]260      if (epochs>max_epochs_){
[964]261        throw std::runtime_error("SVM: maximal number of epochs reached.");
[520]262      }
[323]263    }
[569]264    calculate_margin();
[1049]265    calculate_bias();
266    trained_ = true;
[323]267  }
268
269
[1120]270  bool SVM::choose(const theplu::yat::utility::Vector& E)
[323]271  {
272    // First check for violation among SVs
273    // E should be the same for all SVs
274    // Choose that pair having largest violation/difference.
275    sample_.update_second(0);
276    sample_.update_first(0);
277    if (sample_.nof_sv()>1){
[514]278
[323]279      double max = E(sample_(0));
280      double min = max;
281      for (size_t i=1; i<sample_.nof_sv(); i++){ 
282        assert(alpha_(sample_(i))>tolerance_);
283        if (E(sample_(i)) > max){
284          max = E(sample_(i));
285          sample_.update_second(i);
286        }
287        else if (E(sample_(i))<min){
288          min = E(sample_(i));
289          sample_.update_first(i);
290        }
[162]291      }
[323]292      assert(alpha_(sample_.value_first())>tolerance_);
293      assert(alpha_(sample_.value_second())>tolerance_);
294
295      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
296        return true;
297      }
298     
299      // If no violation check among non-support vectors
300      sample_.shuffle();
[568]301      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
[514]302        if (target_.binary(sample_(i))){
[323]303          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
304            sample_.update_second(i);
305            return true;
306          }
307        }
308        else{
309          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
310            sample_.update_first(i);
311            return true;
312          }
313        }
314      }
[68]315    }
[37]316
[323]317    // if no support vectors - special case
318    else{
[559]319      // to avoid getting stuck we shuffle
320      sample_.shuffle();
[568]321      for (size_t i=0; i<sample_.size(); i++) {
322        if (target(sample_(i))==1){
323          for (size_t j=0; j<sample_.size(); j++) {
324            if ( target(sample_(j))==-1 && 
[323]325                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
326              sample_.update_first(i);
327              sample_.update_second(j);
328              return true;
329            }
330          }
331        }
332      }
333    }
[162]334   
[323]335    // If there is no violation then we should stop training
336    return false;
337
[114]338  }
[68]339 
[323]340 
341  void SVM::bounds( double& u, double& v) const
342  {
[509]343    if (target(sample_.value_first())!=target(sample_.value_second())) {
[323]344      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
345        v = std::numeric_limits<double>::max();
346        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
347      }
348      else {
349        v = (std::numeric_limits<double>::max() - 
350             alpha_(sample_.value_first()) + 
351             alpha_(sample_.value_second()));
352        u = 0;
353      }
354    }
355    else {       
356      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
357           std::numeric_limits<double>::max()) {
358        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
359              std::numeric_limits<double>::max());
360        v =  std::numeric_limits<double>::max();   
361      }
362      else {
363        u = 0;
364        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
365      }
366    }
367  }
368 
[1049]369  void SVM::calculate_bias(void)
[323]370  {
[47]371
[323]372    // calculating output without bias
373    for (size_t i=0; i<output_.size(); i++) {
374      output_(i)=0;
375      for (size_t j=0; j<output_.size(); j++) 
[523]376        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
[323]377    }
378
379    if (!sample_.nof_sv()){
[1049]380      std::stringstream ss;
381      ss << "yat::classifier::SVM::train() error: " 
382         << "Cannot calculate bias because there is no support vector"; 
383      throw std::runtime_error(ss.str());
[323]384    }
385
386    // For samples with alpha>0, we have: target*output=1-alpha/C
387    bias_=0;
388    for (size_t i=0; i<sample_.nof_sv(); i++) 
[509]389      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
[323]390                output_(sample_(i)) );
391    bias_=bias_/sample_.nof_sv();
392    for (size_t i=0; i<output_.size(); i++) 
393      output_(i) += bias_;
394  }
395
[571]396  void SVM::set_C(const double C)
397  {
398    C_inverse_ = 1/C;
399  }
400
[680]401}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.