Changeset 1184 for trunk/yat/classifier


Ignore:
Timestamp:
Feb 28, 2008, 7:49:57 PM (16 years ago)
Author:
Peter
Message:

refs #335 predict fixed for NBC

Location:
trunk/yat/classifier
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r1182 r1184  
    6161    sigma2_.resize(data.rows(), target.nof_classes());
    6262    centroids_.resize(data.rows(), target.nof_classes());
    63     utility::Matrix nof_in_class(data.rows(), target.nof_classes());
    6463   
    6564    for(size_t i=0; i<data.rows(); ++i) {
     
    7271        assert(i<centroids_.rows());
    7372        assert(j<centroids_.columns());
    74         centroids_(i,j) = aver[j].mean();
    7573        assert(i<sigma2_.rows());
    7674        assert(j<sigma2_.columns());
     
    7977          centroids_(i,j) = aver[j].mean();
    8078        }
    81           else {
     79        else {
    8280            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
    8381            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
    84           }
     82        }
    8583      }
    8684    }
     
    9290    sigma2_.resize(data.rows(), target.nof_classes());
    9391    centroids_.resize(data.rows(), target.nof_classes());
    94     utility::Matrix nof_in_class(data.rows(), target.nof_classes());
    9592
    9693    for(size_t i=0; i<data.rows(); ++i) {
     
    132129        prediction(label,sample) = sum_log_sigma;
    133130        for (size_t i=0; i<ml.rows(); ++i)
    134           // Ignoring missing features
    135           if (!std::isnan(sigma2_(i, label)))
    136             prediction(label, sample) +=
    137               std::pow(ml(i, label)-centroids_(i, label),2)/
    138               sigma2_(i, label);
     131          prediction(label, sample) +=
     132            std::pow(ml(i, label)-centroids_(i, label),2)/
     133            sigma2_(i, label);
    139134      }
    140135    }
     
    159154        statistics::AveragerWeighted aw;
    160155        for (size_t i=0; i<mlw.rows(); ++i)
    161           // missing training features
    162           if (!std::isnan(sigma2_(i, label)))
    163             aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
    164                    sigma2_(i, label), mlw.weight(i, label));
     156          aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
     157                 sigma2_(i, label), mlw.weight(i, label));
    165158        prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2;
    166159      }
     
    171164  void NBC::standardize_lnP(utility::Matrix& prediction) const
    172165  {
    173     // -lnP might be a large number, in order to avoid out of bound
    174     // problems when calculating P = exp(- -lnP), we centralize matrix
    175     // by adding a constant.
    176     statistics::Averager a;
    177     add(a, prediction.begin(), prediction.end());
     166    /// -lnP might be a large number, in order to avoid out of bound
     167    /// problems when calculating P = exp(- -lnP), we centralize matrix
     168    /// by adding a constant.
     169    // lookup of prediction with zero weights for NaNs
     170    MatrixLookupWeighted mlw(prediction);
     171    statistics::AveragerWeighted a;
     172    add(a, mlw.begin(), mlw.end());
    178173    prediction -= a.mean();
    179174   
     
    185180    // normalize each row (label) to sum up to unity (probability)
    186181    for (size_t i=0; i<prediction.rows(); ++i){
    187       prediction.row_view(i) *= 1.0/sum(prediction.row_const_view(i));
     182      // calculate sum of row ignoring NaNs
     183      statistics::AveragerWeighted a;
     184      add(a, mlw.begin_row(i), mlw.end_row(i));
     185      prediction.row_view(i) *= 1.0/a.sum_wx();
    188186    }
    189187  }
     
    195193    assert(label<sigma2_.columns());
    196194    for (size_t i=0; i<sigma2_.rows(); ++i) {
    197       if (!std::isnan(sigma2_(i,label)))
    198         sum_log_sigma += std::log(sigma2_(i, label));
     195      sum_log_sigma += std::log(sigma2_(i, label));
    199196    }
    200197    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
  • trunk/yat/classifier/NBC.h

    r1182 r1184  
    4141 
    4242     Each class is modelled as a multinormal distribution with
    43      features being independent: \f$ p(x|c) = \prod
     43     features being independent: \f$ P(x|c) \propto \prod
    4444     \frac{1}{\sqrt{2\pi\sigma_i^2}} \exp \left(
    45      \frac{(x_i-m_i)^2}{2\sigma_i^2)} \right)\f$
     45     -\frac{(x_i-\mu_i)^2}{2\sigma_i^2)} \right)\f$
    4646  */
    4747  class NBC : public SupervisedClassifier
     
    6464   
    6565    ///
    66     /// Train the classifier using training data and targets.
     66    /// \brief Train the %classifier using training data and targets.
    6767    ///
    6868    /// For each class mean and variance are estimated for each
    69     /// feature (see Averager and AveragerWeighted for details).
     69    /// feature (see statistics::Averager for details).
    7070    ///
    71     /// If variance can not be estimated (only one valid data point)
    72     /// for a feature and label, then that feature is ignored for that
    73     /// specific label.
     71    /// If there is only one (or zero) samples in a class, parameters
     72    /// cannot be estimated. In that case, parameters are set to NaN
     73    /// for that particular class.
    7474    ///
    7575    void train(const MatrixLookup&, const Target&);
    7676
    7777    ///
    78     /// Train the classifier using weighted training data and targets.
     78    /// \brief Train the %classifier using weighted training data and
     79    /// targets.
     80    ///
     81    /// For each class mean and variance are estimated for each
     82    /// feature (see statistics::AveragerWeighted for details).
     83    ///
     84    /// To estimate the parameters of a class, each feature of the
     85    /// class must have at least two non-zero data points. Otherwise
     86    /// the parameters are set to NaN and any prediction will result
     87    /// in NaN for that particular class.
    7988    ///
    8089    void train(const MatrixLookupWeighted&, const Target&);
    81 
    82 
    8390   
    8491    /**
     92       \brief Predict samples using unweighted data
     93
    8594       Each sample (column) in \a data is predicted and predictions
    86        are returned in the corresponding column in passed \a res. Each
    87        row in \a res corresponds to a class. The prediction is the
    88        estimated probability that sample belong to class \f$ j \f$
     95       are returned in the corresponding column in passed \a
     96       result. Each row in \a result corresponds to a class. The
     97       prediction is the estimated probability that sample belong to
     98       class \f$ j \f$:
    8999
    90        \f$ P_j = \frac{1}{Z}\prod_i{\frac{1}{\sqrt{2\pi\sigma_i^2}}}
    91        \exp(\frac{(x_i-\mu_i)^2}{\sigma_i^2})\f$, where \f$ \mu_i
     100       \f$ P_j = \frac{1}{Z}\prod_i\frac{1}{\sqrt{2\pi\sigma_i^2}}
     101       \exp\left(-\frac{(x_i-\mu_i)^2}{2\sigma_i^2}\right)\f$, where \f$ \mu_i
    92102       \f$ and \f$ \sigma_i^2 \f$ are the estimated mean and variance,
    93        respectively. If a \f$ \sigma_i \f$ could not be estimated
    94        during training, corresponding factor is set to unity, in other
    95        words, that feature is ignored for the prediction of that
    96        particular class. Z is chosen such that total probability, \f$
    97        \sum P_j \f$, equals unity.
     103       respectively. Z is chosen such that total probability equals unity, \f$
     104       \sum P_j = 1 \f$.
     105
     106       \note If parameters could not be estimated during training, due
     107       to lack of number of sufficient data points, the output for
     108       that class is NaN and not included in calculation of
     109       normalization factor \f$ Z \f$.
    98110    */
    99     void predict(const MatrixLookup& data, utility::Matrix& res) const;
     111    void predict(const MatrixLookup& data, utility::Matrix& result) const;
    100112
    101113    /**
     114       \brief Predict samples using weighted data
     115
    102116       Each sample (column) in \a data is predicted and predictions
    103        are returned in the corresponding column in passed \a res. Each
    104        row in \a res corresponds to a class. The prediction is the
    105        estimated probability that sample belong to class \f$ j \f$
     117       are returned in the corresponding column in passed \a
     118       result. Each row in \a result corresponds to a class. The
     119       prediction is the estimated probability that sample belong to
     120       class \f$ j \f$:
    106121
    107        \f$ P_j = \frac{1}{Z}\prod_i\({\frac{1}{\sqrt{2\pi\sigma_i^2}}}\)
    108        \exp(\frac{\sum{w_i(x_i-\mu_i)^2}{\sigma_i^2}}{\sum w_i})\f$,
    109        where \f$ \mu_i
    110        \f$ and \f$ \sigma_i^2 \f$ are the estimated mean and variance,
    111        respectively. If a \f$ \sigma_i \f$ could not be estimated
    112        during training, corresponding factor is set to unity, in other
    113        words, that feature is ignored for the prediction of that
    114        particular class. Z is chosen such that total probability, \f$
    115        \sum P_j \f$, equals unity.
     122       \f$ P_j = \frac{1}{Z} \exp\left(-N\frac{\sum
     123       {w_i(x_i-\mu_i)^2}/(2\sigma_i^2)}{\sum w_i}\right)\f$,
     124       where \f$ \mu_i \f$ and \f$ \sigma_i^2 \f$ are the estimated
     125       mean and variance, respectively. Z is chosen such that
     126       total probability equals unity, \f$ \sum P_j = 1 \f$.
     127
     128       \note If parameters could not be estimated during training, due
     129       to lack of number of sufficient data points, the output for
     130       that class is NaN and not included in calculation of
     131       normalization factor \f$ Z \f$.
    116132     */
    117     void predict(const MatrixLookupWeighted& data, utility::Matrix& res) const;
     133    void predict(const MatrixLookupWeighted& data,utility::Matrix& result) const;
    118134
    119135
Note: See TracChangeset for help on using the changeset viewer.