Changeset 813


Ignore:
Timestamp:
Mar 16, 2007, 8:30:02 PM (17 years ago)
Author:
Peter
Message:

Predict in NBC. Fixes #57

Location:
trunk
Files:
1 added
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/Makefile.am

    r806 r813  
    3030  ensemble_test feature_selection_test fileutil_test inputranker_test \
    3131  kernel_test kernel_lookup_test matrix_test matrix_lookup_test       \
     32  nbc_test \
    3233  ncc_test nni_test pca_test regression_test rnd_test roc_test \
    3334  score_test      \
     
    5758matrix_test_SOURCES = matrix_test.cc
    5859matrix_lookup_test_SOURCES = matrix_lookup_test.cc
     60nbc_test_SOURCES = nbc_test.cc
    5961ncc_test_SOURCES = ncc_test.cc
    6062nni_test_SOURCES = nni_test.cc
  • trunk/yat/classifier/NBC.cc

    r812 r813  
    9090          aver[target_(j)].add(data_(i,j),1.0);
    9191      }
    92       for (size_t j=0; target_.nof_classes(); ++j){
     92      assert(centroids_.columns()==target_.nof_classes());
     93      for (size_t j=0; j<target_.nof_classes(); ++j){
     94        assert(i<centroids_.rows());
     95        assert(j<centroids_.columns());
    9396        centroids_(i,j) = aver[j].mean();
     97        assert(i<sigma2_.rows());
     98        assert(j<sigma2_.columns());
    9499        sigma2_(i,j) = aver[j].variance();
    95100      }
     
    104109  {   
    105110    assert(data_.rows()==x.rows());
     111    assert(x.rows()==sigma2_.rows());
     112    assert(x.rows()==centroids_.rows());
     113
     114    const MatrixLookupWeighted* w =
     115      dynamic_cast<const MatrixLookupWeighted*>(&x);
    106116
    107117    // each row in prediction corresponds to a sample label (class)
    108118    prediction.resize(centroids_.columns(), x.columns(), 0);
    109119    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
    110     for (size_t label=0; label<prediction.columns(); ++label) {
     120    for (size_t label=0; label<centroids_.columns(); ++label) {
    111121      double sum_ln_sigma=0;
    112       for (size_t i=0; i<x.rows(); ++i)
     122      assert(label<sigma2_.columns());
     123      for (size_t i=0; i<x.rows(); ++i) {
     124        assert(i<sigma2_.rows());
    113125        sum_ln_sigma += std::log(sigma2_(i, label));
     126      }
    114127      sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2
    115128      for (size_t sample=0; sample<prediction.rows(); ++sample) {
    116129        for (size_t i=0; i<x.rows(); ++i) {
    117           prediction(label, sample) +=
    118             std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
     130          // weighted calculation
     131          if (w){
     132            // taking care of NaN
     133            if (w->weight(i, label)){
     134            prediction(label, sample) += w->weight(i, label)*
     135              std::pow(w->data(i, label)-centroids_(i, label),2)/
     136              sigma2_(i, label);
     137            }
     138          }
     139          // no weights
     140          else {
     141            prediction(label, sample) +=
     142              std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
     143          }
    119144        }
    120145      }
  • trunk/yat/classifier/NBC.h

    r812 r813  
    8383   
    8484    /**
    85        For each sample, calculate the probabilities the sample belong
    86        to the corresponding class.
     85       Each sample (column) in \a data is predicted and predictions
     86       are returned in the corresponding column in passed \a res. Each
     87       row in \a res corresponds to a class. The prediction is the
     88       estimated probability that sample belong to class \f$ j \f$
     89
     90       \f$ P_j = \frac{1}{Z}\prod_i{\frac{1}{\sigma_i}}
     91       \exp(\frac{w_i(x_i-\mu_i)^2}{\sigma_i^2})\f$, where \f$ \mu_i
     92       \f$ and \f$ \sigma_i^2 \f$ are the estimated mean and variance,
     93       respectively. If \a data is a MatrixLookup is equivalent to
     94       using all weight equal to unity.
    8795    */
    8896    void predict(const DataLookup2D& data, utility::matrix& res) const;
  • trunk/yat/utility/matrix.cc

    r810 r813  
    379379  {
    380380    assert(m_);
     381    assert(row<rows());
     382    assert(column<columns());
    381383    double* d=gsl_matrix_ptr(m_, row, column);
    382384    if (!d)
     
    388390  const double& matrix::operator()(size_t row, size_t column) const
    389391  {
     392    assert(row<rows());
     393    assert(column<columns());
    390394    const double* d=gsl_matrix_const_ptr(proxy_m_, row, column);
    391395    if (!d)
Note: See TracChangeset for help on using the changeset viewer.