Ignore:
Timestamp:
Feb 26, 2008, 2:25:19 PM (14 years ago)
Author:
Markus Ringnér
Message:

Refs #318

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NBC.cc

    r1144 r1157  
    4040namespace classifier {
    4141
    42   NBC::NBC(const MatrixLookup& data, const Target& target)
    43     : SupervisedClassifier(target), data_(data)
     42  NBC::NBC()
     43    : SupervisedClassifier()
    4444  {
    4545  }
    4646
    47   NBC::NBC(const MatrixLookupWeighted& data, const Target& target)
    48     : SupervisedClassifier(target), data_(data)
    49   {
    50   }
    5147
    5248  NBC::~NBC()   
     
    5551
    5652
    57   const DataLookup2D& NBC::data(void) const
    58   {
    59     return data_;
    60   }
    61 
    62 
    63   NBC*
    64   NBC::make_classifier(const DataLookup2D& data, const Target& target) const
     53  NBC* NBC::make_classifier() const
    6554  {     
    66     NBC* nbc=0;
    67     try {
    68       if(data.weighted()) {
    69         nbc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
    70       }
    71       else {
    72         nbc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
    73       }     
    74     }
    75     catch (std::bad_cast) {
    76       std::string str =
    77         "Error in NBC::make_classifier: DataLookup2D of unexpected class.";
    78       throw std::runtime_error(str);
    79     }
    80     return nbc;
    81   }
    82 
    83 
    84   void NBC::train()
     55    return new NBC();
     56  }
     57
     58
     59  void NBC::train(const MatrixLookup& data, const Target& target)
    8560  {   
    86     sigma2_.resize(data_.rows(), target_.nof_classes());
    87     centroids_.resize(data_.rows(), target_.nof_classes());
    88     utility::Matrix nof_in_class(data_.rows(), target_.nof_classes());
     61    sigma2_.resize(data.rows(), target.nof_classes());
     62    centroids_.resize(data.rows(), target.nof_classes());
     63    utility::Matrix nof_in_class(data.rows(), target.nof_classes());
    8964   
    90     // unweighted
    91     if (data_.weighted()){
    92       const MatrixLookupWeighted& data =
    93         dynamic_cast<const MatrixLookupWeighted&>(data_);
    94       for(size_t i=0; i<data_.rows(); ++i) {
    95         std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
    96         for(size_t j=0; j<data_.columns(); ++j)
    97           aver[target_(j)].add(data.data(i,j), data.weight(i,j));
    98 
    99         assert(centroids_.columns()==target_.nof_classes());
    100         for (size_t j=0; j<target_.nof_classes(); ++j){
    101           assert(i<centroids_.rows());
    102           assert(j<centroids_.columns());
    103           assert(i<sigma2_.rows());
    104           assert(j<sigma2_.columns());
    105           if (aver[j].n()>1){
    106             sigma2_(i,j) = aver[j].variance();
    107             centroids_(i,j) = aver[j].mean();
    108           }
     65    for(size_t i=0; i<data.rows(); ++i) {
     66      std::vector<statistics::Averager> aver(target.nof_classes());
     67      for(size_t j=0; j<data.columns(); ++j)
     68        aver[target(j)].add(data(i,j));
     69     
     70      assert(centroids_.columns()==target.nof_classes());
     71      for (size_t j=0; j<target.nof_classes(); ++j){
     72        assert(i<centroids_.rows());
     73        assert(j<centroids_.columns());
     74        centroids_(i,j) = aver[j].mean();
     75        assert(i<sigma2_.rows());
     76        assert(j<sigma2_.columns());
     77        if (aver[j].n()>1){
     78          sigma2_(i,j) = aver[j].variance();
     79          centroids_(i,j) = aver[j].mean();
     80        }
    10981          else {
    11082            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
    11183            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
    11284          }
    113         }
    114       }
    115     }
    116     else { 
    117       const MatrixLookup& data = dynamic_cast<const MatrixLookup&>(data_);
    118       for(size_t i=0; i<data_.rows(); ++i) {
    119         std::vector<statistics::Averager> aver(target_.nof_classes());
    120         for(size_t j=0; j<data_.columns(); ++j)
    121           aver[target_(j)].add(data(i,j));
    122 
    123         assert(centroids_.columns()==target_.nof_classes());
    124         for (size_t j=0; j<target_.nof_classes(); ++j){
    125           assert(i<centroids_.rows());
    126           assert(j<centroids_.columns());
     85      }
     86    }
     87    trained_=true;
     88  }   
     89
     90
     91  void NBC::train(const MatrixLookupWeighted& data, const Target& target)
     92  {   
     93    sigma2_.resize(data.rows(), target.nof_classes());
     94    centroids_.resize(data.rows(), target.nof_classes());
     95    utility::Matrix nof_in_class(data.rows(), target.nof_classes());
     96
     97    for(size_t i=0; i<data.rows(); ++i) {
     98      std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
     99      for(size_t j=0; j<data.columns(); ++j)
     100        aver[target(j)].add(data.data(i,j), data.weight(i,j));
     101     
     102      assert(centroids_.columns()==target.nof_classes());
     103      for (size_t j=0; j<target.nof_classes(); ++j) {
     104        assert(i<centroids_.rows());
     105        assert(j<centroids_.columns());
     106        assert(i<sigma2_.rows());
     107        assert(j<sigma2_.columns());
     108        if (aver[j].n()>1){
     109          sigma2_(i,j) = aver[j].variance();
    127110          centroids_(i,j) = aver[j].mean();
    128           assert(i<sigma2_.rows());
    129           assert(j<sigma2_.columns());
    130           if (aver[j].n()>1){
    131             sigma2_(i,j) = aver[j].variance();
    132             centroids_(i,j) = aver[j].mean();
    133           }
    134           else {
    135             sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
    136             centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
    137           }
    138         }
    139       }
    140     }   
     111        }
     112        else {
     113          sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
     114          centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
     115        }
     116      }
     117    }
    141118    trained_=true;
    142119  }
     
    146123                    utility::Matrix& prediction) const
    147124  {   
    148     assert(data_.rows()==x.rows());
    149125    assert(x.rows()==sigma2_.rows());
    150126    assert(x.rows()==centroids_.rows());
Note: See TracChangeset for help on using the changeset viewer.