Ignore:
Timestamp:
Feb 28, 2008, 7:49:57 PM (14 years ago)
Author:
Peter
Message:

refs #335 predict fixed for NBC

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r1182 r1184  
    6161    sigma2_.resize(data.rows(), target.nof_classes());
    6262    centroids_.resize(data.rows(), target.nof_classes());
    63     utility::Matrix nof_in_class(data.rows(), target.nof_classes());
    6463   
    6564    for(size_t i=0; i<data.rows(); ++i) {
     
    7271        assert(i<centroids_.rows());
    7372        assert(j<centroids_.columns());
    74         centroids_(i,j) = aver[j].mean();
    7573        assert(i<sigma2_.rows());
    7674        assert(j<sigma2_.columns());
     
    7977          centroids_(i,j) = aver[j].mean();
    8078        }
    81           else {
     79        else {
    8280            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
    8381            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
    84           }
     82        }
    8583      }
    8684    }
     
    9290    sigma2_.resize(data.rows(), target.nof_classes());
    9391    centroids_.resize(data.rows(), target.nof_classes());
    94     utility::Matrix nof_in_class(data.rows(), target.nof_classes());
    9592
    9693    for(size_t i=0; i<data.rows(); ++i) {
     
    132129        prediction(label,sample) = sum_log_sigma;
    133130        for (size_t i=0; i<ml.rows(); ++i)
    134           // Ignoring missing features
    135           if (!std::isnan(sigma2_(i, label)))
    136             prediction(label, sample) +=
    137               std::pow(ml(i, label)-centroids_(i, label),2)/
    138               sigma2_(i, label);
     131          prediction(label, sample) +=
     132            std::pow(ml(i, label)-centroids_(i, label),2)/
     133            sigma2_(i, label);
    139134      }
    140135    }
     
    159154        statistics::AveragerWeighted aw;
    160155        for (size_t i=0; i<mlw.rows(); ++i)
    161           // missing training features
    162           if (!std::isnan(sigma2_(i, label)))
    163             aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
    164                    sigma2_(i, label), mlw.weight(i, label));
     156          aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
     157                 sigma2_(i, label), mlw.weight(i, label));
    165158        prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2;
    166159      }
     
    171164  void NBC::standardize_lnP(utility::Matrix& prediction) const
    172165  {
    173     // -lnP might be a large number, in order to avoid out of bound
    174     // problems when calculating P = exp(- -lnP), we centralize matrix
    175     // by adding a constant.
    176     statistics::Averager a;
    177     add(a, prediction.begin(), prediction.end());
     166    /// -lnP might be a large number, in order to avoid out of bound
     167    /// problems when calculating P = exp(- -lnP), we centralize matrix
     168    /// by adding a constant.
     169    // lookup of prediction with zero weights for NaNs
     170    MatrixLookupWeighted mlw(prediction);
     171    statistics::AveragerWeighted a;
     172    add(a, mlw.begin(), mlw.end());
    178173    prediction -= a.mean();
    179174   
     
    185180    // normalize each row (label) to sum up to unity (probability)
    186181    for (size_t i=0; i<prediction.rows(); ++i){
    187       prediction.row_view(i) *= 1.0/sum(prediction.row_const_view(i));
     182      // calculate sum of row ignoring NaNs
     183      statistics::AveragerWeighted a;
     184      add(a, mlw.begin_row(i), mlw.end_row(i));
     185      prediction.row_view(i) *= 1.0/a.sum_wx();
    188186    }
    189187  }
     
    195193    assert(label<sigma2_.columns());
    196194    for (size_t i=0; i<sigma2_.rows(); ++i) {
    197       if (!std::isnan(sigma2_(i,label)))
    198         sum_log_sigma += std::log(sigma2_(i, label));
     195      sum_log_sigma += std::log(sigma2_(i, label));
    199196    }
    200197    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
Note: See TracChangeset for help on using the changeset viewer.