source: trunk/yat/classifier/SVM.cc @ 1098

Last change on this file since 1098 was 1098, checked in by Jari Häkkinen, 14 years ago

Fixes #299. Memory leak in matrix was found and removed.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.9 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "SVM.h"
27#include "DataLookup2D.h"
28#include "Target.h"
29#include "yat/random/random.h"
30#include "yat/statistics/Averager.h"
31#include "yat/utility/matrix.h"
32#include "yat/utility/vector.h"
33
34#include <algorithm>
35#include <cassert>
36#include <cctype>
37#include <cmath>
38#include <limits>
39#include <sstream>
40#include <stdexcept>
41#include <utility>
42#include <vector>
43
44namespace theplu {
45namespace yat {
46namespace classifier { 
47
48  SVM::SVM(const KernelLookup& kernel, const Target& target)
49    : alpha_(target.size(),0),
50      bias_(0),
51      C_inverse_(0),
52      kernel_(&kernel),
53      margin_(0),
54      max_epochs_(100000),
55      output_(target.size(),0),
56      owner_(false),
57      sample_(target.size()),
58      target_(target),
59      trained_(false),
60      tolerance_(0.00000001)
61  {
62#ifndef NDEBUG
63    assert(kernel.columns()==kernel.rows());
64    assert(kernel.columns()==alpha_.size());
65    for (size_t i=0; i<alpha_.size(); i++) 
66      for (size_t j=0; j<alpha_.size(); j++)
67        assert(kernel(i,j)==kernel(j,i));
68    for (size_t i=0; i<alpha_.size(); i++) 
69      for (size_t j=0; j<alpha_.size(); j++) 
70        assert((*kernel_)(i,j)==(*kernel_)(j,i));
71    for (size_t i = 0; i<kernel_->rows(); i++)
72      for (size_t j = 0; j<kernel_->columns(); j++)
73        if (std::isnan((*kernel_)(i,j)))
74          std::cerr << "SVM: Found nan in kernel: " << i << " " 
75                    << j << std::endl;
76#endif       
77  }
78
79  SVM::~SVM()
80  {
81    if (owner_)
82      delete kernel_;
83  }
84
85  const utility::vector& SVM::alpha(void) const
86  {
87    return alpha_;
88  }
89
90  double SVM::C(void) const
91  {
92    return 1.0/C_inverse_;
93  }
94
95  void SVM::calculate_margin(void)
96  {
97    margin_ = 0;
98    for(size_t i = 0; i<alpha_.size(); ++i){
99      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
100      for(size_t j = i+1; j<alpha_.size(); ++j)
101        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
102    }
103  }
104
105  const DataLookup2D& SVM::data(void) const
106  {
107    return *kernel_;
108  }
109
110
111  double SVM::kernel_mod(const size_t i, const size_t j) const
112  {
113    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
114  }
115
116  SVM* SVM::make_classifier(const DataLookup2D& data,
117                            const Target& target) const
118  {
119    SVM* sc=0;
120    try {
121      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
122      assert(data.rows()==data.columns());
123      assert(data.columns()==target.size());
124      sc = new SVM(kernel,target);
125      //Copy those variables possible to modify from outside
126      sc->set_C(this->C());
127      sc->max_epochs(max_epochs());
128    }
129    catch (std::bad_cast) {
130      std::string str = 
131        "Error in SVM::make_classifier: DataLookup2D of unexpected class.";
132      throw std::runtime_error(str);
133    }
134 
135    return sc;
136  }
137
138  long int SVM::max_epochs(void) const
139  {
140    return max_epochs_;
141  }
142
143
144  void SVM::max_epochs(long int n)
145  {
146    max_epochs_=n;
147  }
148
149
150  const theplu::yat::utility::vector& SVM::output(void) const
151  {
152    return output_;
153  }
154
155  void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
156  {
157    try {
158      const KernelLookup& input_kernel =dynamic_cast<const KernelLookup&>(input);
159      assert(input.rows()==alpha_.size());
160      prediction.resize(2,input.columns(),0);
161      for (size_t i = 0; i<input.columns(); i++){
162        for (size_t j = 0; j<input.rows(); j++){
163          prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
164          assert(target(j));
165        }
166        prediction(0,i) = margin_ * (prediction(0,i) + bias_);
167      }
168     
169      for (size_t i = 0; i<prediction.columns(); i++)
170        prediction(1,i) = -prediction(0,i);
171    }
172    catch (std::bad_cast) {
173      std::string str = 
174        "Error in SVM::predict: DataLookup2D of unexpected class.";
175      throw std::runtime_error(str);
176    }
177   
178  }
179
180  double SVM::predict(const DataLookup1D& x) const
181  {
182    double y=0;
183    for (size_t i=0; i<alpha_.size(); i++)
184      y += alpha_(i)*target_(i)*kernel_->element(x,i);
185
186    return margin_*(y+bias_);
187  }
188
189  double SVM::predict(const DataLookupWeighted1D& x) const
190  {
191    double y=0;
192    for (size_t i=0; i<alpha_.size(); i++)
193      y += alpha_(i)*target_(i)*kernel_->element(x,i);
194
195    return margin_*(y+bias_);
196  }
197
198  void SVM::reset(void)
199  {
200    trained_=false;
201    alpha_ = utility::vector(target_.size(), 0);
202  }
203
204  int SVM::target(size_t i) const
205  {
206    return target_.binary(i) ? 1 : -1;
207  }
208
209  void SVM::train(void) 
210  {
211    // initializing variables for optimization
212    assert(target_.size()==kernel_->rows());
213    assert(target_.size()==alpha_.size());
214
215    sample_.init(alpha_,tolerance_);
216    utility::vector   E(target_.size(),0);
217    for (size_t i=0; i<E.size(); i++) {
218      for (size_t j=0; j<E.size(); j++) 
219        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
220      E(i)-=target(i);
221    }
222    assert(target_.size()==E.size());
223    assert(target_.size()==sample_.size());
224
225    unsigned long int epochs = 0;
226    double alpha_new2;
227    double alpha_new1;
228    double u;
229    double v;
230
231    // Training loop
232    while(choose(E)) {
233      bounds(u,v);       
234      double k = ( kernel_mod(sample_.value_first(), sample_.value_first()) + 
235                   kernel_mod(sample_.value_second(), sample_.value_second()) - 
236                   2*kernel_mod(sample_.value_first(), sample_.value_second()));
237     
238      double alpha_old1=alpha_(sample_.value_first());
239      double alpha_old2=alpha_(sample_.value_second());
240      alpha_new2 = ( alpha_(sample_.value_second()) + 
241                     target(sample_.value_second())*
242                     ( E(sample_.value_first())-E(sample_.value_second()) )/k );
243     
244      if (alpha_new2 > v)
245        alpha_new2 = v;
246      else if (alpha_new2<u)
247        alpha_new2 = u;
248     
249      // Updating the alphas
250      // if alpha is 'zero' make the sample a non-support vector
251      if (alpha_new2 < tolerance_){
252        sample_.nsv_second();
253      }
254      else{
255        sample_.sv_second();
256      }
257     
258     
259      alpha_new1 = (alpha_(sample_.value_first()) + 
260                    (target(sample_.value_first()) * 
261                     target(sample_.value_second()) * 
262                     (alpha_(sample_.value_second()) - alpha_new2) ));
263           
264      // if alpha is 'zero' make the sample a non-support vector
265      if (alpha_new1 < tolerance_){
266        sample_.nsv_first();
267      }
268      else
269        sample_.sv_first();
270     
271      alpha_(sample_.value_first()) = alpha_new1;
272      alpha_(sample_.value_second()) = alpha_new2;
273     
274      // update E vector
275      // Peter, perhaps one should only update SVs, but what happens in choose?
276      for (size_t i=0; i<E.size(); i++) {
277        E(i)+=( kernel_mod(i,sample_.value_first())*
278                target(sample_.value_first()) *
279                (alpha_new1-alpha_old1) );
280        E(i)+=( kernel_mod(i,sample_.value_second())*
281                target(sample_.value_second()) *
282                (alpha_new2-alpha_old2) );
283      }
284           
285      epochs++; 
286      if (epochs>max_epochs_){
287        throw std::runtime_error("SVM: maximal number of epochs reached.");
288      }
289    }
290    calculate_margin();
291    calculate_bias();
292    trained_ = true;
293  }
294
295
296  bool SVM::choose(const theplu::yat::utility::vector& E)
297  {
298    // First check for violation among SVs
299    // E should be the same for all SVs
300    // Choose that pair having largest violation/difference.
301    sample_.update_second(0);
302    sample_.update_first(0);
303    if (sample_.nof_sv()>1){
304
305      double max = E(sample_(0));
306      double min = max;
307      for (size_t i=1; i<sample_.nof_sv(); i++){ 
308        assert(alpha_(sample_(i))>tolerance_);
309        if (E(sample_(i)) > max){
310          max = E(sample_(i));
311          sample_.update_second(i);
312        }
313        else if (E(sample_(i))<min){
314          min = E(sample_(i));
315          sample_.update_first(i);
316        }
317      }
318      assert(alpha_(sample_.value_first())>tolerance_);
319      assert(alpha_(sample_.value_second())>tolerance_);
320
321      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
322        return true;
323      }
324     
325      // If no violation check among non-support vectors
326      sample_.shuffle();
327      for (size_t i=sample_.nof_sv(); i<sample_.size();i++){
328        if (target_.binary(sample_(i))){
329          if(E(sample_(i)) < E(sample_.value_first()) - 2*tolerance_){
330            sample_.update_second(i);
331            return true;
332          }
333        }
334        else{
335          if(E(sample_(i)) > E(sample_.value_second()) + 2*tolerance_){
336            sample_.update_first(i);
337            return true;
338          }
339        }
340      }
341    }
342
343    // if no support vectors - special case
344    else{
345      // to avoid getting stuck we shuffle
346      sample_.shuffle();
347      for (size_t i=0; i<sample_.size(); i++) {
348        if (target(sample_(i))==1){
349          for (size_t j=0; j<sample_.size(); j++) {
350            if ( target(sample_(j))==-1 && 
351                 E(sample_(i)) < E(sample_(j))+2*tolerance_ ){
352              sample_.update_first(i);
353              sample_.update_second(j);
354              return true;
355            }
356          }
357        }
358      }
359    }
360   
361    // If there is no violation then we should stop training
362    return false;
363
364  }
365 
366 
367  void SVM::bounds( double& u, double& v) const
368  {
369    if (target(sample_.value_first())!=target(sample_.value_second())) {
370      if (alpha_(sample_.value_second()) > alpha_(sample_.value_first())) {
371        v = std::numeric_limits<double>::max();
372        u = alpha_(sample_.value_second()) - alpha_(sample_.value_first());
373      }
374      else {
375        v = (std::numeric_limits<double>::max() - 
376             alpha_(sample_.value_first()) + 
377             alpha_(sample_.value_second()));
378        u = 0;
379      }
380    }
381    else {       
382      if (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) > 
383           std::numeric_limits<double>::max()) {
384        u = (alpha_(sample_.value_second()) + alpha_(sample_.value_first()) - 
385              std::numeric_limits<double>::max());
386        v =  std::numeric_limits<double>::max();   
387      }
388      else {
389        u = 0;
390        v = alpha_(sample_.value_first()) + alpha_(sample_.value_second());
391      }
392    }
393  }
394 
395  void SVM::calculate_bias(void)
396  {
397
398    // calculating output without bias
399    for (size_t i=0; i<output_.size(); i++) {
400      output_(i)=0;
401      for (size_t j=0; j<output_.size(); j++) 
402        output_(i)+=alpha_(j)*target(j) * (*kernel_)(i,j);
403    }
404
405    if (!sample_.nof_sv()){
406      std::stringstream ss;
407      ss << "yat::classifier::SVM::train() error: " 
408         << "Cannot calculate bias because there is no support vector"; 
409      throw std::runtime_error(ss.str());
410    }
411
412    // For samples with alpha>0, we have: target*output=1-alpha/C
413    bias_=0;
414    for (size_t i=0; i<sample_.nof_sv(); i++) 
415      bias_+= ( target(sample_(i)) * (1-alpha_(sample_(i))*C_inverse_) - 
416                output_(sample_(i)) );
417    bias_=bias_/sample_.nof_sv();
418    for (size_t i=0; i<output_.size(); i++) 
419      output_(i) += bias_;
420  }
421
422  void SVM::set_C(const double C)
423  {
424    C_inverse_ = 1/C;
425  }
426
427}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.