Changeset 1007 for trunk/yat/classifier
- Timestamp:
- Jan 29, 2008, 10:53:23 AM (16 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/yat/classifier/NCC.h
r1000 r1007 220 220 221 221 template <typename Distance> 222 void NCC<Distance>::predict(const DataLookup2D& input,222 void NCC<Distance>::predict(const DataLookup2D& test, 223 223 utility::matrix& prediction) const 224 224 { 225 prediction.clone(utility::matrix(centroids_->columns(), input.columns())); 226 // If both training and test are unweighted: unweighted 227 // calculations are used 228 const MatrixLookup* test_unweighted = 229 dynamic_cast<const MatrixLookup*>(&input); 230 if (test_unweighted && !data_.weighted()) { 225 utility::yat_assert<std::runtime_error>(data_.rows()==test.rows()); 226 utility::yat_assert<std::runtime_error>(test.rows()==centroids_->rows()); 227 228 prediction.clone(utility::matrix(centroids_->columns(), test.columns())); 229 230 // unweighted test data 231 if (const MatrixLookup* test_unweighted = 232 dynamic_cast<const MatrixLookup*>(&test)) { 231 233 MatrixLookup unweighted_centroids(*centroids_); 232 for(size_t j=0; j< input.columns();j++) {234 for(size_t j=0; j<test.columns();j++) { 233 235 DataLookup1D in(*test_unweighted,j,false); 234 236 for(size_t k=0; k<centroids_->columns();k++) { … … 241 243 } 242 244 } 243 // if either training or test is weighted: weighted 244 // calculations are used 245 else { 246 const MatrixLookupWeighted* test_weighted = 247 dynamic_cast<const MatrixLookupWeighted*>(&input); 248 if(test_weighted) { 249 MatrixLookupWeighted weighted_centroids(*centroids_); 250 for(size_t j=0; j<input.columns();j++) { 251 DataLookupWeighted1D in(*test_weighted,j,false); 252 for(size_t k=0; k<centroids_->columns();k++) { 253 DataLookupWeighted1D centroid(weighted_centroids,k,false); 254 utility::yat_assert<std::runtime_error>(in.size()==centroid.size()); 255 prediction(k,j)=statistics:: 256 vector_distance(in.begin(),in.end(),centroid.begin(), 257 typename statistics::vector_distance_traits<Distance>::distance()); 258 } 259 } 260 } 261 else if(data_.weighted() && test_unweighted) { 262 std::string str = "Error in NCC<Distance>::predict:"; 263 str += " predicting unweighted data when NCC"; 264 str += " is trained on weighted data is not yet supported"; 265 throw std::runtime_error(str); 266 } 267 else { 268 std::string str = 269 "Error in NCC<Distance>::predict: DataLookup2D of unexpected class."; 270 throw std::runtime_error(str); 271 } 245 // weighted test data 246 else if (const MatrixLookupWeighted* test_weighted = 247 dynamic_cast<const MatrixLookupWeighted*>(&test)) { 248 MatrixLookupWeighted weighted_centroids(*centroids_); 249 for(size_t j=0; j<test.columns();j++) { 250 DataLookupWeighted1D in(*test_weighted,j,false); 251 for(size_t k=0; k<centroids_->columns();k++) { 252 DataLookupWeighted1D centroid(weighted_centroids,k,false); 253 utility::yat_assert<std::runtime_error>(in.size()==centroid.size()); 254 prediction(k,j)=statistics:: 255 vector_distance(in.begin(),in.end(),centroid.begin(), 256 typename statistics::vector_distance_traits<Distance>::distance()); 257 } 258 } 259 } 260 else { 261 std::string str = 262 "Error in NCC<Distance>::predict: DataLookup2D of unexpected class."; 263 throw std::runtime_error(str); 272 264 } 273 265 }
Note: See TracChangeset
for help on using the changeset viewer.