source: trunk/yat/classifier/SVM.cc @ 865

Last change on this file since 865 was 865, checked in by Peter, 14 years ago

changing URL to http://trac.thep.lu.se/trac/yat

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.8 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "SVM.h"
27#include "DataLookup2D.h"
28#include "Target.h"
29#include "yat/random/random.h"
30#include "yat/statistics/Averager.h"
31#include "yat/utility/matrix.h"
32#include "yat/utility/vector.h"
33
34#include <algorithm>
35#include <cassert>
36#include <cctype>
37#include <cmath>
38#include <limits>
39#include <utility>
40#include <vector>
41
42namespace theplu {
43namespace yat {
44namespace classifier { 
45
46  SVM::SVM(const KernelLookup& kernel, const Target& target)
47    : SupervisedClassifier(target),
48      alpha_(target.size(),0),
49      bias_(0),
50      C_inverse_(0),
51      kernel_(&kernel),
52      margin_(0),
53      max_epochs_(100000),
54      output_(target.size(),0),
55      owner_(false),
56      sample_(target.size()),
57      trained_(false),
58      tolerance_(0.00000001)
59  {
60#ifndef NDEBUG
61    assert(kernel.columns()==kernel.rows());
62    assert(kernel.columns()==alpha_.size());
63    for (size_t i=0; i<alpha_.size(); i++) 
64      for (size_t j=0; j<alpha_.size(); j++)
65        assert(kernel(i,j)==kernel(j,i));
66    for (size_t i=0; i<alpha_.size(); i++) 
67      for (size_t j=0; j<alpha_.size(); j++) 
68        assert((*kernel_)(i,j)==(*kernel_)(j,i));
69    for (size_t i = 0; i<kernel_->rows(); i++)
70      for (size_t j = 0; j<kernel_->columns(); j++)
71        if (std::isnan((*kernel_)(i,j)))
72          std::cerr << "SVM: Found nan in kernel: " << i << " " 
73                    << j << std::endl;
74#endif       
75  }
76
77  SVM::~SVM()
78  {
79    if (owner_)
80      delete kernel_;
81  }
82
83  const utility::vector& SVM::alpha(void) const
84  {
85    return alpha_;
86  }
87
88  double SVM::C(void) const
89  {
90    return 1.0/C_inverse_;
91  }
92
93  void SVM::calculate_margin(void)
94  {
95    margin_ = 0;
96    for(size_t i = 0; i<alpha_.size(); ++i){
97      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
98      for(size_t j = i+1; j<alpha_.size(); ++j)
99        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
100    }
101  }
102
103  const DataLookup2D& SVM::data(void) const
104  {
105    return *kernel_;
106  }
107
108
109  double SVM::kernel_mod(const size_t i, const size_t j) const
110  {
111    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
112  }
113
114  SupervisedClassifier* SVM::make_classifier(const DataLookup2D& data,
115                                             const Target& target) const
116  {
117    SVM* sc=0;
118    try {
119      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
120      assert(data.rows()==data.columns());
121      assert(data.columns()==target.size());
122      sc = new SVM(kernel,target);
123
124    //Copy those variables possible to modify from outside
125    // Peter, in particular C
126    }
127    catch (std::bad_cast) {
128      std::cerr << "Warning: SVM::make_classifier only takes KernelLookup" 
129                << std::endl;
130    }
131    return sc;
132  }
133
134  long int SVM::max_epochs(void) const
135  {
136    return max_epochs_;
137  }
138
139  const theplu::yat::utility::vector& SVM::output(void) const
140  {
141    return output_;
142  }
143
144  void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
145  {
146    // Peter, should check success of dynamic_cast
147    const KernelLookup& input_kernel = dynamic_cast<const KernelLookup&>(input);
148
149    assert(input.rows()==alpha_.size());
150    prediction.clone(utility::matrix(2,input.columns(),0));
151    for (size_t i = 0; i<input.columns(); i++){
152      for (size_t j = 0; j<input.rows(); j++){
153        prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
154        assert(target(j));
155      }
156      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
157    }
158
159    for (size_t i = 0; i<prediction.columns(); i++)
160      prediction(1,i) = -prediction(0,i);
161   
162    assert(prediction(0,0));
163  }
164
165  double SVM::predict(const DataLookup1D& x) const
166  {
167    double y=0;
168    for (size_t i=0; i<alpha_.size(); i++)
169      y += alpha_(i)*target_(i)*kernel_->element(x,i);
170
171    return margin_*(y+bias_);
172  }
173
174  double SVM::predict(const DataLookupWeighted1D& x) const
175  {
176    double y=0;
177    for (size_t i=0; i<alpha_.size(); i++)
178      y += alpha_(i)*target_(i)*kernel_->element(x,i);
179
180    return margin_*(y+bias_);
181  }
182
183  void SVM::reset(void)
184  {
185    trained_=false;
186    alpha_.clone(utility::vector(target_.size(), 0));
187  }
188
189  int SVM::target(size_t i) const
190  {
191    return target_.binary(i) ? 1 : -1;
192  }
193
194  bool SVM::train(void) 
195  {
196    // initializing variables for optimization
197    assert(target_.size()==kernel_->rows());
198    assert(target_.size()==alpha_.size());
199
200    sample_.init(alpha_,tolerance_);
201    utility::vector   E(target_.size(),0);
202    for (size_t i=0; i<E.size(); i++) {
203      for (size_t j=0; j<E.size(); j++) 
204        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
205      E(i)-=target(i);
206    }
207    assert(target_.size()==E.size());
208    assert(target_.size()==sample_.size());
209
210    unsigned long int epochs = 0;
211    double alpha_new2;
212    double alpha_new1;
213    double u;
214    double v;
215
216    // Training loop
217    while(choose(E)) {
218      bounds(u,v);       
219      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
220                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
221                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
222     
223      double alpha_old1=alpha_(sample_.value_first());
224      double alpha_old2=alpha_(sample_.value_second());
225      alpha_new2 = ( alpha_(sample_.value_second()) + 
226                     target(sample_.value_second())*
227                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
228     
229      if (alpha_new2 > v)
230        alpha_new2 = v;
231      else if (alpha_new2<u)
232        alpha_new2 = u;
233     
234      // Updating the alphas
235      // if alpha is 'zero' make the sample a non-support vector
236      if (alpha_new2 < tolerance_){
237        sample_.nsv_second();
238      }
239      else{
240        sample_.sv_second();
241      }
242     
243     
244      alpha_new1 = (alpha_(sample_.value_first()) + 
245                    (target(sample_.value_first()) * 
246                     target(sample_.value_second()) * 
247                     (alpha_(sample_.value_second()) - alpha_new2) ));
248           
249      // if alpha is 'zero' make the sample a non-support vector
250      if (alpha_new1 < tolerance_){
251        sample_.nsv_first();
252      }
253      else
254        sample_.sv_first();
255     
256      alpha_(sample_.value_first()) = alpha_new1;
257      alpha_(sample_.value_second()) = alpha_new2;
258     
259      // update E vector
260      // Peter, perhaps one should only update SVs, but what happens in choose?
261      for (size_t i=0; i<E.size(); i++) {
262        E(i)+=( kernel_mod(i,sample_.value_first())*
263                target(sample_.value_first()) *
264                (alpha_new1-alpha_old1) );
265        E(i)+=( kernel_mod(i,sample_.value_second())*
266                target(sample_.value_second()) *
267                (alpha_new2-alpha_old2) );
268      }
269           
270      epochs++; 
271      if (epochs>max_epochs_){
272        std::cerr << "WARNING: SVM: maximal number of epochs reached.\n";
273        calculate_bias();
274        calculate_margin();
275        return false;
276      }
277    }
278    calculate_margin();
279    trained_ = calculate_bias();
280    return trained_;
281  }
282
283
284  bool SVM::choose(const theplu::yat::utility::vector& E)
285  {
286    // First check for violation among SVs
287    // E should be the same for all SVs
288    // Choose that pair having largest violation/difference.
289    sample_.update_second(0);
290    sample_.update_first(0);
291    if (sample_.nof_sv()>1){
292
293      double max = E(sample_(0));
294      double min = max;
295      for (size_t i=1; i<sample_.nof_sv(); i++){ 
296        assert(alpha_(sample_(i))>tolerance_);
297        if (E(sample_(i)) > max){
298          max = E(sample_(i));
299          sample_.update_second(i);
300        }
301        else if (E(sample_(i))<min){
302          min = E(sample_(i));
303          sample_.update_first(i);
304        }
305      }
306      assert(alpha_(sample_.value_first())>tolerance_);
307      assert(alpha_(sample_.value_second())>tolerance_);
308
309      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
310        return true;
311      }
312     
313      // If no violation check among non-support vectors
314      sample_.shuffle();
315      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
316        if (target_.binary(sample_(i))){
317          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
318            sample_.update_second(i);
319            return true;
320          }
321        }
322        else{
323          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
324            sample_.update_first(i);
325            return true;
326          }
327        }
328      }
329    }
330
331    // if no support vectors - special case
332    else{
333      // to avoid getting stuck we shuffle
334      sample_.shuffle();
335      for (size_t i=0; i<sample_.size(); i++) {
336        if (target(sample_(i))==1){
337          for (size_t j=0; j<sample_.size(); j++) {
338            if ( target(sample_(j))==-1 && 
339                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
340              sample_.update_first(i);
341              sample_.update_second(j);
342              return true;
343            }
344          }
345        }
346      }
347    }
348   
349    // If there is no violation then we should stop training
350    return false;
351
352  }
353 
354 
355  void SVM::bounds( double& u, double& v) const
356  {
357    if (target(sample_.value_first())!=target(sample_.value_second())) {
358      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
359        v = std::numeric_limits<double>::max();
360        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
361      }
362      else {
363        v = (std::numeric_limits<double>::max() - 
364             alpha_(sample_.value_first()) + 
365             alpha_(sample_.value_second()));
366        u = 0;
367      }
368    }
369    else {       
370      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
371           std::numeric_limits<double>::max()) {
372        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
373              std::numeric_limits<double>::max());
374        v =  std::numeric_limits<double>::max();   
375      }
376      else {
377        u = 0;
378        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
379      }
380    }
381  }
382 
383  bool SVM::calculate_bias(void)
384  {
385
386    // calculating output without bias
387    for (size_t i=0; i<output_.size(); i++) {
388      output_(i)=0;
389      for (size_t j=0; j<output_.size(); j++) 
390        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
391    }
392
393    if (!sample_.nof_sv()){
394      std::cerr << "SVM::train() error: " 
395                << "Cannot calculate bias because there is no support vector" 
396                << std::endl;
397      return false;
398    }
399
400    // For samples with alpha>0, we have: target*output=1-alpha/C
401    bias_=0;
402    for (size_t i=0; i<sample_.nof_sv(); i++) 
403      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
404                output_(sample_(i)) );
405    bias_=bias_/sample_.nof_sv();
406    for (size_t i=0; i<output_.size(); i++) 
407      output_(i) += bias_;
408     
409    return true;
410  }
411
412  void SVM::set_C(const double C)
413  {
414    C_inverse_ = 1/C;
415  }
416
417}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.