Ignore:
Timestamp:
Mar 16, 2007, 8:30:02 PM (15 years ago)
Author:
Peter
Message:

Predict in NBC. Fixes #57

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r812 r813  
    9090          aver[target_(j)].add(data_(i,j),1.0);
    9191      }
    92       for (size_t j=0; target_.nof_classes(); ++j){
     92      assert(centroids_.columns()==target_.nof_classes());
     93      for (size_t j=0; j<target_.nof_classes(); ++j){
     94        assert(i<centroids_.rows());
     95        assert(j<centroids_.columns());
    9396        centroids_(i,j) = aver[j].mean();
     97        assert(i<sigma2_.rows());
     98        assert(j<sigma2_.columns());
    9499        sigma2_(i,j) = aver[j].variance();
    95100      }
     
    104109  {   
    105110    assert(data_.rows()==x.rows());
     111    assert(x.rows()==sigma2_.rows());
     112    assert(x.rows()==centroids_.rows());
     113
     114    const MatrixLookupWeighted* w =
     115      dynamic_cast<const MatrixLookupWeighted*>(&x);
    106116
    107117    // each row in prediction corresponds to a sample label (class)
    108118    prediction.resize(centroids_.columns(), x.columns(), 0);
    109119    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
    110     for (size_t label=0; label<prediction.columns(); ++label) {
     120    for (size_t label=0; label<centroids_.columns(); ++label) {
    111121      double sum_ln_sigma=0;
    112       for (size_t i=0; i<x.rows(); ++i)
     122      assert(label<sigma2_.columns());
     123      for (size_t i=0; i<x.rows(); ++i) {
     124        assert(i<sigma2_.rows());
    113125        sum_ln_sigma += std::log(sigma2_(i, label));
     126      }
    114127      sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2
    115128      for (size_t sample=0; sample<prediction.rows(); ++sample) {
    116129        for (size_t i=0; i<x.rows(); ++i) {
    117           prediction(label, sample) +=
    118             std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
     130          // weighted calculation
     131          if (w){
     132            // taking care of NaN
     133            if (w->weight(i, label)){
     134            prediction(label, sample) += w->weight(i, label)*
     135              std::pow(w->data(i, label)-centroids_(i, label),2)/
     136              sigma2_(i, label);
     137            }
     138          }
     139          // no weights
     140          else {
     141            prediction(label, sample) +=
     142              std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
     143          }
    119144        }
    120145      }
Note: See TracChangeset for help on using the changeset viewer.