source: trunk/yat/classifier/SVM.cc @ 676

Last change on this file since 676 was 675, checked in by Jari Häkkinen, 15 years ago

References #83. Changing project name to yat. Compilation will fail in this revision.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.0 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "yat/classifier/SVM.h"
25
26#include "yat/classifier/DataLookup2D.h"
27#include "yat/random/random.h"
28#include "yat/statistics/Averager.h"
29#include "yat/utility/matrix.h"
30#include "yat/utility/vector.h"
31
32#include <algorithm>
33#include <cassert>
34#include <cctype>
35#include <cmath>
36#include <limits>
37#include <utility>
38#include <vector>
39
40
41namespace theplu {
42namespace classifier { 
43
44  SVM::SVM(const KernelLookup& kernel, const Target& target)
45    : SupervisedClassifier(target),
46      alpha_(target.size(),0),
47      bias_(0),
48      C_inverse_(0),
49      kernel_(&kernel),
50      margin_(0),
51      max_epochs_(100000),
52      output_(target.size(),0),
53      owner_(false),
54      sample_(target.size()),
55      trained_(false),
56      tolerance_(0.00000001)
57  {
58#ifndef NDEBUG
59    assert(kernel.columns()==kernel.rows());
60    assert(kernel.columns()==alpha_.size());
61    for (size_t i=0; i<alpha_.size(); i++) 
62      for (size_t j=0; j<alpha_.size(); j++)
63        assert(kernel(i,j)==kernel(j,i));
64    for (size_t i=0; i<alpha_.size(); i++) 
65      for (size_t j=0; j<alpha_.size(); j++) 
66        assert((*kernel_)(i,j)==(*kernel_)(j,i));
67    for (size_t i = 0; i<kernel_->rows(); i++)
68      for (size_t j = 0; j<kernel_->columns(); j++)
69        if (std::isnan((*kernel_)(i,j)))
70          std::cerr << "SVM: Found nan in kernel: " << i << " " 
71                    << j << std::endl;
72#endif       
73  }
74
75  SVM::~SVM()
76  {
77    if (owner_)
78      delete kernel_;
79  }
80
81
82  void SVM::calculate_margin(void)
83  {
84    margin_ = 0;
85    for(size_t i = 0; i<alpha_.size(); ++i){
86      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
87      for(size_t j = i+1; j<alpha_.size(); ++j)
88        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
89    }
90  }
91
92
93  SupervisedClassifier* SVM::make_classifier(const DataLookup2D& data,
94                                             const Target& target) const
95  {
96    SVM* sc=0;
97    try {
98      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
99      assert(data.rows()==data.columns());
100      assert(data.columns()==target.size());
101      sc = new SVM(kernel,target);
102
103    //Copy those variables possible to modify from outside
104    // Peter, in particular C
105    }
106    catch (std::bad_cast) {
107      std::cerr << "Warning: SVM::make_classifier only takes KernelLookup" 
108                << std::endl;
109    }
110    return sc;
111  }
112
113  void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
114  {
115    // Peter, should check success of dynamic_cast
116    const KernelLookup& input_kernel = dynamic_cast<const KernelLookup&>(input);
117
118    assert(input.rows()==alpha_.size());
119    prediction = utility::matrix(2,input.columns(),0);
120    for (size_t i = 0; i<input.columns(); i++){
121      for (size_t j = 0; j<input.rows(); j++){
122        prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
123        assert(target(j));
124      }
125      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
126    }
127
128    for (size_t i = 0; i<prediction.columns(); i++)
129      prediction(1,i) = -prediction(0,i);
130   
131    assert(prediction(0,0));
132  }
133
134  double SVM::predict(const DataLookup1D& x) const
135  {
136    double y=0;
137    for (size_t i=0; i<alpha_.size(); i++)
138      y += alpha_(i)*target_(i)*kernel_->element(x,i);
139
140    return margin_*(y+bias_);
141  }
142
143  double SVM::predict(const DataLookupWeighted1D& x) const
144  {
145    double y=0;
146    for (size_t i=0; i<alpha_.size(); i++)
147      y += alpha_(i)*target_(i)*kernel_->element(x,i);
148
149    return margin_*(y+bias_);
150  }
151
152  bool SVM::train(void) 
153  {
154    // initializing variables for optimization
155    assert(target_.size()==kernel_->rows());
156    assert(target_.size()==alpha_.size());
157
158    sample_.init(alpha_,tolerance_);
159    utility::vector   E(target_.size(),0);
160    for (size_t i=0; i<E.size(); i++) {
161      for (size_t j=0; j<E.size(); j++) 
162        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
163      E(i)-=target(i);
164    }
165    assert(target_.size()==E.size());
166    assert(target_.size()==sample_.size());
167
168    unsigned long int epochs = 0;
169    double alpha_new2;
170    double alpha_new1;
171    double u;
172    double v;
173
174    // Training loop
175    while(choose(E)) {
176      bounds(u,v);       
177      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
178                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
179                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
180     
181      double alpha_old1=alpha_(sample_.value_first());
182      double alpha_old2=alpha_(sample_.value_second());
183      alpha_new2 = ( alpha_(sample_.value_second()) + 
184                     target(sample_.value_second())*
185                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
186     
187      if (alpha_new2 > v)
188        alpha_new2 = v;
189      else if (alpha_new2<u)
190        alpha_new2 = u;
191     
192      // Updating the alphas
193      // if alpha is 'zero' make the sample a non-support vector
194      if (alpha_new2 < tolerance_){
195        sample_.nsv_second();
196      }
197      else{
198        sample_.sv_second();
199      }
200     
201     
202      alpha_new1 = (alpha_(sample_.value_first()) + 
203                    (target(sample_.value_first()) * 
204                     target(sample_.value_second()) * 
205                     (alpha_(sample_.value_second()) - alpha_new2) ));
206           
207      // if alpha is 'zero' make the sample a non-support vector
208      if (alpha_new1 < tolerance_){
209        sample_.nsv_first();
210      }
211      else
212        sample_.sv_first();
213     
214      alpha_(sample_.value_first()) = alpha_new1;
215      alpha_(sample_.value_second()) = alpha_new2;
216     
217      // update E vector
218      // Peter, perhaps one should only update SVs, but what happens in choose?
219      for (size_t i=0; i<E.size(); i++) {
220        E(i)+=( kernel_mod(i,sample_.value_first())*
221                target(sample_.value_first()) *
222                (alpha_new1-alpha_old1) );
223        E(i)+=( kernel_mod(i,sample_.value_second())*
224                target(sample_.value_second()) *
225                (alpha_new2-alpha_old2) );
226      }
227           
228      epochs++; 
229      if (epochs>max_epochs_){
230        std::cerr << "WARNING: SVM: maximal number of epochs reached.\n";
231        calculate_bias();
232        calculate_margin();
233        return false;
234      }
235    }
236    calculate_margin();
237    trained_ = calculate_bias();
238    return trained_;
239  }
240
241
242  bool SVM::choose(const theplu::utility::vector& E)
243  {
244    // First check for violation among SVs
245    // E should be the same for all SVs
246    // Choose that pair having largest violation/difference.
247    sample_.update_second(0);
248    sample_.update_first(0);
249    if (sample_.nof_sv()>1){
250
251      double max = E(sample_(0));
252      double min = max;
253      for (size_t i=1; i<sample_.nof_sv(); i++){ 
254        assert(alpha_(sample_(i))>tolerance_);
255        if (E(sample_(i)) > max){
256          max = E(sample_(i));
257          sample_.update_second(i);
258        }
259        else if (E(sample_(i))<min){
260          min = E(sample_(i));
261          sample_.update_first(i);
262        }
263      }
264      assert(alpha_(sample_.value_first())>tolerance_);
265      assert(alpha_(sample_.value_second())>tolerance_);
266
267      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
268        return true;
269      }
270     
271      // If no violation check among non-support vectors
272      sample_.shuffle();
273      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
274        if (target_.binary(sample_(i))){
275          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
276            sample_.update_second(i);
277            return true;
278          }
279        }
280        else{
281          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
282            sample_.update_first(i);
283            return true;
284          }
285        }
286      }
287    }
288
289    // if no support vectors - special case
290    else{
291      // to avoid getting stuck we shuffle
292      sample_.shuffle();
293      for (size_t i=0; i<sample_.size(); i++) {
294        if (target(sample_(i))==1){
295          for (size_t j=0; j<sample_.size(); j++) {
296            if ( target(sample_(j))==-1 && 
297                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
298              sample_.update_first(i);
299              sample_.update_second(j);
300              return true;
301            }
302          }
303        }
304      }
305    }
306   
307    // If there is no violation then we should stop training
308    return false;
309
310  }
311 
312 
313  void SVM::bounds( double& u, double& v) const
314  {
315    if (target(sample_.value_first())!=target(sample_.value_second())) {
316      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
317        v = std::numeric_limits<double>::max();
318        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
319      }
320      else {
321        v = (std::numeric_limits<double>::max() - 
322             alpha_(sample_.value_first()) + 
323             alpha_(sample_.value_second()));
324        u = 0;
325      }
326    }
327    else {       
328      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
329           std::numeric_limits<double>::max()) {
330        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
331              std::numeric_limits<double>::max());
332        v =  std::numeric_limits<double>::max();   
333      }
334      else {
335        u = 0;
336        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
337      }
338    }
339  }
340 
341  bool SVM::calculate_bias(void)
342  {
343
344    // calculating output without bias
345    for (size_t i=0; i<output_.size(); i++) {
346      output_(i)=0;
347      for (size_t j=0; j<output_.size(); j++) 
348        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
349    }
350
351    if (!sample_.nof_sv()){
352      std::cerr << "SVM::train() error: " 
353                << "Cannot calculate bias because there is no support vector" 
354                << std::endl;
355      return false;
356    }
357
358    // For samples with alpha>0, we have: target*output=1-alpha/C
359    bias_=0;
360    for (size_t i=0; i<sample_.nof_sv(); i++) 
361      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
362                output_(sample_(i)) );
363    bias_=bias_/sample_.nof_sv();
364    for (size_t i=0; i<output_.size(); i++) 
365      output_(i) += bias_;
366     
367    return true;
368  }
369
370  void SVM::set_C(const double C)
371  {
372    C_inverse_ = 1/C;
373  }
374
375
376}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.