Ignore:
Timestamp:
Feb 22, 2007, 4:14:40 PM (15 years ago)
Author:
Peter
Message:

Fixes #65

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r722 r767  
    3030#include "yat/utility/matrix.h"
    3131
     32#include <cassert>
    3233#include <vector>
    3334
     
    5152
    5253
    53     const DataLookup2D& NBC::data(void) const
    54     {
     54  const DataLookup2D& NBC::data(void) const
     55  {
    5556    return data_;
    56     }
     57  }
    5758
    5859
     
    8384          const MatrixLookupWeighted& data =
    8485            dynamic_cast<const MatrixLookupWeighted&>(data_);
    85             aver[target_(j)].add(data.data(i,j), data.weight(i,j));
     86          aver[target_(j)].add(data.data(i,j), data.weight(i,j));
    8687        }
    8788        else
     
    9091      for (size_t j=0; target_.nof_classes(); ++j){
    9192        centroids_(i,j) = aver[j].mean();
    92         sigma_(i,j) = aver[j].variance();
     93        sigma2_(i,j) = aver[j].variance();
    9394      }
    9495    }   
     
    103104    std::cerr << "NBC::predict not implemented\n";
    104105    exit(1);
     106    assert(data_.rows()==input.rows());
     107
     108    std::log(sigma_(i,c)) +
     109
     110    prediction = utility::matrix(centroids_.columns(),input.columns());
     111    for (size_t c=0; c<centroid_.columns(); ++c) {
     112      double sum_ln_sigma=0;
     113      for (size_t i=0; i<sigma2_.rows(); ++i)
     114        sum_ln_sigma += log(sigma2_(i,c));
     115      sum_ln_sigma /= 2;
     116
     117      for (size_t s=0; s<input.columns(); ++s) {
     118        // -lnp = sum{ln(sigma_i)} + sum{(x_i-m_i)^2/(2sigma_i)}
     119        prediction(c,s) = sum_ln_sigma;
     120        for (size_t i=0; i<input.columns(); ++i) {
     121          prediction(c,s) += std::pow(input(i,s)-mean_(i,c),2)/sigma2_(i,c);
     122        }
     123      }
     124    }
     125    // exponentiate and normalize
    105126  }
    106127
Note: See TracChangeset for help on using the changeset viewer.