Changeset 812 for trunk/yat/classifier/NBC.cc
 Timestamp:
 Mar 16, 2007, 2:02:07 AM (16 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/NBC.cc
r808 r812 31 31 32 32 #include <cassert> 33 #include <cmath> 33 34 #include <vector> 34 35 … … 74 75 bool NBC::train() 75 76 { 76 sigma2_=centroids_=utility::matrix(data_.rows(), target_.nof_classes()); 77 sigma2_.resize(data_.rows(), target_.nof_classes()); 78 centroids_.resize(data_.rows(), target_.nof_classes()); 77 79 utility::matrix nof_in_class(data_.rows(), target_.nof_classes()); 78 79 80 80 81 for(size_t i=0; i<data_.rows(); ++i) { … … 99 100 100 101 101 void NBC::predict(const DataLookup2D& input,102 void NBC::predict(const DataLookup2D& x, 102 103 utility::matrix& prediction) const 103 104 { 104 std::cerr << "NBC::predict not implemented\n"; 105 exit(1); 106 assert(data_.rows()==input.rows()); 105 assert(data_.rows()==x.rows()); 107 106 108 // utility 109 //for (size_t i=0; 110 111 112 prediction = utility::matrix(centroids_.columns(),input.columns()); 113 for (size_t c=0; c<centroids_.columns(); ++c) { 107 // each row in prediction corresponds to a sample label (class) 108 prediction.resize(centroids_.columns(), x.columns(), 0); 109 // first calculate lnP = sum sigma_i + (x_im_i)^2/2sigma_i^2 110 for (size_t label=0; label<prediction.columns(); ++label) { 114 111 double sum_ln_sigma=0; 115 for (size_t i=0; i<sigma2_.rows(); ++i) 116 sum_ln_sigma += log(sigma2_(i,c)); 117 sum_ln_sigma /= 2; 118 119 for (size_t s=0; s<input.columns(); ++s) { 120 // lnp = sum{ln(sigma_i)} + sum{(x_im_i)^2/(2sigma_i)} 121 prediction(c,s) = sum_ln_sigma; 122 for (size_t i=0; i<input.columns(); ++i) { 123 prediction(c,s) += std::pow(input(i,s)centroids_(i,c),2)/sigma2_(i,c); 112 for (size_t i=0; i<x.rows(); ++i) 113 sum_ln_sigma += std::log(sigma2_(i, label)); 114 sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2 115 for (size_t sample=0; sample<prediction.rows(); ++sample) { 116 for (size_t i=0; i<x.rows(); ++i) { 117 prediction(label, sample) += 118 std::pow(x(i, label)centroids_(i, label),2)/sigma2_(i, label); 124 119 } 125 120 } 126 121 } 127 // exponentiate and normalize 122 123 // lnP might be a large number, in order to avoid out of bound 124 // problems when calculating P = exp( lnP), we centralize matrix 125 // by adding a constant. 126 double m=0; 127 for (size_t i=0; i<prediction.rows(); ++i) 128 for (size_t j=0; j<prediction.columns(); ++j) 129 m+=prediction(i,j); 130 prediction = m/prediction.rows()/prediction.columns(); 131 132 // exponentiate 133 for (size_t i=0; i<prediction.rows(); ++i) 134 for (size_t j=0; j<prediction.columns(); ++j) 135 prediction(i,j) = std::exp(prediction(i,j)); 136 137 // normalize each row (label) to sum up to unity (probability) 138 for (size_t i=0; i<prediction.rows(); ++i) 139 utility::vector(prediction,i) *= 140 1.0/utility::sum(utility::vector(prediction,i)); 141 128 142 } 129 143
Note: See TracChangeset
for help on using the changeset viewer.