Changeset 1157 for trunk/yat/classifier/NBC.cc
 Timestamp:
 Feb 26, 2008, 2:25:19 PM (15 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/NBC.cc
r1144 r1157 40 40 namespace classifier { 41 41 42 NBC::NBC( const MatrixLookup& data, const Target& target)43 : SupervisedClassifier( target), data_(data)42 NBC::NBC() 43 : SupervisedClassifier() 44 44 { 45 45 } 46 46 47 NBC::NBC(const MatrixLookupWeighted& data, const Target& target)48 : SupervisedClassifier(target), data_(data)49 {50 }51 47 52 48 NBC::~NBC() … … 55 51 56 52 57 const DataLookup2D& NBC::data(void) const 58 { 59 return data_; 60 } 61 62 63 NBC* 64 NBC::make_classifier(const DataLookup2D& data, const Target& target) const 53 NBC* NBC::make_classifier() const 65 54 { 66 NBC* nbc=0; 67 try { 68 if(data.weighted()) { 69 nbc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target); 70 } 71 else { 72 nbc=new NBC(dynamic_cast<const MatrixLookup&>(data),target); 73 } 74 } 75 catch (std::bad_cast) { 76 std::string str = 77 "Error in NBC::make_classifier: DataLookup2D of unexpected class."; 78 throw std::runtime_error(str); 79 } 80 return nbc; 81 } 82 83 84 void NBC::train() 55 return new NBC(); 56 } 57 58 59 void NBC::train(const MatrixLookup& data, const Target& target) 85 60 { 86 sigma2_.resize(data _.rows(), target_.nof_classes());87 centroids_.resize(data _.rows(), target_.nof_classes());88 utility::Matrix nof_in_class(data _.rows(), target_.nof_classes());61 sigma2_.resize(data.rows(), target.nof_classes()); 62 centroids_.resize(data.rows(), target.nof_classes()); 63 utility::Matrix nof_in_class(data.rows(), target.nof_classes()); 89 64 90 // unweighted 91 if (data_.weighted()){ 92 const MatrixLookupWeighted& data = 93 dynamic_cast<const MatrixLookupWeighted&>(data_); 94 for(size_t i=0; i<data_.rows(); ++i) { 95 std::vector<statistics::AveragerWeighted> aver(target_.nof_classes()); 96 for(size_t j=0; j<data_.columns(); ++j) 97 aver[target_(j)].add(data.data(i,j), data.weight(i,j)); 98 99 assert(centroids_.columns()==target_.nof_classes()); 100 for (size_t j=0; j<target_.nof_classes(); ++j){ 101 assert(i<centroids_.rows()); 102 assert(j<centroids_.columns()); 103 assert(i<sigma2_.rows()); 104 assert(j<sigma2_.columns()); 105 if (aver[j].n()>1){ 106 sigma2_(i,j) = aver[j].variance(); 107 centroids_(i,j) = aver[j].mean(); 108 } 65 for(size_t i=0; i<data.rows(); ++i) { 66 std::vector<statistics::Averager> aver(target.nof_classes()); 67 for(size_t j=0; j<data.columns(); ++j) 68 aver[target(j)].add(data(i,j)); 69 70 assert(centroids_.columns()==target.nof_classes()); 71 for (size_t j=0; j<target.nof_classes(); ++j){ 72 assert(i<centroids_.rows()); 73 assert(j<centroids_.columns()); 74 centroids_(i,j) = aver[j].mean(); 75 assert(i<sigma2_.rows()); 76 assert(j<sigma2_.columns()); 77 if (aver[j].n()>1){ 78 sigma2_(i,j) = aver[j].variance(); 79 centroids_(i,j) = aver[j].mean(); 80 } 109 81 else { 110 82 sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN(); 111 83 centroids_(i,j) = std::numeric_limits<double>::quiet_NaN(); 112 84 } 113 } 114 } 115 } 116 else { 117 const MatrixLookup& data = dynamic_cast<const MatrixLookup&>(data_); 118 for(size_t i=0; i<data_.rows(); ++i) { 119 std::vector<statistics::Averager> aver(target_.nof_classes()); 120 for(size_t j=0; j<data_.columns(); ++j) 121 aver[target_(j)].add(data(i,j)); 122 123 assert(centroids_.columns()==target_.nof_classes()); 124 for (size_t j=0; j<target_.nof_classes(); ++j){ 125 assert(i<centroids_.rows()); 126 assert(j<centroids_.columns()); 85 } 86 } 87 trained_=true; 88 } 89 90 91 void NBC::train(const MatrixLookupWeighted& data, const Target& target) 92 { 93 sigma2_.resize(data.rows(), target.nof_classes()); 94 centroids_.resize(data.rows(), target.nof_classes()); 95 utility::Matrix nof_in_class(data.rows(), target.nof_classes()); 96 97 for(size_t i=0; i<data.rows(); ++i) { 98 std::vector<statistics::AveragerWeighted> aver(target.nof_classes()); 99 for(size_t j=0; j<data.columns(); ++j) 100 aver[target(j)].add(data.data(i,j), data.weight(i,j)); 101 102 assert(centroids_.columns()==target.nof_classes()); 103 for (size_t j=0; j<target.nof_classes(); ++j) { 104 assert(i<centroids_.rows()); 105 assert(j<centroids_.columns()); 106 assert(i<sigma2_.rows()); 107 assert(j<sigma2_.columns()); 108 if (aver[j].n()>1){ 109 sigma2_(i,j) = aver[j].variance(); 127 110 centroids_(i,j) = aver[j].mean(); 128 assert(i<sigma2_.rows()); 129 assert(j<sigma2_.columns()); 130 if (aver[j].n()>1){ 131 sigma2_(i,j) = aver[j].variance(); 132 centroids_(i,j) = aver[j].mean(); 133 } 134 else { 135 sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN(); 136 centroids_(i,j) = std::numeric_limits<double>::quiet_NaN(); 137 } 138 } 139 } 140 } 111 } 112 else { 113 sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN(); 114 centroids_(i,j) = std::numeric_limits<double>::quiet_NaN(); 115 } 116 } 117 } 141 118 trained_=true; 142 119 } … … 146 123 utility::Matrix& prediction) const 147 124 { 148 assert(data_.rows()==x.rows());149 125 assert(x.rows()==sigma2_.rows()); 150 126 assert(x.rows()==centroids_.rows());
Note: See TracChangeset
for help on using the changeset viewer.