Changeset 960 for trunk/yat/classifier


Ignore:
Timestamp:
Oct 10, 2007, 7:44:31 PM (16 years ago)
Author:
Peter
Message:

fixed ticket:271

Location:
trunk/yat/classifier
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r959 r960  
    8787    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
    8888   
    89     for(size_t i=0; i<data_.rows(); ++i) {
    90       std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
    91       for(size_t j=0; j<data_.columns(); ++j) {
    92         if (data_.weighted()){
    93           const MatrixLookupWeighted& data =
    94             dynamic_cast<const MatrixLookupWeighted&>(data_);
     89    // unweighted
     90    if (data_.weighted()){
     91      const MatrixLookupWeighted& data =
     92        dynamic_cast<const MatrixLookupWeighted&>(data_);
     93      for(size_t i=0; i<data_.rows(); ++i) {
     94        std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
     95        for(size_t j=0; j<data_.columns(); ++j)
    9596          aver[target_(j)].add(data.data(i,j), data.weight(i,j));
    96         }
    97         else
    98           aver[target_(j)].add(data_(i,j),1.0);
    99       }
    100       assert(centroids_.columns()==target_.nof_classes());
    101       for (size_t j=0; j<target_.nof_classes(); ++j){
    102         assert(i<centroids_.rows());
    103         assert(j<centroids_.columns());
    104         centroids_(i,j) = aver[j].mean();
    105         assert(i<sigma2_.rows());
    106         assert(j<sigma2_.columns());
    107         sigma2_(i,j) = aver[j].variance();
     97
     98        assert(centroids_.columns()==target_.nof_classes());
     99        for (size_t j=0; j<target_.nof_classes(); ++j){
     100          assert(i<centroids_.rows());
     101          assert(j<centroids_.columns());
     102          centroids_(i,j) = aver[j].mean();
     103          assert(i<sigma2_.rows());
     104          assert(j<sigma2_.columns());
     105          if (aver[j].variance())
     106            sigma2_(i,j) = aver[j].variance();
     107          else
     108            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
     109        }
     110      }
     111    }
     112    else { 
     113      const MatrixLookup& data = dynamic_cast<const MatrixLookup&>(data_);
     114      for(size_t i=0; i<data_.rows(); ++i) {
     115        std::vector<statistics::Averager> aver(target_.nof_classes());
     116        for(size_t j=0; j<data_.columns(); ++j)
     117          aver[target_(j)].add(data(i,j));
     118
     119        assert(centroids_.columns()==target_.nof_classes());
     120        for (size_t j=0; j<target_.nof_classes(); ++j){
     121          assert(i<centroids_.rows());
     122          assert(j<centroids_.columns());
     123          centroids_(i,j) = aver[j].mean();
     124          assert(i<sigma2_.rows());
     125          assert(j<sigma2_.columns());
     126          if (aver[j].variance())
     127            sigma2_(i,j) = aver[j].variance();
     128          else
     129            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
     130        }
    108131      }
    109132    }   
     
    133156          prediction(label,sample) = sum_log_sigma;
    134157          for (size_t i=0; i<x.rows(); ++i)
    135             // taking care of NaN
    136             if (mlw->weight(i, label)) {
     158            // taking care of NaN and missing training features
     159            if (mlw->weight(i, label) && !std::isnan(sigma2_(i, label))) {
    137160              prediction(label, sample) += mlw->weight(i, label)*
    138161                std::pow(mlw->data(i, label)-centroids_(i, label),2)/
     
    151174          prediction(label,sample) = sum_log_sigma;
    152175          for (size_t i=0; i<ml->rows(); ++i)
    153             prediction(label, sample) +=
    154               std::pow((*ml)(i, label)-centroids_(i, label),2)/sigma2_(i, label);
     176            // Ignoring missing features
     177            if (!std::isnan(sigma2_(i, label)))
     178              prediction(label, sample) +=
     179                std::pow((*ml)(i, label)-centroids_(i, label),2)/
     180                sigma2_(i, label);
    155181        }
    156182      }
     
    190216    assert(label<sigma2_.columns());
    191217    for (size_t i=0; i<sigma2_.rows(); ++i) {
    192       sum_log_sigma += std::log(sigma2_(i, label));
     218      if (!std::isnan(sigma2_(i,label)))
     219        sum_log_sigma += std::log(sigma2_(i, label));
    193220    }
    194221    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
  • trunk/yat/classifier/NBC.h

    r959 r960  
    7878    /// feature (see Averager and AveragerWeighted for details).
    7979    ///
     80    /// If variance can not be estimated (too few data points or all
     81    /// points identical) for a feature and label, then that feature
     82    /// is ignored for that specific label.
     83    ///
    8084    /// @return true if training succedeed.
    8185    ///
Note: See TracChangeset for help on using the changeset viewer.