source: trunk/yat/classifier/SVM.cc @ 680

Last change on this file since 680 was 680, checked in by Jari Häkkinen, 15 years ago

Addresses #153. Introduced yat namespace. Removed alignment namespace. Clean up of code.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.0 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "SVM.h"
25#include "DataLookup2D.h"
26#include "yat/random/random.h"
27#include "yat/statistics/Averager.h"
28#include "yat/utility/matrix.h"
29#include "yat/utility/vector.h"
30
31#include <algorithm>
32#include <cassert>
33#include <cctype>
34#include <cmath>
35#include <limits>
36#include <utility>
37#include <vector>
38
39namespace theplu {
40namespace yat {
41namespace classifier { 
42
43  SVM::SVM(const KernelLookup& kernel, const Target& target)
44    : SupervisedClassifier(target),
45      alpha_(target.size(),0),
46      bias_(0),
47      C_inverse_(0),
48      kernel_(&kernel),
49      margin_(0),
50      max_epochs_(100000),
51      output_(target.size(),0),
52      owner_(false),
53      sample_(target.size()),
54      trained_(false),
55      tolerance_(0.00000001)
56  {
57#ifndef NDEBUG
58    assert(kernel.columns()==kernel.rows());
59    assert(kernel.columns()==alpha_.size());
60    for (size_t i=0; i<alpha_.size(); i++) 
61      for (size_t j=0; j<alpha_.size(); j++)
62        assert(kernel(i,j)==kernel(j,i));
63    for (size_t i=0; i<alpha_.size(); i++) 
64      for (size_t j=0; j<alpha_.size(); j++) 
65        assert((*kernel_)(i,j)==(*kernel_)(j,i));
66    for (size_t i = 0; i<kernel_->rows(); i++)
67      for (size_t j = 0; j<kernel_->columns(); j++)
68        if (std::isnan((*kernel_)(i,j)))
69          std::cerr << "SVM: Found nan in kernel: " << i << " " 
70                    << j << std::endl;
71#endif       
72  }
73
74  SVM::~SVM()
75  {
76    if (owner_)
77      delete kernel_;
78  }
79
80
81  void SVM::calculate_margin(void)
82  {
83    margin_ = 0;
84    for(size_t i = 0; i<alpha_.size(); ++i){
85      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
86      for(size_t j = i+1; j<alpha_.size(); ++j)
87        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
88    }
89  }
90
91
92  SupervisedClassifier* SVM::make_classifier(const DataLookup2D& data,
93                                             const Target& target) const
94  {
95    SVM* sc=0;
96    try {
97      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
98      assert(data.rows()==data.columns());
99      assert(data.columns()==target.size());
100      sc = new SVM(kernel,target);
101
102    //Copy those variables possible to modify from outside
103    // Peter, in particular C
104    }
105    catch (std::bad_cast) {
106      std::cerr << "Warning: SVM::make_classifier only takes KernelLookup" 
107                << std::endl;
108    }
109    return sc;
110  }
111
112  void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
113  {
114    // Peter, should check success of dynamic_cast
115    const KernelLookup& input_kernel = dynamic_cast<const KernelLookup&>(input);
116
117    assert(input.rows()==alpha_.size());
118    prediction = utility::matrix(2,input.columns(),0);
119    for (size_t i = 0; i<input.columns(); i++){
120      for (size_t j = 0; j<input.rows(); j++){
121        prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
122        assert(target(j));
123      }
124      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
125    }
126
127    for (size_t i = 0; i<prediction.columns(); i++)
128      prediction(1,i) = -prediction(0,i);
129   
130    assert(prediction(0,0));
131  }
132
133  double SVM::predict(const DataLookup1D& x) const
134  {
135    double y=0;
136    for (size_t i=0; i<alpha_.size(); i++)
137      y += alpha_(i)*target_(i)*kernel_->element(x,i);
138
139    return margin_*(y+bias_);
140  }
141
142  double SVM::predict(const DataLookupWeighted1D& x) const
143  {
144    double y=0;
145    for (size_t i=0; i<alpha_.size(); i++)
146      y += alpha_(i)*target_(i)*kernel_->element(x,i);
147
148    return margin_*(y+bias_);
149  }
150
151  bool SVM::train(void) 
152  {
153    // initializing variables for optimization
154    assert(target_.size()==kernel_->rows());
155    assert(target_.size()==alpha_.size());
156
157    sample_.init(alpha_,tolerance_);
158    utility::vector   E(target_.size(),0);
159    for (size_t i=0; i<E.size(); i++) {
160      for (size_t j=0; j<E.size(); j++) 
161        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
162      E(i)-=target(i);
163    }
164    assert(target_.size()==E.size());
165    assert(target_.size()==sample_.size());
166
167    unsigned long int epochs = 0;
168    double alpha_new2;
169    double alpha_new1;
170    double u;
171    double v;
172
173    // Training loop
174    while(choose(E)) {
175      bounds(u,v);       
176      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
177                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
178                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
179     
180      double alpha_old1=alpha_(sample_.value_first());
181      double alpha_old2=alpha_(sample_.value_second());
182      alpha_new2 = ( alpha_(sample_.value_second()) + 
183                     target(sample_.value_second())*
184                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
185     
186      if (alpha_new2 > v)
187        alpha_new2 = v;
188      else if (alpha_new2<u)
189        alpha_new2 = u;
190     
191      // Updating the alphas
192      // if alpha is 'zero' make the sample a non-support vector
193      if (alpha_new2 < tolerance_){
194        sample_.nsv_second();
195      }
196      else{
197        sample_.sv_second();
198      }
199     
200     
201      alpha_new1 = (alpha_(sample_.value_first()) + 
202                    (target(sample_.value_first()) * 
203                     target(sample_.value_second()) * 
204                     (alpha_(sample_.value_second()) - alpha_new2) ));
205           
206      // if alpha is 'zero' make the sample a non-support vector
207      if (alpha_new1 < tolerance_){
208        sample_.nsv_first();
209      }
210      else
211        sample_.sv_first();
212     
213      alpha_(sample_.value_first()) = alpha_new1;
214      alpha_(sample_.value_second()) = alpha_new2;
215     
216      // update E vector
217      // Peter, perhaps one should only update SVs, but what happens in choose?
218      for (size_t i=0; i<E.size(); i++) {
219        E(i)+=( kernel_mod(i,sample_.value_first())*
220                target(sample_.value_first()) *
221                (alpha_new1-alpha_old1) );
222        E(i)+=( kernel_mod(i,sample_.value_second())*
223                target(sample_.value_second()) *
224                (alpha_new2-alpha_old2) );
225      }
226           
227      epochs++; 
228      if (epochs>max_epochs_){
229        std::cerr << "WARNING: SVM: maximal number of epochs reached.\n";
230        calculate_bias();
231        calculate_margin();
232        return false;
233      }
234    }
235    calculate_margin();
236    trained_ = calculate_bias();
237    return trained_;
238  }
239
240
241  bool SVM::choose(const theplu::yat::utility::vector& E)
242  {
243    // First check for violation among SVs
244    // E should be the same for all SVs
245    // Choose that pair having largest violation/difference.
246    sample_.update_second(0);
247    sample_.update_first(0);
248    if (sample_.nof_sv()>1){
249
250      double max = E(sample_(0));
251      double min = max;
252      for (size_t i=1; i<sample_.nof_sv(); i++){ 
253        assert(alpha_(sample_(i))>tolerance_);
254        if (E(sample_(i)) > max){
255          max = E(sample_(i));
256          sample_.update_second(i);
257        }
258        else if (E(sample_(i))<min){
259          min = E(sample_(i));
260          sample_.update_first(i);
261        }
262      }
263      assert(alpha_(sample_.value_first())>tolerance_);
264      assert(alpha_(sample_.value_second())>tolerance_);
265
266      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
267        return true;
268      }
269     
270      // If no violation check among non-support vectors
271      sample_.shuffle();
272      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
273        if (target_.binary(sample_(i))){
274          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
275            sample_.update_second(i);
276            return true;
277          }
278        }
279        else{
280          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
281            sample_.update_first(i);
282            return true;
283          }
284        }
285      }
286    }
287
288    // if no support vectors - special case
289    else{
290      // to avoid getting stuck we shuffle
291      sample_.shuffle();
292      for (size_t i=0; i<sample_.size(); i++) {
293        if (target(sample_(i))==1){
294          for (size_t j=0; j<sample_.size(); j++) {
295            if ( target(sample_(j))==-1 && 
296                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
297              sample_.update_first(i);
298              sample_.update_second(j);
299              return true;
300            }
301          }
302        }
303      }
304    }
305   
306    // If there is no violation then we should stop training
307    return false;
308
309  }
310 
311 
312  void SVM::bounds( double& u, double& v) const
313  {
314    if (target(sample_.value_first())!=target(sample_.value_second())) {
315      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
316        v = std::numeric_limits<double>::max();
317        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
318      }
319      else {
320        v = (std::numeric_limits<double>::max() - 
321             alpha_(sample_.value_first()) + 
322             alpha_(sample_.value_second()));
323        u = 0;
324      }
325    }
326    else {       
327      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
328           std::numeric_limits<double>::max()) {
329        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
330              std::numeric_limits<double>::max());
331        v =  std::numeric_limits<double>::max();   
332      }
333      else {
334        u = 0;
335        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
336      }
337    }
338  }
339 
340  bool SVM::calculate_bias(void)
341  {
342
343    // calculating output without bias
344    for (size_t i=0; i<output_.size(); i++) {
345      output_(i)=0;
346      for (size_t j=0; j<output_.size(); j++) 
347        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
348    }
349
350    if (!sample_.nof_sv()){
351      std::cerr << "SVM::train() error: " 
352                << "Cannot calculate bias because there is no support vector" 
353                << std::endl;
354      return false;
355    }
356
357    // For samples with alpha>0, we have: target*output=1-alpha/C
358    bias_=0;
359    for (size_t i=0; i<sample_.nof_sv(); i++) 
360      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
361                output_(sample_(i)) );
362    bias_=bias_/sample_.nof_sv();
363    for (size_t i=0; i<output_.size(); i++) 
364      output_(i) += bias_;
365     
366    return true;
367  }
368
369  void SVM::set_C(const double C)
370  {
371    C_inverse_ = 1/C;
372  }
373
374
375}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.