1 | // $Id: SVM.h 37 2004-02-13 15:46:31Z peter $ |
---|

2 | |
---|

3 | #ifndef CS_CPP_TOOLS_SVM_H |
---|

4 | #define CS_CPP_TOOLS_SVM_H |
---|

5 | |
---|

6 | // C++ tools include |
---|

7 | ///////////////////// |
---|

8 | #include "vector.h" |
---|

9 | #include "matrix.h" |
---|

10 | |
---|

11 | |
---|

12 | // Standard C++ includes |
---|

13 | //////////////////////// |
---|

14 | |
---|

15 | |
---|

16 | namespace thep_cpp_tools |
---|

17 | { |
---|

18 | |
---|

19 | class SVM |
---|

20 | { |
---|

21 | |
---|

22 | public: |
---|

23 | /** |
---|

24 | Constructor taking the kernel matrix and the target vector as input |
---|

25 | */ |
---|

26 | SVM(const thep_gsl_api::matrix&, const thep_gsl_api::vector&); |
---|

27 | |
---|

28 | /** |
---|

29 | Training the SVM using the SMO algorithm |
---|

30 | */ |
---|

31 | void train(); |
---|

32 | |
---|

33 | /** |
---|

34 | Function will return \f$\alpha\f$ |
---|

35 | */ |
---|

36 | inline thep_gsl_api::vector get_alpha() const; |
---|

37 | |
---|

38 | /** |
---|

39 | Function will return the output from SVM |
---|

40 | */ |
---|

41 | inline thep_gsl_api::vector get_output() const; |
---|

42 | |
---|

43 | |
---|

44 | |
---|

45 | private: |
---|

46 | bool trained_; |
---|

47 | thep_gsl_api::matrix kernel_; |
---|

48 | thep_gsl_api::vector target_; |
---|

49 | thep_gsl_api::vector alpha_; |
---|

50 | double bias_; |
---|

51 | |
---|

52 | /** |
---|

53 | Private function that determines when to stop the training. |
---|

54 | The test is done in two steps. First, we check that the |
---|

55 | functional margin is at least 2 - epsilon. Second, we check |
---|

56 | that the gap between the primal and the dual object is less |
---|

57 | than epsilon. |
---|

58 | |
---|

59 | */ |
---|

60 | bool SVM::stop(const thep_gsl_api::vector& target_, |
---|

61 | const thep_gsl_api::matrix& kernel_, |
---|

62 | const thep_gsl_api::vector& alpha_); |
---|

63 | }; |
---|

64 | |
---|

65 | // class SVM |
---|

66 | thep_gsl_api::vector SVM::get_alpha() const |
---|

67 | { |
---|

68 | return alpha_; |
---|

69 | } |
---|

70 | |
---|

71 | thep_gsl_api::vector SVM::get_output() const |
---|

72 | { |
---|

73 | thep_gsl_api::vector bias(target_.size(), false, false); |
---|

74 | bias.set_all(bias_); |
---|

75 | return kernel_ * alpha_.mul_elements(target_) + bias; |
---|

76 | |
---|

77 | } |
---|

78 | }; // namespace thep_c++_tools |
---|

79 | |
---|

80 | #endif |
---|

81 | |
---|