1 | // $Id: SVM.h 37 2004-02-13 15:46:31Z peter $ |
---|
2 | |
---|
3 | #ifndef CS_CPP_TOOLS_SVM_H |
---|
4 | #define CS_CPP_TOOLS_SVM_H |
---|
5 | |
---|
6 | // C++ tools include |
---|
7 | ///////////////////// |
---|
8 | #include "vector.h" |
---|
9 | #include "matrix.h" |
---|
10 | |
---|
11 | |
---|
12 | // Standard C++ includes |
---|
13 | //////////////////////// |
---|
14 | |
---|
15 | |
---|
16 | namespace thep_cpp_tools |
---|
17 | { |
---|
18 | |
---|
19 | class SVM |
---|
20 | { |
---|
21 | |
---|
22 | public: |
---|
23 | /** |
---|
24 | Constructor taking the kernel matrix and the target vector as input |
---|
25 | */ |
---|
26 | SVM(const thep_gsl_api::matrix&, const thep_gsl_api::vector&); |
---|
27 | |
---|
28 | /** |
---|
29 | Training the SVM using the SMO algorithm |
---|
30 | */ |
---|
31 | void train(); |
---|
32 | |
---|
33 | /** |
---|
34 | Function will return \f$\alpha\f$ |
---|
35 | */ |
---|
36 | inline thep_gsl_api::vector get_alpha() const; |
---|
37 | |
---|
38 | /** |
---|
39 | Function will return the output from SVM |
---|
40 | */ |
---|
41 | inline thep_gsl_api::vector get_output() const; |
---|
42 | |
---|
43 | |
---|
44 | |
---|
45 | private: |
---|
46 | bool trained_; |
---|
47 | thep_gsl_api::matrix kernel_; |
---|
48 | thep_gsl_api::vector target_; |
---|
49 | thep_gsl_api::vector alpha_; |
---|
50 | double bias_; |
---|
51 | |
---|
52 | /** |
---|
53 | Private function that determines when to stop the training. |
---|
54 | The test is done in two steps. First, we check that the |
---|
55 | functional margin is at least 2 - epsilon. Second, we check |
---|
56 | that the gap between the primal and the dual object is less |
---|
57 | than epsilon. |
---|
58 | |
---|
59 | */ |
---|
60 | bool SVM::stop(const thep_gsl_api::vector& target_, |
---|
61 | const thep_gsl_api::matrix& kernel_, |
---|
62 | const thep_gsl_api::vector& alpha_); |
---|
63 | }; |
---|
64 | |
---|
65 | // class SVM |
---|
66 | thep_gsl_api::vector SVM::get_alpha() const |
---|
67 | { |
---|
68 | return alpha_; |
---|
69 | } |
---|
70 | |
---|
71 | thep_gsl_api::vector SVM::get_output() const |
---|
72 | { |
---|
73 | thep_gsl_api::vector bias(target_.size(), false, false); |
---|
74 | bias.set_all(bias_); |
---|
75 | return kernel_ * alpha_.mul_elements(target_) + bias; |
---|
76 | |
---|
77 | } |
---|
78 | }; // namespace thep_c++_tools |
---|
79 | |
---|
80 | #endif |
---|
81 | |
---|