Changeset 960 for trunk/yat/classifier/NBC.cc
 Timestamp:
 Oct 10, 2007, 7:44:31 PM (14 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/NBC.cc
r959 r960 87 87 utility::matrix nof_in_class(data_.rows(), target_.nof_classes()); 88 88 89 for(size_t i=0; i<data_.rows(); ++i) { 90 std::vector<statistics::AveragerWeighted> aver(target_.nof_classes()); 91 for(size_t j=0; j<data_.columns(); ++j) { 92 if (data_.weighted()){ 93 const MatrixLookupWeighted& data = 94 dynamic_cast<const MatrixLookupWeighted&>(data_); 89 // unweighted 90 if (data_.weighted()){ 91 const MatrixLookupWeighted& data = 92 dynamic_cast<const MatrixLookupWeighted&>(data_); 93 for(size_t i=0; i<data_.rows(); ++i) { 94 std::vector<statistics::AveragerWeighted> aver(target_.nof_classes()); 95 for(size_t j=0; j<data_.columns(); ++j) 95 96 aver[target_(j)].add(data.data(i,j), data.weight(i,j)); 96 } 97 else 98 aver[target_(j)].add(data_(i,j),1.0); 99 } 100 assert(centroids_.columns()==target_.nof_classes()); 101 for (size_t j=0; j<target_.nof_classes(); ++j){ 102 assert(i<centroids_.rows()); 103 assert(j<centroids_.columns()); 104 centroids_(i,j) = aver[j].mean(); 105 assert(i<sigma2_.rows()); 106 assert(j<sigma2_.columns()); 107 sigma2_(i,j) = aver[j].variance(); 97 98 assert(centroids_.columns()==target_.nof_classes()); 99 for (size_t j=0; j<target_.nof_classes(); ++j){ 100 assert(i<centroids_.rows()); 101 assert(j<centroids_.columns()); 102 centroids_(i,j) = aver[j].mean(); 103 assert(i<sigma2_.rows()); 104 assert(j<sigma2_.columns()); 105 if (aver[j].variance()) 106 sigma2_(i,j) = aver[j].variance(); 107 else 108 sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN(); 109 } 110 } 111 } 112 else { 113 const MatrixLookup& data = dynamic_cast<const MatrixLookup&>(data_); 114 for(size_t i=0; i<data_.rows(); ++i) { 115 std::vector<statistics::Averager> aver(target_.nof_classes()); 116 for(size_t j=0; j<data_.columns(); ++j) 117 aver[target_(j)].add(data(i,j)); 118 119 assert(centroids_.columns()==target_.nof_classes()); 120 for (size_t j=0; j<target_.nof_classes(); ++j){ 121 assert(i<centroids_.rows()); 122 assert(j<centroids_.columns()); 123 centroids_(i,j) = aver[j].mean(); 124 assert(i<sigma2_.rows()); 125 assert(j<sigma2_.columns()); 126 if (aver[j].variance()) 127 sigma2_(i,j) = aver[j].variance(); 128 else 129 sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN(); 130 } 108 131 } 109 132 } … … 133 156 prediction(label,sample) = sum_log_sigma; 134 157 for (size_t i=0; i<x.rows(); ++i) 135 // taking care of NaN 136 if (mlw>weight(i, label) ) {158 // taking care of NaN and missing training features 159 if (mlw>weight(i, label) && !std::isnan(sigma2_(i, label))) { 137 160 prediction(label, sample) += mlw>weight(i, label)* 138 161 std::pow(mlw>data(i, label)centroids_(i, label),2)/ … … 151 174 prediction(label,sample) = sum_log_sigma; 152 175 for (size_t i=0; i<ml>rows(); ++i) 153 prediction(label, sample) += 154 std::pow((*ml)(i, label)centroids_(i, label),2)/sigma2_(i, label); 176 // Ignoring missing features 177 if (!std::isnan(sigma2_(i, label))) 178 prediction(label, sample) += 179 std::pow((*ml)(i, label)centroids_(i, label),2)/ 180 sigma2_(i, label); 155 181 } 156 182 } … … 190 216 assert(label<sigma2_.columns()); 191 217 for (size_t i=0; i<sigma2_.rows(); ++i) { 192 sum_log_sigma += std::log(sigma2_(i, label)); 218 if (!std::isnan(sigma2_(i,label))) 219 sum_log_sigma += std::log(sigma2_(i, label)); 193 220 } 194 221 return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
Note: See TracChangeset
for help on using the changeset viewer.