Changeset 1100 for trunk/yat/classifier
- Timestamp:
- Feb 18, 2008, 5:37:50 AM (15 years ago)
- Location:
- trunk/yat/classifier
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/yat/classifier/SVM.cc
r1098 r1100 39 39 #include <sstream> 40 40 #include <stdexcept> 41 #include <string> 41 42 #include <utility> 42 43 #include <vector> … … 46 47 namespace classifier { 47 48 48 SVM::SVM(const KernelLookup& kernel, const Target& target) 49 : alpha_(target.size(),0), 50 bias_(0), 49 SVM::SVM(void) 50 : bias_(0), 51 51 C_inverse_(0), 52 kernel_( &kernel),52 kernel_(NULL), 53 53 margin_(0), 54 54 max_epochs_(100000), 55 output_(target.size(),0), 56 owner_(false), 57 sample_(target.size()), 58 target_(target), 59 trained_(false), 60 tolerance_(0.00000001) 61 { 62 #ifndef NDEBUG 63 assert(kernel.columns()==kernel.rows()); 64 assert(kernel.columns()==alpha_.size()); 65 for (size_t i=0; i<alpha_.size(); i++) 66 for (size_t j=0; j<alpha_.size(); j++) 67 assert(kernel(i,j)==kernel(j,i)); 68 for (size_t i=0; i<alpha_.size(); i++) 69 for (size_t j=0; j<alpha_.size(); j++) 70 assert((*kernel_)(i,j)==(*kernel_)(j,i)); 71 for (size_t i = 0; i<kernel_->rows(); i++) 72 for (size_t j = 0; j<kernel_->columns(); j++) 73 if (std::isnan((*kernel_)(i,j))) 74 std::cerr << "SVM: Found nan in kernel: " << i << " " 75 << j << std::endl; 76 #endif 77 } 55 tolerance_(0.00000001), 56 trained_(false) 57 { 58 } 59 78 60 79 61 SVM::~SVM() 80 62 { 81 if ( owner_)63 if (kernel_) 82 64 delete kernel_; 83 65 } 84 66 67 85 68 const utility::vector& SVM::alpha(void) const 86 69 { … … 88 71 } 89 72 73 90 74 double SVM::C(void) const 91 75 { 92 76 return 1.0/C_inverse_; 93 77 } 78 94 79 95 80 void SVM::calculate_margin(void) … … 103 88 } 104 89 90 105 91 const DataLookup2D& SVM::data(void) const 106 92 { … … 111 97 double SVM::kernel_mod(const size_t i, const size_t j) const 112 98 { 99 assert(kernel_); 100 assert(i<kernel_->rows()); 101 assert(i<kernel_->columns()); 113 102 return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_; 114 103 } 115 104 116 SVM* SVM::make_classifier(const DataLookup2D& data, 117 const Target& target) const 118 { 119 SVM* sc=0; 120 try { 121 const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data); 122 assert(data.rows()==data.columns()); 123 assert(data.columns()==target.size()); 124 sc = new SVM(kernel,target); 125 //Copy those variables possible to modify from outside 126 sc->set_C(this->C()); 127 sc->max_epochs(max_epochs()); 128 } 129 catch (std::bad_cast) { 130 std::string str = 131 "Error in SVM::make_classifier: DataLookup2D of unexpected class."; 132 throw std::runtime_error(str); 133 } 134 135 return sc; 136 } 105 106 SVM* SVM::make_classifier(void) const 107 { 108 return new SVM(*this); 109 } 110 137 111 138 112 long int SVM::max_epochs(void) const … … 196 170 } 197 171 198 void SVM::reset(void)199 {200 trained_=false;201 alpha_ = utility::vector(target_.size(), 0);202 }203 204 172 int SVM::target(size_t i) const 205 173 { 174 assert(i<target_.size()); 206 175 return target_.binary(i) ? 1 : -1; 207 176 } 208 177 209 void SVM::train(void) 210 { 178 void SVM::train(const KernelLookup& kernel, const Target& targ) 179 { 180 if (kernel_) 181 delete kernel_; 182 kernel_ = new KernelLookup(kernel); 183 target_ = targ; 184 185 alpha_ = utility::vector(targ.size(), 0.0); 186 output_ = utility::vector(targ.size(), 0.0); 211 187 // initializing variables for optimization 212 188 assert(target_.size()==kernel_->rows()); -
trunk/yat/classifier/SVM.h
r1087 r1100 57 57 public: 58 58 /// 59 /// Constructor taking the kernel and the target vector as 60 /// input. 61 /// 62 /// @note if the @a target or @a kernel 63 /// is destroyed the behaviour is undefined. 64 /// 65 SVM(const KernelLookup& kernel, const Target& target); 59 /// \brief Constructor 60 /// 61 SVM(void); 66 62 67 63 /// … … 73 69 74 70 /// 75 /// If DataLookup2D is not a KernelLookup a bad_cast exception is thrown.76 /// 77 SVM* make_classifier( const DataLookup2D&, const Target&) const;71 /// 72 /// 73 SVM* make_classifier(void) const; 78 74 79 75 /// … … 145 141 /// 146 142 double predict(const DataLookupWeighted1D& input) const; 147 148 ///149 /// @brief Function sets \f$ \alpha=0 \f$ and makes SVM untrained.150 ///151 void reset(void);152 143 153 144 /// … … 187 178 @return true if succesful 188 179 */ 189 void train( );180 void train(const KernelLookup& kernel, const Target& target); 190 181 191 182 … … 243 234 unsigned long int max_epochs_; 244 235 utility::vector output_; 245 bool owner_;246 236 SVindex sample_; 247 237 Target target_; 248 bool trained_;249 238 double tolerance_; 239 bool trained_; 250 240 251 241 }; -
trunk/yat/classifier/SVindex.cc
r1004 r1100 69 69 nof_sv_=0; 70 70 size_t nof_nsv=0; 71 vec_.resize(alpha.size()); 71 72 for (size_t i=0; i<alpha.size(); i++) 72 73 if (alpha(i)<tol){ -
trunk/yat/classifier/Target.cc
r1004 r1100 41 41 namespace yat { 42 42 namespace classifier { 43 44 Target::Target(void) 45 { 46 } 47 43 48 44 49 Target::Target(const std::vector<std::string>& label) -
trunk/yat/classifier/Target.h
r1000 r1100 42 42 /// @brief Class for containing sample labels. 43 43 /// 44 45 44 class Target 46 45 { 47 46 48 47 public: 48 /** 49 \brief default constructor 50 */ 51 Target(void); 52 49 53 /// 50 54 /// @brief Constructor creating target with @a labels
Note: See TracChangeset
for help on using the changeset viewer.