source: trunk/c++_tools/classifier/SVM.h @ 675

Last change on this file since 675 was 675, checked in by Jari Häkkinen, 17 years ago

References #83. Changing project name to yat. Compilation will fail in this revision.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 6.4 KB
Line 
1#ifndef _theplu_classifier_svm_
2#define _theplu_classifier_svm_
3
4// $Id$
5
6/*
7  Copyright (C) The authors contributing to this file.
8
9  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "yat/classifier/KernelLookup.h"
28#include "yat/classifier/SupervisedClassifier.h"
29#include "yat/classifier/SVindex.h"
30#include "yat/classifier/Target.h"
31#include "yat/utility/vector.h"
32
33#include <utility>
34#include <vector>
35
36
37namespace theplu {
38namespace classifier { 
39
40  class DataLookup2D;
41  ///
42  /// @brief Support Vector Machine
43  ///
44  ///
45  ///
46  /// Class for SVM using Keerthi's second modification of Platt's
47  /// Sequential Minimal Optimization. The SVM uses all data given for
48  /// training. If validation or testing is wanted this should be
49  /// taken care of outside (in the kernel).
50  ///   
51  class SVM : public SupervisedClassifier
52  {
53 
54  public:
55    ///
56    /// Constructor taking the kernel and the target vector as
57    /// input.
58    ///
59    /// @note if the @a target or @a kernel
60    /// is destroyed the behaviour is undefined.
61    ///
62    SVM(const KernelLookup& kernel, const Target& target);
63
64    ///
65    /// Destructor
66    ///
67    virtual ~SVM();
68
69    ///
70    /// If DataLookup2D is not a KernelLookup a bad_cast exception is thrown.
71    ///
72    SupervisedClassifier* 
73    make_classifier(const DataLookup2D&, const Target&) const;
74
75    ///
76    /// @return \f$ \alpha \f$
77    ///
78    inline const utility::vector& alpha(void) const { return alpha_; }
79
80    ///
81    /// The C-parameter is the balance term (see train()). A very
82    /// large C means the training will be focused on getting samples
83    /// correctly classified, with risk for overfitting and poor
84    /// generalisation. A too small C will result in a training where
85    /// misclassifications are not penalized. C is weighted with
86    /// respect to the size, so \f$ n_+C_+ = n_-C_- \f$, meaning a
87    /// misclassificaion of the smaller group is penalized
88    /// harder. This balance is equivalent to the one occuring for
89    /// regression with regularisation, or ANN-training with a
90    /// weight-decay term. Default is C set to infinity.
91    ///
92    /// @returns mean of vector \f$ C_i \f$
93    ///
94    inline double C(void) const { return 1/C_inverse_; }
95
96    ///
97    /// Default is max_epochs set to 10,000,000.
98    ///
99    /// @return number of maximal epochs
100    ///
101    inline long int max_epochs(void) const {return max_epochs_;}
102   
103    ///
104    /// The output is calculated as \f$ o_i = \sum \alpha_j t_j K_{ij}
105    /// + bias \f$, where \f$ t \f$ is the target.
106    ///
107    /// @return output
108    ///
109    inline const theplu::utility::vector& output(void) const { return output_; }
110
111    ///
112    /// Generate prediction @a predict from @a input. The prediction
113    /// is calculated as the output times the margin, i.e., geometric
114    /// distance from decision hyperplane: \f$ \frac{ \sum \alpha_j
115    /// t_j K_{ij} + bias}{w} \f$ The output has 2 rows. The first row
116    /// is for binary target true, and the second is for binary target
117    /// false. The second row is superfluous as it is the first row
118    /// negated. It exist just to be aligned with multi-class
119    /// SupervisedClassifiers. Each column in @a input and @a output
120    /// corresponds to a sample to predict. Each row in @a input
121    /// corresponds to a training sample, and more exactly row i in @a
122    /// input should correspond to row i in KernelLookup that was used
123    /// for training.
124    ///
125    void predict(const DataLookup2D& input, utility::matrix& predict) const;
126
127    ///
128    /// @return output times margin (i.e. geometric distance from
129    /// decision hyperplane) from data @a input
130    ///
131    double predict(const DataLookup1D& input) const;
132
133    ///
134    /// @return output times margin from data @a input with
135    /// corresponding @a weight
136    ///
137    double predict(const DataLookupWeighted1D& input) const;
138
139    ///
140    /// Function sets \f$ \alpha=0 \f$ and makes SVM untrained.
141    ///
142    inline void reset(void) 
143    { trained_=false; alpha_=utility::vector(target_.size(),0); }
144
145    ///
146    /// @brief sets the C-Parameter
147    ///
148    void set_C(const double);
149
150    /**
151       Training the SVM following Platt's SMO, with Keerti's
152       modifacation. Minimizing \f$ \frac{1}{2}\sum
153       y_iy_j\alpha_i\alpha_j(K_{ij}+\frac{1}{C_i}\delta_{ij}) \f$ ,
154       which corresponds to minimizing \f$ \sum w_i^2+\sum C_i\xi_i^2
155       \f$.
156    */
157    bool train();
158
159       
160     
161  private:
162    ///
163    /// Copy constructor. (not implemented)
164    ///
165    SVM(const SVM&);
166         
167    ///
168    /// Calculates bounds for alpha2
169    ///
170    void bounds(double&, double&) const;
171
172    ///
173    /// @brief calculates the bias term
174    ///
175    /// @return true if successful
176    ///
177    bool calculate_bias(void);
178
179    ///
180    /// Calculate margin that is inverse of w
181    ///
182    void calculate_margin(void);
183
184    ///
185    ///   Private function choosing which two elements that should be
186    ///   updated. First checking for the biggest violation (output - target =
187    ///   0) among support vectors (alpha!=0). If no violation was found check
188    ///   sequentially among the other samples. If no violation there as
189    ///   well training is completed
190    ///
191    ///  @return true if a pair of samples that violate the conditions
192    ///  can be found
193    ///
194    bool choose(const theplu::utility::vector&);
195
196    ///
197    /// @return kernel modified with diagonal term (soft margin)
198    ///
199    inline double kernel_mod(const size_t i, const size_t j) const 
200    { return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_; }
201   
202    /// @return 1 if i belong to binary target true else -1
203    inline int target(size_t i) const { return target_.binary(i) ? 1 : -1; }
204
205    utility::vector alpha_;
206    double bias_;
207    double C_inverse_;
208    const KernelLookup* kernel_; 
209    double margin_;
210    unsigned long int max_epochs_;
211    utility::vector output_;
212    bool owner_;
213    SVindex sample_;
214    bool trained_;
215    double tolerance_;
216
217  };
218
219}} // of namespace classifier and namespace theplu
220
221#endif
Note: See TracBrowser for help on using the repository browser.