Ignore:
Timestamp:
Feb 26, 2008, 4:29:50 PM (14 years ago)
Author:
Markus Ringnér
Message:

Fixes #333

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r1157 r1160  
    2424
    2525#include "NBC.h"
    26 #include "DataLookup2D.h"
    2726#include "MatrixLookup.h"
    2827#include "MatrixLookupWeighted.h"
    2928#include "Target.h"
     29#include "yat/statistics/Averager.h"
    3030#include "yat/statistics/AveragerWeighted.h"
    3131#include "yat/utility/Matrix.h"
     
    120120
    121121
    122   void NBC::predict(const DataLookup2D& x,                   
     122  void NBC::predict(const MatrixLookup& ml,                     
    123123                    utility::Matrix& prediction) const
    124124  {   
    125     assert(x.rows()==sigma2_.rows());
    126     assert(x.rows()==centroids_.rows());
    127    
    128    
     125    assert(ml.rows()==sigma2_.rows());
     126    assert(ml.rows()==centroids_.rows());
    129127    // each row in prediction corresponds to a sample label (class)
    130     prediction.resize(centroids_.columns(), x.columns(), 0);
    131     // weighted calculation
    132     if (const MatrixLookupWeighted* mlw =
    133         dynamic_cast<const MatrixLookupWeighted*>(&x)) {
    134       // first calculate -lnP = sum ln_sigma_i + (x_i-m_i)^2/2sigma_i^2
    135       for (size_t label=0; label<centroids_.columns(); ++label) {
    136         double sum_log_sigma = sum_logsigma(label);
    137         for (size_t sample=0; sample<prediction.rows(); ++sample) {
    138           prediction(label,sample) = sum_log_sigma;
    139           for (size_t i=0; i<x.rows(); ++i)
    140             // taking care of NaN and missing training features
    141             if (mlw->weight(i, label) && !std::isnan(sigma2_(i, label))) {
    142               prediction(label, sample) += mlw->weight(i, label)*
    143                 std::pow(mlw->data(i, label)-centroids_(i, label),2)/
    144                 sigma2_(i, label);
    145             }
     128    prediction.resize(centroids_.columns(), ml.columns(), 0);
     129
     130    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
     131    for (size_t label=0; label<centroids_.columns(); ++label) {
     132      double sum_log_sigma = sum_logsigma(label);
     133      for (size_t sample=0; sample<prediction.rows(); ++sample) {
     134        prediction(label,sample) = sum_log_sigma;
     135        for (size_t i=0; i<ml.rows(); ++i)
     136          // Ignoring missing features
     137          if (!std::isnan(sigma2_(i, label)))
     138            prediction(label, sample) +=
     139              std::pow(ml(i, label)-centroids_(i, label),2)/
     140              sigma2_(i, label);
     141      }
     142    }
     143    standardize_lnP(prediction);
     144  }
     145
    146146     
    147         }
    148       }
    149     }
    150       // no weights
    151     else if (const MatrixLookup* ml = dynamic_cast<const MatrixLookup*>(&x)) {
    152       // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
    153       for (size_t label=0; label<centroids_.columns(); ++label) {
    154         double sum_log_sigma = sum_logsigma(label);
    155         for (size_t sample=0; sample<prediction.rows(); ++sample) {
    156           prediction(label,sample) = sum_log_sigma;
    157           for (size_t i=0; i<ml->rows(); ++i)
    158             // Ignoring missing features
    159             if (!std::isnan(sigma2_(i, label)))
    160               prediction(label, sample) +=
    161                 std::pow((*ml)(i, label)-centroids_(i, label),2)/
    162                 sigma2_(i, label);
    163         }
    164       }
    165     }
    166     else {
    167       std::string str =
    168         "Error in NBC::predict: DataLookup2D of unexpected class.";
    169       throw std::runtime_error(str);
    170     }
    171 
    172 
     147  void NBC::predict(const MatrixLookupWeighted& mlw,                   
     148                    utility::Matrix& prediction) const
     149  {   
     150    assert(mlw.rows()==sigma2_.rows());
     151    assert(mlw.rows()==centroids_.rows());
     152   
     153    // each row in prediction corresponds to a sample label (class)
     154    prediction.resize(centroids_.columns(), mlw.columns(), 0);
     155
     156    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
     157    for (size_t label=0; label<centroids_.columns(); ++label) {
     158      double sum_log_sigma = sum_logsigma(label);
     159      for (size_t sample=0; sample<prediction.rows(); ++sample) {
     160        prediction(label,sample) = sum_log_sigma;
     161        for (size_t i=0; i<mlw.rows(); ++i)
     162          // taking care of NaN and missing training features
     163          if (mlw.weight(i, label) && !std::isnan(sigma2_(i, label))) {
     164            prediction(label, sample) += mlw.weight(i, label)*
     165              std::pow(mlw.data(i, label)-centroids_(i, label),2)/
     166              sigma2_(i, label);
     167          }
     168       
     169      }
     170    }
     171    standardize_lnP(prediction);
     172  }
     173
     174  void NBC::standardize_lnP(utility::Matrix& prediction) const
     175  {
    173176    // -lnP might be a large number, in order to avoid out of bound
    174177    // problems when calculating P = exp(- -lnP), we centralize matrix
     
    177180    add(a, prediction.begin(), prediction.end());
    178181    prediction -= a.mean();
    179 
     182   
    180183    // exponentiate
    181184    for (size_t i=0; i<prediction.rows(); ++i)
    182185      for (size_t j=0; j<prediction.columns(); ++j)
    183186        prediction(i,j) = std::exp(prediction(i,j));
    184 
     187   
    185188    // normalize each row (label) to sum up to unity (probability)
    186189    for (size_t i=0; i<prediction.rows(); ++i){
Note: See TracChangeset for help on using the changeset viewer.