Ignore:
Timestamp:
Oct 10, 2007, 6:49:39 PM (14 years ago)
Author:
Peter
Message:

Fixed so NBC and SVM are throwing when unexpected DataLookup2D is
apssed to make_classifier or predict.

Speeding up NBC::predict by separating weighted code from
unweighted. Also fixed some bugs in NBC.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r950 r959  
    120120    assert(x.rows()==centroids_.rows());
    121121
    122     const MatrixLookupWeighted* w =
    123       dynamic_cast<const MatrixLookupWeighted*>(&x);
     122   
    124123   
    125124    // each row in prediction corresponds to a sample label (class)
    126125    prediction.resize(centroids_.columns(), x.columns(), 0);
    127     // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
    128     for (size_t label=0; label<centroids_.columns(); ++label) {
    129       double sum_ln_sigma=0;
    130       assert(label<sigma2_.columns());
    131       for (size_t i=0; i<x.rows(); ++i) {
    132         assert(i<sigma2_.rows());
    133         sum_ln_sigma += std::log(sigma2_(i, label));
    134       }
    135       sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2
    136       for (size_t sample=0; sample<prediction.rows(); ++sample) {
    137         for (size_t i=0; i<x.rows(); ++i) {
    138           // weighted calculation
    139           if (w){
     126    // weighted calculation
     127    if (const MatrixLookupWeighted* mlw =
     128        dynamic_cast<const MatrixLookupWeighted*>(&x)) {
     129      // first calculate -lnP = sum ln_sigma_i + (x_i-m_i)^2/2sigma_i^2
     130      for (size_t label=0; label<centroids_.columns(); ++label) {
     131        double sum_log_sigma = sum_logsigma(label);
     132        for (size_t sample=0; sample<prediction.rows(); ++sample) {
     133          prediction(label,sample) = sum_log_sigma;
     134          for (size_t i=0; i<x.rows(); ++i)
    140135            // taking care of NaN
    141             if (w->weight(i, label)){
    142             prediction(label, sample) += w->weight(i, label)*
    143               std::pow(w->data(i, label)-centroids_(i, label),2)/
    144               sigma2_(i, label);
     136            if (mlw->weight(i, label)) {
     137              prediction(label, sample) += mlw->weight(i, label)*
     138                std::pow(mlw->data(i, label)-centroids_(i, label),2)/
     139                sigma2_(i, label);
    145140            }
    146           }
    147           // no weights
    148           else {
    149             prediction(label, sample) +=
    150               std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
    151           }
     141     
    152142        }
    153143      }
    154144    }
     145      // no weights
     146    else if (const MatrixLookup* ml = dynamic_cast<const MatrixLookup*>(&x)) {
     147      // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
     148      for (size_t label=0; label<centroids_.columns(); ++label) {
     149        double sum_log_sigma = sum_logsigma(label);
     150        for (size_t sample=0; sample<prediction.rows(); ++sample) {
     151          prediction(label,sample) = sum_log_sigma;
     152          for (size_t i=0; i<ml->rows(); ++i)
     153            prediction(label, sample) +=
     154              std::pow((*ml)(i, label)-centroids_(i, label),2)/sigma2_(i, label);
     155        }
     156      }
     157    }
     158    else {
     159      std::string str =
     160        "Error in NBC::predict: DataLookup2D of unexpected class.";
     161      throw std::runtime_error(str);
     162    }
     163
    155164
    156165    // -lnP might be a large number, in order to avoid out of bound
     
    176185
    177186
     187  double NBC::sum_logsigma(size_t label) const
     188  {
     189    double sum_log_sigma=0;
     190    assert(label<sigma2_.columns());
     191    for (size_t i=0; i<sigma2_.rows(); ++i) {
     192      sum_log_sigma += std::log(sigma2_(i, label));
     193    }
     194    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
     195  }
     196
    178197}}} // of namespace classifier, yat, and theplu
Note: See TracChangeset for help on using the changeset viewer.