Changeset 813 for trunk/yat/classifier/NBC.cc
 Timestamp:
 Mar 16, 2007, 8:30:02 PM (16 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/NBC.cc
r812 r813 90 90 aver[target_(j)].add(data_(i,j),1.0); 91 91 } 92 for (size_t j=0; target_.nof_classes(); ++j){ 92 assert(centroids_.columns()==target_.nof_classes()); 93 for (size_t j=0; j<target_.nof_classes(); ++j){ 94 assert(i<centroids_.rows()); 95 assert(j<centroids_.columns()); 93 96 centroids_(i,j) = aver[j].mean(); 97 assert(i<sigma2_.rows()); 98 assert(j<sigma2_.columns()); 94 99 sigma2_(i,j) = aver[j].variance(); 95 100 } … … 104 109 { 105 110 assert(data_.rows()==x.rows()); 111 assert(x.rows()==sigma2_.rows()); 112 assert(x.rows()==centroids_.rows()); 113 114 const MatrixLookupWeighted* w = 115 dynamic_cast<const MatrixLookupWeighted*>(&x); 106 116 107 117 // each row in prediction corresponds to a sample label (class) 108 118 prediction.resize(centroids_.columns(), x.columns(), 0); 109 119 // first calculate lnP = sum sigma_i + (x_im_i)^2/2sigma_i^2 110 for (size_t label=0; label< prediction.columns(); ++label) {120 for (size_t label=0; label<centroids_.columns(); ++label) { 111 121 double sum_ln_sigma=0; 112 for (size_t i=0; i<x.rows(); ++i) 122 assert(label<sigma2_.columns()); 123 for (size_t i=0; i<x.rows(); ++i) { 124 assert(i<sigma2_.rows()); 113 125 sum_ln_sigma += std::log(sigma2_(i, label)); 126 } 114 127 sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2 115 128 for (size_t sample=0; sample<prediction.rows(); ++sample) { 116 129 for (size_t i=0; i<x.rows(); ++i) { 117 prediction(label, sample) += 118 std::pow(x(i, label)centroids_(i, label),2)/sigma2_(i, label); 130 // weighted calculation 131 if (w){ 132 // taking care of NaN 133 if (w>weight(i, label)){ 134 prediction(label, sample) += w>weight(i, label)* 135 std::pow(w>data(i, label)centroids_(i, label),2)/ 136 sigma2_(i, label); 137 } 138 } 139 // no weights 140 else { 141 prediction(label, sample) += 142 std::pow(x(i, label)centroids_(i, label),2)/sigma2_(i, label); 143 } 119 144 } 120 145 }
Note: See TracChangeset
for help on using the changeset viewer.