Changeset 813 for trunk/yat/classifier


Ignore:
Timestamp:
Mar 16, 2007, 8:30:02 PM (14 years ago)
Author:
Peter
Message:

Predict in NBC. Fixes #57

Location:
trunk/yat/classifier
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r812 r813  
    9090          aver[target_(j)].add(data_(i,j),1.0);
    9191      }
    92       for (size_t j=0; target_.nof_classes(); ++j){
     92      assert(centroids_.columns()==target_.nof_classes());
     93      for (size_t j=0; j<target_.nof_classes(); ++j){
     94        assert(i<centroids_.rows());
     95        assert(j<centroids_.columns());
    9396        centroids_(i,j) = aver[j].mean();
     97        assert(i<sigma2_.rows());
     98        assert(j<sigma2_.columns());
    9499        sigma2_(i,j) = aver[j].variance();
    95100      }
     
    104109  {   
    105110    assert(data_.rows()==x.rows());
     111    assert(x.rows()==sigma2_.rows());
     112    assert(x.rows()==centroids_.rows());
     113
     114    const MatrixLookupWeighted* w =
     115      dynamic_cast<const MatrixLookupWeighted*>(&x);
    106116
    107117    // each row in prediction corresponds to a sample label (class)
    108118    prediction.resize(centroids_.columns(), x.columns(), 0);
    109119    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
    110     for (size_t label=0; label<prediction.columns(); ++label) {
     120    for (size_t label=0; label<centroids_.columns(); ++label) {
    111121      double sum_ln_sigma=0;
    112       for (size_t i=0; i<x.rows(); ++i)
     122      assert(label<sigma2_.columns());
     123      for (size_t i=0; i<x.rows(); ++i) {
     124        assert(i<sigma2_.rows());
    113125        sum_ln_sigma += std::log(sigma2_(i, label));
     126      }
    114127      sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2
    115128      for (size_t sample=0; sample<prediction.rows(); ++sample) {
    116129        for (size_t i=0; i<x.rows(); ++i) {
    117           prediction(label, sample) +=
    118             std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
     130          // weighted calculation
     131          if (w){
     132            // taking care of NaN
     133            if (w->weight(i, label)){
     134            prediction(label, sample) += w->weight(i, label)*
     135              std::pow(w->data(i, label)-centroids_(i, label),2)/
     136              sigma2_(i, label);
     137            }
     138          }
     139          // no weights
     140          else {
     141            prediction(label, sample) +=
     142              std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
     143          }
    119144        }
    120145      }
  • trunk/yat/classifier/NBC.h

    r812 r813  
    8383   
    8484    /**
    85        For each sample, calculate the probabilities the sample belong
    86        to the corresponding class.
     85       Each sample (column) in \a data is predicted and predictions
     86       are returned in the corresponding column in passed \a res. Each
     87       row in \a res corresponds to a class. The prediction is the
     88       estimated probability that sample belong to class \f$ j \f$
     89
     90       \f$ P_j = \frac{1}{Z}\prod_i{\frac{1}{\sigma_i}}
     91       \exp(\frac{w_i(x_i-\mu_i)^2}{\sigma_i^2})\f$, where \f$ \mu_i
     92       \f$ and \f$ \sigma_i^2 \f$ are the estimated mean and variance,
     93       respectively. If \a data is a MatrixLookup is equivalent to
     94       using all weight equal to unity.
    8795    */
    8896    void predict(const DataLookup2D& data, utility::matrix& res) const;
Note: See TracChangeset for help on using the changeset viewer.