source: trunk/yat/classifier/SVM.cc @ 747

Last change on this file since 747 was 747, checked in by Peter, 15 years ago

replaced includes in header files with forward declarations when possible. Added some includes in cc files.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.6 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "SVM.h"
25#include "DataLookup2D.h"
26#include "Target.h"
27#include "yat/random/random.h"
28#include "yat/statistics/Averager.h"
29#include "yat/utility/matrix.h"
30#include "yat/utility/vector.h"
31
32#include <algorithm>
33#include <cassert>
34#include <cctype>
35#include <cmath>
36#include <limits>
37#include <utility>
38#include <vector>
39
40namespace theplu {
41namespace yat {
42namespace classifier { 
43
44  SVM::SVM(const KernelLookup& kernel, const Target& target)
45    : SupervisedClassifier(target),
46      alpha_(target.size(),0),
47      bias_(0),
48      C_inverse_(0),
49      kernel_(&kernel),
50      margin_(0),
51      max_epochs_(100000),
52      output_(target.size(),0),
53      owner_(false),
54      sample_(target.size()),
55      trained_(false),
56      tolerance_(0.00000001)
57  {
58#ifndef NDEBUG
59    assert(kernel.columns()==kernel.rows());
60    assert(kernel.columns()==alpha_.size());
61    for (size_t i=0; i<alpha_.size(); i++) 
62      for (size_t j=0; j<alpha_.size(); j++)
63        assert(kernel(i,j)==kernel(j,i));
64    for (size_t i=0; i<alpha_.size(); i++) 
65      for (size_t j=0; j<alpha_.size(); j++) 
66        assert((*kernel_)(i,j)==(*kernel_)(j,i));
67    for (size_t i = 0; i<kernel_->rows(); i++)
68      for (size_t j = 0; j<kernel_->columns(); j++)
69        if (std::isnan((*kernel_)(i,j)))
70          std::cerr << "SVM: Found nan in kernel: " << i << " " 
71                    << j << std::endl;
72#endif       
73  }
74
75  SVM::~SVM()
76  {
77    if (owner_)
78      delete kernel_;
79  }
80
81  const utility::vector& SVM::alpha(void) const
82  {
83    return alpha_;
84  }
85
86  double SVM::C(void) const
87  {
88    return 1.0/C_inverse_;
89  }
90
91  void SVM::calculate_margin(void)
92  {
93    margin_ = 0;
94    for(size_t i = 0; i<alpha_.size(); ++i){
95      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
96      for(size_t j = i+1; j<alpha_.size(); ++j)
97        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
98    }
99  }
100
101  const DataLookup2D& SVM::data(void) const
102  {
103    return *kernel_;
104  }
105
106
107  double SVM::kernel_mod(const size_t i, const size_t j) const
108  {
109    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
110  }
111
112  SupervisedClassifier* SVM::make_classifier(const DataLookup2D& data,
113                                             const Target& target) const
114  {
115    SVM* sc=0;
116    try {
117      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
118      assert(data.rows()==data.columns());
119      assert(data.columns()==target.size());
120      sc = new SVM(kernel,target);
121
122    //Copy those variables possible to modify from outside
123    // Peter, in particular C
124    }
125    catch (std::bad_cast) {
126      std::cerr << "Warning: SVM::make_classifier only takes KernelLookup" 
127                << std::endl;
128    }
129    return sc;
130  }
131
132  long int SVM::max_epochs(void) const
133  {
134    return max_epochs_;
135  }
136
137  const theplu::yat::utility::vector& SVM::output(void) const
138  {
139    return output_;
140  }
141
142  void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
143  {
144    // Peter, should check success of dynamic_cast
145    const KernelLookup& input_kernel = dynamic_cast<const KernelLookup&>(input);
146
147    assert(input.rows()==alpha_.size());
148    prediction = utility::matrix(2,input.columns(),0);
149    for (size_t i = 0; i<input.columns(); i++){
150      for (size_t j = 0; j<input.rows(); j++){
151        prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
152        assert(target(j));
153      }
154      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
155    }
156
157    for (size_t i = 0; i<prediction.columns(); i++)
158      prediction(1,i) = -prediction(0,i);
159   
160    assert(prediction(0,0));
161  }
162
163  double SVM::predict(const DataLookup1D& x) const
164  {
165    double y=0;
166    for (size_t i=0; i<alpha_.size(); i++)
167      y += alpha_(i)*target_(i)*kernel_->element(x,i);
168
169    return margin_*(y+bias_);
170  }
171
172  double SVM::predict(const DataLookupWeighted1D& x) const
173  {
174    double y=0;
175    for (size_t i=0; i<alpha_.size(); i++)
176      y += alpha_(i)*target_(i)*kernel_->element(x,i);
177
178    return margin_*(y+bias_);
179  }
180
181  void SVM::reset(void)
182  {
183    trained_=false;
184    alpha_=utility::vector(target_.size(), 0);
185  }
186
187  int SVM::target(size_t i) const
188  {
189    return target_.binary(i) ? 1 : -1;
190  }
191
192  bool SVM::train(void) 
193  {
194    // initializing variables for optimization
195    assert(target_.size()==kernel_->rows());
196    assert(target_.size()==alpha_.size());
197
198    sample_.init(alpha_,tolerance_);
199    utility::vector   E(target_.size(),0);
200    for (size_t i=0; i<E.size(); i++) {
201      for (size_t j=0; j<E.size(); j++) 
202        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
203      E(i)-=target(i);
204    }
205    assert(target_.size()==E.size());
206    assert(target_.size()==sample_.size());
207
208    unsigned long int epochs = 0;
209    double alpha_new2;
210    double alpha_new1;
211    double u;
212    double v;
213
214    // Training loop
215    while(choose(E)) {
216      bounds(u,v);       
217      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
218                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
219                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
220     
221      double alpha_old1=alpha_(sample_.value_first());
222      double alpha_old2=alpha_(sample_.value_second());
223      alpha_new2 = ( alpha_(sample_.value_second()) + 
224                     target(sample_.value_second())*
225                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
226     
227      if (alpha_new2 > v)
228        alpha_new2 = v;
229      else if (alpha_new2<u)
230        alpha_new2 = u;
231     
232      // Updating the alphas
233      // if alpha is 'zero' make the sample a non-support vector
234      if (alpha_new2 < tolerance_){
235        sample_.nsv_second();
236      }
237      else{
238        sample_.sv_second();
239      }
240     
241     
242      alpha_new1 = (alpha_(sample_.value_first()) + 
243                    (target(sample_.value_first()) * 
244                     target(sample_.value_second()) * 
245                     (alpha_(sample_.value_second()) - alpha_new2) ));
246           
247      // if alpha is 'zero' make the sample a non-support vector
248      if (alpha_new1 < tolerance_){
249        sample_.nsv_first();
250      }
251      else
252        sample_.sv_first();
253     
254      alpha_(sample_.value_first()) = alpha_new1;
255      alpha_(sample_.value_second()) = alpha_new2;
256     
257      // update E vector
258      // Peter, perhaps one should only update SVs, but what happens in choose?
259      for (size_t i=0; i<E.size(); i++) {
260        E(i)+=( kernel_mod(i,sample_.value_first())*
261                target(sample_.value_first()) *
262                (alpha_new1-alpha_old1) );
263        E(i)+=( kernel_mod(i,sample_.value_second())*
264                target(sample_.value_second()) *
265                (alpha_new2-alpha_old2) );
266      }
267           
268      epochs++; 
269      if (epochs>max_epochs_){
270        std::cerr << "WARNING: SVM: maximal number of epochs reached.\n";
271        calculate_bias();
272        calculate_margin();
273        return false;
274      }
275    }
276    calculate_margin();
277    trained_ = calculate_bias();
278    return trained_;
279  }
280
281
282  bool SVM::choose(const theplu::yat::utility::vector& E)
283  {
284    // First check for violation among SVs
285    // E should be the same for all SVs
286    // Choose that pair having largest violation/difference.
287    sample_.update_second(0);
288    sample_.update_first(0);
289    if (sample_.nof_sv()>1){
290
291      double max = E(sample_(0));
292      double min = max;
293      for (size_t i=1; i<sample_.nof_sv(); i++){ 
294        assert(alpha_(sample_(i))>tolerance_);
295        if (E(sample_(i)) > max){
296          max = E(sample_(i));
297          sample_.update_second(i);
298        }
299        else if (E(sample_(i))<min){
300          min = E(sample_(i));
301          sample_.update_first(i);
302        }
303      }
304      assert(alpha_(sample_.value_first())>tolerance_);
305      assert(alpha_(sample_.value_second())>tolerance_);
306
307      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
308        return true;
309      }
310     
311      // If no violation check among non-support vectors
312      sample_.shuffle();
313      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
314        if (target_.binary(sample_(i))){
315          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
316            sample_.update_second(i);
317            return true;
318          }
319        }
320        else{
321          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
322            sample_.update_first(i);
323            return true;
324          }
325        }
326      }
327    }
328
329    // if no support vectors - special case
330    else{
331      // to avoid getting stuck we shuffle
332      sample_.shuffle();
333      for (size_t i=0; i<sample_.size(); i++) {
334        if (target(sample_(i))==1){
335          for (size_t j=0; j<sample_.size(); j++) {
336            if ( target(sample_(j))==-1 && 
337                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
338              sample_.update_first(i);
339              sample_.update_second(j);
340              return true;
341            }
342          }
343        }
344      }
345    }
346   
347    // If there is no violation then we should stop training
348    return false;
349
350  }
351 
352 
353  void SVM::bounds( double& u, double& v) const
354  {
355    if (target(sample_.value_first())!=target(sample_.value_second())) {
356      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
357        v = std::numeric_limits<double>::max();
358        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
359      }
360      else {
361        v = (std::numeric_limits<double>::max() - 
362             alpha_(sample_.value_first()) + 
363             alpha_(sample_.value_second()));
364        u = 0;
365      }
366    }
367    else {       
368      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
369           std::numeric_limits<double>::max()) {
370        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
371              std::numeric_limits<double>::max());
372        v =  std::numeric_limits<double>::max();   
373      }
374      else {
375        u = 0;
376        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
377      }
378    }
379  }
380 
381  bool SVM::calculate_bias(void)
382  {
383
384    // calculating output without bias
385    for (size_t i=0; i<output_.size(); i++) {
386      output_(i)=0;
387      for (size_t j=0; j<output_.size(); j++) 
388        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
389    }
390
391    if (!sample_.nof_sv()){
392      std::cerr << "SVM::train() error: " 
393                << "Cannot calculate bias because there is no support vector" 
394                << std::endl;
395      return false;
396    }
397
398    // For samples with alpha>0, we have: target*output=1-alpha/C
399    bias_=0;
400    for (size_t i=0; i<sample_.nof_sv(); i++) 
401      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
402                output_(sample_(i)) );
403    bias_=bias_/sample_.nof_sv();
404    for (size_t i=0; i<output_.size(); i++) 
405      output_(i) += bias_;
406     
407    return true;
408  }
409
410  void SVM::set_C(const double C)
411  {
412    C_inverse_ = 1/C;
413  }
414
415
416}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.