source: trunk/yat/classifier/SVM.cc @ 792

Last change on this file since 792 was 792, checked in by Jari Häkkinen, 15 years ago

Addresses #193. matrix now works as outlined in the ticket
discussion. Added support for const views. Added a clone function that
facilitates resizing of matrices. clone is needed since assignement
operator functionality is changed.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.6 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "SVM.h"
25#include "DataLookup2D.h"
26#include "Target.h"
27#include "yat/random/random.h"
28#include "yat/statistics/Averager.h"
29#include "yat/utility/matrix.h"
30#include "yat/utility/vector.h"
31
32#include <algorithm>
33#include <cassert>
34#include <cctype>
35#include <cmath>
36#include <limits>
37#include <utility>
38#include <vector>
39
40namespace theplu {
41namespace yat {
42namespace classifier { 
43
44  SVM::SVM(const KernelLookup& kernel, const Target& target)
45    : SupervisedClassifier(target),
46      alpha_(target.size(),0),
47      bias_(0),
48      C_inverse_(0),
49      kernel_(&kernel),
50      margin_(0),
51      max_epochs_(100000),
52      output_(target.size(),0),
53      owner_(false),
54      sample_(target.size()),
55      trained_(false),
56      tolerance_(0.00000001)
57  {
58#ifndef NDEBUG
59    assert(kernel.columns()==kernel.rows());
60    assert(kernel.columns()==alpha_.size());
61    for (size_t i=0; i<alpha_.size(); i++) 
62      for (size_t j=0; j<alpha_.size(); j++)
63        assert(kernel(i,j)==kernel(j,i));
64    for (size_t i=0; i<alpha_.size(); i++) 
65      for (size_t j=0; j<alpha_.size(); j++) 
66        assert((*kernel_)(i,j)==(*kernel_)(j,i));
67    for (size_t i = 0; i<kernel_->rows(); i++)
68      for (size_t j = 0; j<kernel_->columns(); j++)
69        if (std::isnan((*kernel_)(i,j)))
70          std::cerr << "SVM: Found nan in kernel: " << i << " " 
71                    << j << std::endl;
72#endif       
73  }
74
75  SVM::~SVM()
76  {
77    if (owner_)
78      delete kernel_;
79  }
80
81  const utility::vector& SVM::alpha(void) const
82  {
83    return alpha_;
84  }
85
86  double SVM::C(void) const
87  {
88    return 1.0/C_inverse_;
89  }
90
91  void SVM::calculate_margin(void)
92  {
93    margin_ = 0;
94    for(size_t i = 0; i<alpha_.size(); ++i){
95      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
96      for(size_t j = i+1; j<alpha_.size(); ++j)
97        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
98    }
99  }
100
101  const DataLookup2D& SVM::data(void) const
102  {
103    return *kernel_;
104  }
105
106
107  double SVM::kernel_mod(const size_t i, const size_t j) const
108  {
109    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
110  }
111
112  SupervisedClassifier* SVM::make_classifier(const DataLookup2D& data,
113                                             const Target& target) const
114  {
115    SVM* sc=0;
116    try {
117      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
118      assert(data.rows()==data.columns());
119      assert(data.columns()==target.size());
120      sc = new SVM(kernel,target);
121
122    //Copy those variables possible to modify from outside
123    // Peter, in particular C
124    }
125    catch (std::bad_cast) {
126      std::cerr << "Warning: SVM::make_classifier only takes KernelLookup" 
127                << std::endl;
128    }
129    return sc;
130  }
131
132  long int SVM::max_epochs(void) const
133  {
134    return max_epochs_;
135  }
136
137  const theplu::yat::utility::vector& SVM::output(void) const
138  {
139    return output_;
140  }
141
142  void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
143  {
144    // Peter, should check success of dynamic_cast
145    const KernelLookup& input_kernel = dynamic_cast<const KernelLookup&>(input);
146
147    assert(input.rows()==alpha_.size());
148    prediction.clone(utility::matrix(2,input.columns(),0));
149    for (size_t i = 0; i<input.columns(); i++){
150      for (size_t j = 0; j<input.rows(); j++){
151        prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
152        assert(target(j));
153      }
154      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
155    }
156
157    for (size_t i = 0; i<prediction.columns(); i++)
158      prediction(1,i) = -prediction(0,i);
159   
160    assert(prediction(0,0));
161  }
162
163  double SVM::predict(const DataLookup1D& x) const
164  {
165    double y=0;
166    for (size_t i=0; i<alpha_.size(); i++)
167      y += alpha_(i)*target_(i)*kernel_->element(x,i);
168
169    return margin_*(y+bias_);
170  }
171
172  double SVM::predict(const DataLookupWeighted1D& x) const
173  {
174    double y=0;
175    for (size_t i=0; i<alpha_.size(); i++)
176      y += alpha_(i)*target_(i)*kernel_->element(x,i);
177
178    return margin_*(y+bias_);
179  }
180
181  void SVM::reset(void)
182  {
183    trained_=false;
184    alpha_.clone(utility::vector(target_.size(), 0));
185  }
186
187  int SVM::target(size_t i) const
188  {
189    return target_.binary(i) ? 1 : -1;
190  }
191
192  bool SVM::train(void) 
193  {
194    // initializing variables for optimization
195    assert(target_.size()==kernel_->rows());
196    assert(target_.size()==alpha_.size());
197
198    sample_.init(alpha_,tolerance_);
199    utility::vector   E(target_.size(),0);
200    for (size_t i=0; i<E.size(); i++) {
201      for (size_t j=0; j<E.size(); j++) 
202        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
203      E(i)-=target(i);
204    }
205    assert(target_.size()==E.size());
206    assert(target_.size()==sample_.size());
207
208    unsigned long int epochs = 0;
209    double alpha_new2;
210    double alpha_new1;
211    double u;
212    double v;
213
214    // Training loop
215    while(choose(E)) {
216      bounds(u,v);       
217      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
218                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
219                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
220     
221      double alpha_old1=alpha_(sample_.value_first());
222      double alpha_old2=alpha_(sample_.value_second());
223      alpha_new2 = ( alpha_(sample_.value_second()) + 
224                     target(sample_.value_second())*
225                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
226     
227      if (alpha_new2 > v)
228        alpha_new2 = v;
229      else if (alpha_new2<u)
230        alpha_new2 = u;
231     
232      // Updating the alphas
233      // if alpha is 'zero' make the sample a non-support vector
234      if (alpha_new2 < tolerance_){
235        sample_.nsv_second();
236      }
237      else{
238        sample_.sv_second();
239      }
240     
241     
242      alpha_new1 = (alpha_(sample_.value_first()) + 
243                    (target(sample_.value_first()) * 
244                     target(sample_.value_second()) * 
245                     (alpha_(sample_.value_second()) - alpha_new2) ));
246           
247      // if alpha is 'zero' make the sample a non-support vector
248      if (alpha_new1 < tolerance_){
249        sample_.nsv_first();
250      }
251      else
252        sample_.sv_first();
253     
254      alpha_(sample_.value_first()) = alpha_new1;
255      alpha_(sample_.value_second()) = alpha_new2;
256     
257      // update E vector
258      // Peter, perhaps one should only update SVs, but what happens in choose?
259      for (size_t i=0; i<E.size(); i++) {
260        E(i)+=( kernel_mod(i,sample_.value_first())*
261                target(sample_.value_first()) *
262                (alpha_new1-alpha_old1) );
263        E(i)+=( kernel_mod(i,sample_.value_second())*
264                target(sample_.value_second()) *
265                (alpha_new2-alpha_old2) );
266      }
267           
268      epochs++; 
269      if (epochs>max_epochs_){
270        std::cerr << "WARNING: SVM: maximal number of epochs reached.\n";
271        calculate_bias();
272        calculate_margin();
273        return false;
274      }
275    }
276    calculate_margin();
277    trained_ = calculate_bias();
278    return trained_;
279  }
280
281
282  bool SVM::choose(const theplu::yat::utility::vector& E)
283  {
284    // First check for violation among SVs
285    // E should be the same for all SVs
286    // Choose that pair having largest violation/difference.
287    sample_.update_second(0);
288    sample_.update_first(0);
289    if (sample_.nof_sv()>1){
290
291      double max = E(sample_(0));
292      double min = max;
293      for (size_t i=1; i<sample_.nof_sv(); i++){ 
294        assert(alpha_(sample_(i))>tolerance_);
295        if (E(sample_(i)) > max){
296          max = E(sample_(i));
297          sample_.update_second(i);
298        }
299        else if (E(sample_(i))<min){
300          min = E(sample_(i));
301          sample_.update_first(i);
302        }
303      }
304      assert(alpha_(sample_.value_first())>tolerance_);
305      assert(alpha_(sample_.value_second())>tolerance_);
306
307      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
308        return true;
309      }
310     
311      // If no violation check among non-support vectors
312      sample_.shuffle();
313      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
314        if (target_.binary(sample_(i))){
315          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
316            sample_.update_second(i);
317            return true;
318          }
319        }
320        else{
321          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
322            sample_.update_first(i);
323            return true;
324          }
325        }
326      }
327    }
328
329    // if no support vectors - special case
330    else{
331      // to avoid getting stuck we shuffle
332      sample_.shuffle();
333      for (size_t i=0; i<sample_.size(); i++) {
334        if (target(sample_(i))==1){
335          for (size_t j=0; j<sample_.size(); j++) {
336            if ( target(sample_(j))==-1 && 
337                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
338              sample_.update_first(i);
339              sample_.update_second(j);
340              return true;
341            }
342          }
343        }
344      }
345    }
346   
347    // If there is no violation then we should stop training
348    return false;
349
350  }
351 
352 
353  void SVM::bounds( double& u, double& v) const
354  {
355    if (target(sample_.value_first())!=target(sample_.value_second())) {
356      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
357        v = std::numeric_limits<double>::max();
358        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
359      }
360      else {
361        v = (std::numeric_limits<double>::max() - 
362             alpha_(sample_.value_first()) + 
363             alpha_(sample_.value_second()));
364        u = 0;
365      }
366    }
367    else {       
368      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
369           std::numeric_limits<double>::max()) {
370        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
371              std::numeric_limits<double>::max());
372        v =  std::numeric_limits<double>::max();   
373      }
374      else {
375        u = 0;
376        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
377      }
378    }
379  }
380 
381  bool SVM::calculate_bias(void)
382  {
383
384    // calculating output without bias
385    for (size_t i=0; i<output_.size(); i++) {
386      output_(i)=0;
387      for (size_t j=0; j<output_.size(); j++) 
388        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
389    }
390
391    if (!sample_.nof_sv()){
392      std::cerr << "SVM::train() error: " 
393                << "Cannot calculate bias because there is no support vector" 
394                << std::endl;
395      return false;
396    }
397
398    // For samples with alpha>0, we have: target*output=1-alpha/C
399    bias_=0;
400    for (size_t i=0; i<sample_.nof_sv(); i++) 
401      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
402                output_(sample_(i)) );
403    bias_=bias_/sample_.nof_sv();
404    for (size_t i=0; i<output_.size(); i++) 
405      output_(i) += bias_;
406     
407    return true;
408  }
409
410  void SVM::set_C(const double C)
411  {
412    C_inverse_ = 1/C;
413  }
414
415}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.