Changeset 931 for trunk/yat/classifier
- Timestamp:
- Oct 5, 2007, 5:42:25 PM (16 years ago)
- Location:
- trunk/yat/classifier
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/yat/classifier/KNN.h
r916 r931 168 168 { 169 169 KNN* knn=0; 170 if(data.weighted()) { 171 knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data), 172 target); 173 } 174 knn->k(this->k()); 170 try { 171 if(data.weighted()) { 172 knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data), 173 target); 174 } 175 knn->k(this->k()); 176 } 177 catch (std::bad_cast) { 178 std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class."; 179 throw std::runtime_error(str); 180 } 175 181 return knn; 176 182 } -
trunk/yat/classifier/NCC.h
r930 r931 35 35 #include "Target.h" 36 36 37 #include "yat/statistics/Averager.h" 38 #include "yat/statistics/AveragerWeighted.h" 37 39 #include "yat/statistics/vector_distance.h" 38 40 … … 158 160 { 159 161 NCC* ncc=0; 160 if(data.weighted()) { 161 ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookupWeighted*>(&data), 162 target); 163 } 164 else { 165 ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookup*>(&data), 166 target); 167 } 168 ncc->centroids_=0; 162 try { 163 if(data.weighted()) { 164 ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data), 165 target); 166 } 167 else { 168 ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data), 169 target); 170 } 171 ncc->centroids_=0; 172 } 173 catch (std::bad_cast) { 174 std::string str = "Error in NCC<Distance>::make_classifier: DataLookup2D of unexpected class."; 175 throw std::runtime_error(str); 176 } 169 177 return ncc; 170 178 } … … 177 185 delete centroids_; 178 186 centroids_= new utility::matrix(data_.rows(), target_.nof_classes()); 179 utility::matrix nof_in_class(data_.rows(), target_.nof_classes()); 180 const MatrixLookupWeighted* weighted_data = 181 dynamic_cast<const MatrixLookupWeighted*>(&data_); 182 bool weighted = weighted_data; 183 184 for(size_t i=0; i<data_.rows(); i++) { 185 for(size_t j=0; j<data_.columns(); j++) { 186 (*centroids_)(i,target_(j)) += data_(i,j); 187 if (weighted) 188 nof_in_class(i,target_(j))+= weighted_data->weight(i,j); 189 else 190 nof_in_class(i,target_(j))+=1.0; 191 } 192 } 193 centroids_->div(nof_in_class); 187 // data_ is a MatrixLookup or a MatrixLookupWeighted 188 if(data_.weighted()) { 189 const MatrixLookupWeighted* weighted_data = 190 dynamic_cast<const MatrixLookupWeighted*>(&data_); 191 for(size_t i=0; i<data_.rows(); i++) { 192 std::vector<statistics::AveragerWeighted> class_averager; 193 class_averager.resize(target_.nof_classes()); 194 for(size_t j=0; j<data_.columns(); j++) { 195 class_averager[target_(j)].add((*weighted_data)(i,j), 196 weighted_data->weight(i,j)); 197 } 198 for(size_t c=0;c<target_.nof_classes();c++) { 199 (*centroids_)(i,c) = class_averager[c].mean(); 200 } 201 } 202 } 203 else { 204 const MatrixLookup* unweighted_data = 205 dynamic_cast<const MatrixLookup*>(&data_); 206 for(size_t i=0; i<data_.rows(); i++) { 207 std::vector<statistics::Averager> class_averager; 208 class_averager.resize(target_.nof_classes()); 209 for(size_t j=0; j<data_.columns(); j++) { 210 class_averager[target_(j)].add((*unweighted_data)(i,j)); 211 } 212 for(size_t c=0;c<target_.nof_classes();c++) { 213 (*centroids_)(i,c) = class_averager[c].mean(); 214 } 215 } 216 } 194 217 trained_=true; 195 218 return trained_; … … 200 223 utility::matrix& prediction) const 201 224 { 202 prediction.clone(utility::matrix(centroids_->columns(), input.columns())); 203 204 // Weighted case205 const MatrixLookup Weighted* testdata=206 dynamic_cast<const MatrixLookup Weighted*>(&input);207 if (test data) {208 MatrixLookup Weightedweighted_centroids(*centroids_);225 prediction.clone(utility::matrix(centroids_->columns(), input.columns())); 226 // If both training and test are unweighted: unweighted 227 // calculations are used 228 const MatrixLookup* test_unweighted = 229 dynamic_cast<const MatrixLookup*>(&input); 230 if (test_unweighted && !data_.weighted()) { 231 MatrixLookup unweighted_centroids(*centroids_); 209 232 for(size_t j=0; j<input.columns();j++) { 210 DataLookup Weighted1D in(*testdata,j,false);233 DataLookup1D in(*test_unweighted,j,false); 211 234 for(size_t k=0; k<centroids_->columns();k++) { 212 DataLookup Weighted1D centroid(weighted_centroids,k,false);235 DataLookup1D centroid(unweighted_centroids,k,false); 213 236 yat_assert(in.size()==centroid.size()); 214 237 prediction(k,j)=statistics:: … … 218 241 } 219 242 } 220 // Non-weighted case 221 else { 222 const MatrixLookup* testdata = 223 dynamic_cast<const MatrixLookup*>(&input); 224 if (testdata) { 225 MatrixLookup unweighted_centroids(*centroids_); 243 // if either training or test is weighted: weighted 244 // calculations are used 245 else { 246 const MatrixLookupWeighted* test_weighted = 247 dynamic_cast<const MatrixLookupWeighted*>(&input); 248 MatrixLookupWeighted weighted_centroids(*centroids_); 249 if(test_weighted) { 226 250 for(size_t j=0; j<input.columns();j++) { 227 DataLookup 1D in(*testdata,j,false);251 DataLookupWeighted1D in(*test_weighted,j,false); 228 252 for(size_t k=0; k<centroids_->columns();k++) { 229 DataLookup 1D centroid(unweighted_centroids,k,false);253 DataLookupWeighted1D centroid(weighted_centroids,k,false); 230 254 yat_assert(in.size()==centroid.size()); 231 255 prediction(k,j)=statistics:: … … 234 258 } 235 259 } 236 } 260 } 261 else if(data_.weighted() && test_unweighted) { 262 // MatrixLookupWeighted test2weighted(*test_unweighted); 263 // Need to convert MatrixLookup to MatrixLookupWeighted here 264 // and use it in the code below 265 for(size_t j=0; j<input.columns();j++) { 266 DataLookupWeighted1D in(*test_weighted,j,false); 267 for(size_t k=0; k<centroids_->columns();k++) { 268 DataLookupWeighted1D centroid(weighted_centroids,k,false); 269 yat_assert(in.size()==centroid.size()); 270 prediction(k,j)=statistics:: 271 vector_distance(in.begin(),in.end(),centroid.begin(), 272 typename statistics::vector_distance_traits<Distance>::distance()); 273 } 274 } 275 } 237 276 else { 238 277 std::string str;
Note: See TracChangeset
for help on using the changeset viewer.