Changeset 1033 for trunk/yat/classifier
- Timestamp:
- Feb 5, 2008, 12:12:12 PM (16 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/yat/classifier/NCC.h
r1031 r1033 107 107 private: 108 108 109 void predict_unweighted(const MatrixLookup&, utility::matrix&) const; 110 void predict_weighted(const MatrixLookupWeighted&, utility::matrix&) const; 111 109 112 utility::matrix* centroids_; 113 bool centroids_nan_; 110 114 111 115 // data_ has to be of type DataLookup2D to accomodate both 112 116 // MatrixLookup and MatrixLookupWeighted 113 117 const DataLookup2D& data_; 114 bool centroids_nan_;115 118 }; 116 119 … … 125 128 template <typename Distance> 126 129 NCC<Distance>::NCC(const MatrixLookup& data, const Target& target) 127 : SupervisedClassifier(target), centroids_(0), data_(data), centroids_nan_(false)130 : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data) 128 131 { 129 132 } … … 131 134 template <typename Distance> 132 135 NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target) 133 : SupervisedClassifier(target), centroids_(0), data_(data), centroids_nan_(false)136 : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data) 134 137 { 135 138 } … … 231 234 prediction.clone(utility::matrix(centroids_->columns(), test.columns())); 232 235 233 // unweighted test data and no nan's in centroids 234 // Markus: Should test centroid_nan_ here!!! 236 // unweighted test data 235 237 if (const MatrixLookup* test_unweighted = 236 238 dynamic_cast<const MatrixLookup*>(&test)) { 237 MatrixLookup unweighted_centroids(*centroids_);238 for(size_t j=0; j<test.columns();j++) {239 DataLookup1D in(*test_unweighted,j,false);240 for(size_t k=0; k<centroids_->columns();k++) {241 DataLookup1D centroid(unweighted_centroids,k,false);242 utility::yat_assert<std::runtime_error>(in.size()==centroid.size());243 prediction(k,j)=statistics::244 distance(in.begin(),in.end(),centroid.begin(),245 typename statistics::distance_traits<Distance>::distance());246 }247 } 248 } 249 // weighted test data 239 // If weighted training data resulting in NaN in centroids: weighted calculations 240 if(centroids_nan_) { 241 // predict_weighted(MatrixLookupWeighted(*test_unweighted),prediction); 242 std::string str = 243 "Error in NCC<Distance>::predict: weighted training unweighted test not implemented yet"; 244 throw std::runtime_error(str); 245 } 246 // If unweighted training data: unweighted calculations 247 else { 248 predict_unweighted(*test_unweighted,prediction); 249 } 250 } 251 // weighted test data: weighted calculations 250 252 else if (const MatrixLookupWeighted* test_weighted = 251 dynamic_cast<const MatrixLookupWeighted*>(&test)) { 252 MatrixLookupWeighted weighted_centroids(*centroids_); 253 for(size_t j=0; j<test.columns();j++) { 254 DataLookupWeighted1D in(*test_weighted,j,false); 255 for(size_t k=0; k<centroids_->columns();k++) { 256 DataLookupWeighted1D centroid(weighted_centroids,k,false); 257 utility::yat_assert<std::runtime_error>(in.size()==centroid.size()); 258 prediction(k,j)=statistics:: 259 distance(in.begin(),in.end(),centroid.begin(), 260 typename statistics::distance_traits<Distance>::distance()); 261 } 262 } 253 dynamic_cast<const MatrixLookupWeighted*>(&test)) { 254 predict_weighted(*test_weighted,prediction); 263 255 } 264 256 else { … … 268 260 } 269 261 } 262 263 template <typename Distance> 264 void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 265 utility::matrix& prediction) const 266 { 267 MatrixLookup unweighted_centroids(*centroids_); 268 for(size_t j=0; j<test.columns();j++) { 269 DataLookup1D in(test,j,false); 270 for(size_t k=0; k<centroids_->columns();k++) { 271 DataLookup1D centroid(unweighted_centroids,k,false); 272 utility::yat_assert<std::runtime_error>(in.size()==centroid.size()); 273 prediction(k,j)=statistics:: 274 distance(in.begin(),in.end(),centroid.begin(), 275 typename statistics::distance_traits<Distance>::distance()); 276 } 277 } 278 } 279 280 template <typename Distance> 281 void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 282 utility::matrix& prediction) const 283 { 284 MatrixLookupWeighted weighted_centroids(*centroids_); 285 for(size_t j=0; j<test.columns();j++) { 286 DataLookupWeighted1D in(test,j,false); 287 for(size_t k=0; k<centroids_->columns();k++) { 288 DataLookupWeighted1D centroid(weighted_centroids,k,false); 289 utility::yat_assert<std::runtime_error>(in.size()==centroid.size()); 290 prediction(k,j)=statistics:: 291 distance(in.begin(),in.end(),centroid.begin(), 292 typename statistics::distance_traits<Distance>::distance()); 293 } 294 } 295 } 296 270 297 271 298 }}} // of namespace classifier, yat, and theplu
Note: See TracChangeset
for help on using the changeset viewer.