source: trunk/yat/classifier/SVM.cc @ 1042

Last change on this file since 1042 was 1042, checked in by Peter, 14 years ago

fixes #268 - remove return value in SupervisedClassifier::train()

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.9 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "SVM.h"
27#include "DataLookup2D.h"
28#include "Target.h"
29#include "yat/random/random.h"
30#include "yat/statistics/Averager.h"
31#include "yat/utility/matrix.h"
32#include "yat/utility/vector.h"
33
34#include <algorithm>
35#include <cassert>
36#include <cctype>
37#include <cmath>
38#include <limits>
39#include <stdexcept>
40#include <utility>
41#include <vector>
42
43namespace theplu {
44namespace yat {
45namespace classifier { 
46
47  SVM::SVM(const KernelLookup& kernel, const Target& target)
48    : SupervisedClassifier(target),
49      alpha_(target.size(),0),
50      bias_(0),
51      C_inverse_(0),
52      kernel_(&kernel),
53      margin_(0),
54      max_epochs_(100000),
55      output_(target.size(),0),
56      owner_(false),
57      sample_(target.size()),
58      trained_(false),
59      tolerance_(0.00000001)
60  {
61#ifndef NDEBUG
62    assert(kernel.columns()==kernel.rows());
63    assert(kernel.columns()==alpha_.size());
64    for (size_t i=0; i<alpha_.size(); i++) 
65      for (size_t j=0; j<alpha_.size(); j++)
66        assert(kernel(i,j)==kernel(j,i));
67    for (size_t i=0; i<alpha_.size(); i++) 
68      for (size_t j=0; j<alpha_.size(); j++) 
69        assert((*kernel_)(i,j)==(*kernel_)(j,i));
70    for (size_t i = 0; i<kernel_->rows(); i++)
71      for (size_t j = 0; j<kernel_->columns(); j++)
72        if (std::isnan((*kernel_)(i,j)))
73          std::cerr << "SVM: Found nan in kernel: " << i << " " 
74                    << j << std::endl;
75#endif       
76  }
77
78  SVM::~SVM()
79  {
80    if (owner_)
81      delete kernel_;
82  }
83
84  const utility::vector& SVM::alpha(void) const
85  {
86    return alpha_;
87  }
88
89  double SVM::C(void) const
90  {
91    return 1.0/C_inverse_;
92  }
93
94  void SVM::calculate_margin(void)
95  {
96    margin_ = 0;
97    for(size_t i = 0; i<alpha_.size(); ++i){
98      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
99      for(size_t j = i+1; j<alpha_.size(); ++j)
100        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
101    }
102  }
103
104  const DataLookup2D& SVM::data(void) const
105  {
106    return *kernel_;
107  }
108
109
110  double SVM::kernel_mod(const size_t i, const size_t j) const
111  {
112    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
113  }
114
115  SupervisedClassifier* SVM::make_classifier(const DataLookup2D& data,
116                                             const Target& target) const
117  {
118    SVM* sc=0;
119    try {
120      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
121      assert(data.rows()==data.columns());
122      assert(data.columns()==target.size());
123      sc = new SVM(kernel,target);
124      //Copy those variables possible to modify from outside
125      sc->set_C(this->C());
126      sc->max_epochs(max_epochs());
127    }
128    catch (std::bad_cast) {
129      std::string str = 
130        "Error in SVM::make_classifier: DataLookup2D of unexpected class.";
131      throw std::runtime_error(str);
132    }
133 
134    return sc;
135  }
136
137  long int SVM::max_epochs(void) const
138  {
139    return max_epochs_;
140  }
141
142
143  void SVM::max_epochs(long int n)
144  {
145    max_epochs_=n;
146  }
147
148
149  const theplu::yat::utility::vector& SVM::output(void) const
150  {
151    return output_;
152  }
153
154  void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
155  {
156    try {
157      const KernelLookup& input_kernel =dynamic_cast<const KernelLookup&>(input);
158      assert(input.rows()==alpha_.size());
159      prediction.clone(utility::matrix(2,input.columns(),0));
160      for (size_t i = 0; i<input.columns(); i++){
161        for (size_t j = 0; j<input.rows(); j++){
162          prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
163          assert(target(j));
164        }
165        prediction(0,i) = margin_ * (prediction(0,i) + bias_);
166      }
167     
168      for (size_t i = 0; i<prediction.columns(); i++)
169        prediction(1,i) = -prediction(0,i);
170    }
171    catch (std::bad_cast) {
172      std::string str = 
173        "Error in SVM::predict: DataLookup2D of unexpected class.";
174      throw std::runtime_error(str);
175    }
176   
177  }
178
179  double SVM::predict(const DataLookup1D& x) const
180  {
181    double y=0;
182    for (size_t i=0; i<alpha_.size(); i++)
183      y += alpha_(i)*target_(i)*kernel_->element(x,i);
184
185    return margin_*(y+bias_);
186  }
187
188  double SVM::predict(const DataLookupWeighted1D& x) const
189  {
190    double y=0;
191    for (size_t i=0; i<alpha_.size(); i++)
192      y += alpha_(i)*target_(i)*kernel_->element(x,i);
193
194    return margin_*(y+bias_);
195  }
196
197  void SVM::reset(void)
198  {
199    trained_=false;
200    alpha_ = utility::vector(target_.size(), 0);
201  }
202
203  int SVM::target(size_t i) const
204  {
205    return target_.binary(i) ? 1 : -1;
206  }
207
208  void SVM::train(void) 
209  {
210    // initializing variables for optimization
211    assert(target_.size()==kernel_->rows());
212    assert(target_.size()==alpha_.size());
213
214    sample_.init(alpha_,tolerance_);
215    utility::vector   E(target_.size(),0);
216    for (size_t i=0; i<E.size(); i++) {
217      for (size_t j=0; j<E.size(); j++) 
218        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
219      E(i)-=target(i);
220    }
221    assert(target_.size()==E.size());
222    assert(target_.size()==sample_.size());
223
224    unsigned long int epochs = 0;
225    double alpha_new2;
226    double alpha_new1;
227    double u;
228    double v;
229
230    // Training loop
231    while(choose(E)) {
232      bounds(u,v);       
233      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
234                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
235                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
236     
237      double alpha_old1=alpha_(sample_.value_first());
238      double alpha_old2=alpha_(sample_.value_second());
239      alpha_new2 = ( alpha_(sample_.value_second()) + 
240                     target(sample_.value_second())*
241                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
242     
243      if (alpha_new2 > v)
244        alpha_new2 = v;
245      else if (alpha_new2<u)
246        alpha_new2 = u;
247     
248      // Updating the alphas
249      // if alpha is 'zero' make the sample a non-support vector
250      if (alpha_new2 < tolerance_){
251        sample_.nsv_second();
252      }
253      else{
254        sample_.sv_second();
255      }
256     
257     
258      alpha_new1 = (alpha_(sample_.value_first()) + 
259                    (target(sample_.value_first()) * 
260                     target(sample_.value_second()) * 
261                     (alpha_(sample_.value_second()) - alpha_new2) ));
262           
263      // if alpha is 'zero' make the sample a non-support vector
264      if (alpha_new1 < tolerance_){
265        sample_.nsv_first();
266      }
267      else
268        sample_.sv_first();
269     
270      alpha_(sample_.value_first()) = alpha_new1;
271      alpha_(sample_.value_second()) = alpha_new2;
272     
273      // update E vector
274      // Peter, perhaps one should only update SVs, but what happens in choose?
275      for (size_t i=0; i<E.size(); i++) {
276        E(i)+=( kernel_mod(i,sample_.value_first())*
277                target(sample_.value_first()) *
278                (alpha_new1-alpha_old1) );
279        E(i)+=( kernel_mod(i,sample_.value_second())*
280                target(sample_.value_second()) *
281                (alpha_new2-alpha_old2) );
282      }
283           
284      epochs++; 
285      if (epochs>max_epochs_){
286        throw std::runtime_error("SVM: maximal number of epochs reached.");
287      }
288    }
289    calculate_margin();
290    trained_ = calculate_bias();
291  }
292
293
294  bool SVM::choose(const theplu::yat::utility::vector& E)
295  {
296    // First check for violation among SVs
297    // E should be the same for all SVs
298    // Choose that pair having largest violation/difference.
299    sample_.update_second(0);
300    sample_.update_first(0);
301    if (sample_.nof_sv()>1){
302
303      double max = E(sample_(0));
304      double min = max;
305      for (size_t i=1; i<sample_.nof_sv(); i++){ 
306        assert(alpha_(sample_(i))>tolerance_);
307        if (E(sample_(i)) > max){
308          max = E(sample_(i));
309          sample_.update_second(i);
310        }
311        else if (E(sample_(i))<min){
312          min = E(sample_(i));
313          sample_.update_first(i);
314        }
315      }
316      assert(alpha_(sample_.value_first())>tolerance_);
317      assert(alpha_(sample_.value_second())>tolerance_);
318
319      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
320        return true;
321      }
322     
323      // If no violation check among non-support vectors
324      sample_.shuffle();
325      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
326        if (target_.binary(sample_(i))){
327          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
328            sample_.update_second(i);
329            return true;
330          }
331        }
332        else{
333          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
334            sample_.update_first(i);
335            return true;
336          }
337        }
338      }
339    }
340
341    // if no support vectors - special case
342    else{
343      // to avoid getting stuck we shuffle
344      sample_.shuffle();
345      for (size_t i=0; i<sample_.size(); i++) {
346        if (target(sample_(i))==1){
347          for (size_t j=0; j<sample_.size(); j++) {
348            if ( target(sample_(j))==-1 && 
349                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
350              sample_.update_first(i);
351              sample_.update_second(j);
352              return true;
353            }
354          }
355        }
356      }
357    }
358   
359    // If there is no violation then we should stop training
360    return false;
361
362  }
363 
364 
365  void SVM::bounds( double& u, double& v) const
366  {
367    if (target(sample_.value_first())!=target(sample_.value_second())) {
368      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
369        v = std::numeric_limits<double>::max();
370        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
371      }
372      else {
373        v = (std::numeric_limits<double>::max() - 
374             alpha_(sample_.value_first()) + 
375             alpha_(sample_.value_second()));
376        u = 0;
377      }
378    }
379    else {       
380      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
381           std::numeric_limits<double>::max()) {
382        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
383              std::numeric_limits<double>::max());
384        v =  std::numeric_limits<double>::max();   
385      }
386      else {
387        u = 0;
388        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
389      }
390    }
391  }
392 
393  bool SVM::calculate_bias(void)
394  {
395
396    // calculating output without bias
397    for (size_t i=0; i<output_.size(); i++) {
398      output_(i)=0;
399      for (size_t j=0; j<output_.size(); j++) 
400        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
401    }
402
403    if (!sample_.nof_sv()){
404      std::cerr << "SVM::train() error: " 
405                << "Cannot calculate bias because there is no support vector" 
406                << std::endl;
407      return false;
408    }
409
410    // For samples with alpha>0, we have: target*output=1-alpha/C
411    bias_=0;
412    for (size_t i=0; i<sample_.nof_sv(); i++) 
413      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
414                output_(sample_(i)) );
415    bias_=bias_/sample_.nof_sv();
416    for (size_t i=0; i<output_.size(); i++) 
417      output_(i) += bias_;
418     
419    return true;
420  }
421
422  void SVM::set_C(const double C)
423  {
424    C_inverse_ = 1/C;
425  }
426
427}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.