Changeset 3552 for trunk/yat/classifier/KNN.h
- Timestamp:
- Jan 3, 2017, 8:48:46 AM (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/yat/classifier/KNN.h
r2384 r3552 1 #ifndef _theplu_yat_classifier_knn_ 2 #define _theplu_yat_classifier_knn_ 1 #ifndef _theplu_yat_classifier_knn_ 2 #define _theplu_yat_classifier_knn_ 3 3 4 4 // $Id$ … … 53 53 /** 54 54 \brief Nearest Neighbor Classifier 55 55 56 56 A sample is predicted based on the classes of its k nearest 57 57 neighbors among the training data samples. KNN supports using … … 61 61 uniform vote a test sample gets a vote for each class which is the 62 62 number of nearest neighbors belonging to the class. 63 63 64 64 The template argument Distance should be a class modelling the 65 65 concept \ref concept_distance. The template argument … … 70 70 class KNN : public SupervisedClassifier 71 71 { 72 72 73 73 public: 74 74 /** 75 75 \brief Default constructor. 76 76 77 77 The number of nearest neighbors (k) is set to 3. Distance and 78 78 NeighborWeighting are initialized using their default … … 84 84 /** 85 85 \brief Constructor using an intialized distance measure. 86 86 87 87 The number of nearest neighbors (k) is set to 88 88 3. NeighborWeighting is initialized using its default … … 90 90 parameters and the user wants to specify the parameters by 91 91 initializing Distance prior to constructing the KNN. 92 */ 92 */ 93 93 KNN(const Distance&); 94 94 … … 98 98 */ 99 99 virtual ~KNN(); 100 101 100 101 102 102 /** 103 103 \brief Get the number of nearest neighbors. … … 108 108 /** 109 109 \brief Set the number of nearest neighbors. 110 111 Sets the number of neighbors to \a k_in. 110 111 Sets the number of neighbors to \a k_in. 112 112 */ 113 113 void k(unsigned int k_in); … … 115 115 116 116 KNN<Distance,NeighborWeighting>* make_classifier(void) const; 117 117 118 118 /** 119 119 \brief Make predictions for unweighted test data. 120 120 121 121 Predictions are calculated and returned in \a results. For 122 122 each sample in \a data, \a results contains the weighted number … … 128 128 void predict(const MatrixLookup& data , utility::Matrix& results) const; 129 129 130 /** 130 /** 131 131 \brief Make predictions for weighted test data. 132 132 133 133 Predictions are calculated and returned in \a results. For 134 134 each sample in \a data, \a results contains the weighted … … 140 140 this case the distance between the two is set to infinity. 141 141 */ 142 void predict(const MatrixLookupWeighted& data, 142 void predict(const MatrixLookupWeighted& data, 143 143 utility::Matrix& results) const; 144 144 … … 146 146 /** 147 147 \brief Train the KNN using unweighted training data with known 148 targets. 149 148 targets. 149 150 150 For KNN there is no actual training; the entire training data 151 151 set is stored with targets. KNN only stores references to \a data … … 153 153 slow. If the number of training samples set is smaller than k, 154 154 k is set to the number of training samples. 155 155 156 156 \note If \a data or \a targets go out of scope ore are 157 157 deleted, the KNN becomes invalid and further use is undefined … … 159 159 */ 160 160 void train(const MatrixLookup& data, const Target& targets); 161 162 /** 163 \brief Train the KNN using weighted training data with known targets. 164 161 162 /** 163 \brief Train the KNN using weighted training data with known targets. 164 165 165 See train(const MatrixLookup& data, const Target& targets) for 166 166 additional information. 167 167 */ 168 168 void train(const MatrixLookupWeighted& data, const Target& targets); 169 169 170 170 private: 171 171 172 172 const MatrixLookup* data_ml_; 173 173 const MatrixLookupWeighted* data_mlw_; … … 187 187 utility::Matrix*) const; 188 188 189 void predict_common(const utility::Matrix& distances, 189 void predict_common(const utility::Matrix& distances, 190 190 utility::Matrix& prediction) const; 191 191 192 192 }; 193 194 193 194 195 195 /** 196 196 \brief Concept check for a \ref concept_neighbor_weighting 197 197 198 This class is intended to be used in a <a 198 This class is intended to be used in a <a 199 199 href="\boost_url/concept_check/using_concept_check.htm"> 200 200 BOOST_CONCEPT_ASSERT </a> … … 212 212 */ 213 213 template <class T> 214 class NeighborWeightingConcept 214 class NeighborWeightingConcept 215 215 : public boost::DefaultConstructible<T>, public boost::Assignable<T> 216 216 { … … 233 233 234 234 // template implementation 235 236 template <typename Distance, typename NeighborWeighting> 237 KNN<Distance, NeighborWeighting>::KNN() 235 236 template <typename Distance, typename NeighborWeighting> 237 KNN<Distance, NeighborWeighting>::KNN() 238 238 : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3) 239 239 { … … 243 243 244 244 template <typename Distance, typename NeighborWeighting> 245 KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 246 : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3), 245 KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 246 : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3), 247 247 distance_(dist) 248 248 { … … 251 251 } 252 252 253 254 template <typename Distance, typename NeighborWeighting> 255 KNN<Distance, NeighborWeighting>::~KNN() 256 { 257 } 258 253 254 template <typename Distance, typename NeighborWeighting> 255 KNN<Distance, NeighborWeighting>::~KNN() 256 { 257 } 258 259 259 260 260 template <typename Distance, typename NeighborWeighting> … … 265 265 for(size_t i=0; i<training.columns(); i++) { 266 266 for(size_t j=0; j<test.columns(); j++) { 267 (*distances)(i,j) = distance_(training.begin_column(i), 268 training.end_column(i), 267 (*distances)(i,j) = distance_(training.begin_column(i), 268 training.end_column(i), 269 269 test.begin_column(j)); 270 270 YAT_ASSERT(!std::isnan((*distances)(i,j))); … … 273 273 } 274 274 275 276 template <typename Distance, typename NeighborWeighting> 277 void 275 276 template <typename Distance, typename NeighborWeighting> 277 void 278 278 KNN<Distance, NeighborWeighting>::calculate_weighted 279 279 (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test, 280 280 utility::Matrix* distances) const 281 281 { 282 for(size_t i=0; i<training.columns(); i++) { 282 for(size_t i=0; i<training.columns(); i++) { 283 283 for(size_t j=0; j<test.columns(); j++) { 284 (*distances)(i,j) = distance_(training.begin_column(i), 285 training.end_column(i), 284 (*distances)(i,j) = distance_(training.begin_column(i), 285 training.end_column(i), 286 286 test.begin_column(j)); 287 287 // If the distance is NaN (no common variables with non-zero weights), 288 288 // the distance is set to infinity to be sorted as a neighbor at the end 289 if(std::isnan((*distances)(i,j))) 289 if(std::isnan((*distances)(i,j))) 290 290 (*distances)(i,j)=std::numeric_limits<double>::infinity(); 291 291 } 292 292 } 293 293 } 294 295 294 295 296 296 template <typename Distance, typename NeighborWeighting> 297 297 unsigned int KNN<Distance, NeighborWeighting>::k() const … … 308 308 309 309 template <typename Distance, typename NeighborWeighting> 310 KNN<Distance, NeighborWeighting>* 311 KNN<Distance, NeighborWeighting>::make_classifier() const 312 { 310 KNN<Distance, NeighborWeighting>* 311 KNN<Distance, NeighborWeighting>::make_classifier() const 312 { 313 313 // All private members should be copied here to generate an 314 314 // identical but untrained classifier … … 318 318 return knn; 319 319 } 320 321 322 template <typename Distance, typename NeighborWeighting> 323 void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 320 321 322 template <typename Distance, typename NeighborWeighting> 323 void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 324 324 const Target& target) 325 { 325 { 326 326 utility::yat_assert<utility::runtime_error> 327 327 (data.columns()==target.size(), 328 328 "KNN::train called with different sizes of target and data"); 329 329 // k has to be at most the number of training samples. 330 if(data.columns()<k_) 330 if(data.columns()<k_) 331 331 k_=data.columns(); 332 332 data_ml_=&data; … … 336 336 337 337 template <typename Distance, typename NeighborWeighting> 338 void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 338 void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 339 339 const Target& target) 340 { 340 { 341 341 utility::yat_assert<utility::runtime_error> 342 342 (data.columns()==target.size(), 343 343 "KNN::train called with different sizes of target and data"); 344 344 // k has to be at most the number of training samples. 345 if(data.columns()<k_) 345 if(data.columns()<k_) 346 346 k_=data.columns(); 347 347 data_ml_=0; … … 352 352 353 353 template <typename Distance, typename NeighborWeighting> 354 void 354 void 355 355 KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test, 356 356 utility::Matrix& prediction) const 357 { 357 { 358 358 // matrix with training samples as rows and test samples as columns 359 359 utility::Matrix* distances = 0; … … 373 373 distances=new utility::Matrix(data_mlw_->columns(),test.columns()); 374 374 calculate_weighted(*data_mlw_,MatrixLookupWeighted(test), 375 distances); 375 distances); 376 376 } 377 377 else { … … 386 386 387 387 template <typename Distance, typename NeighborWeighting> 388 void 388 void 389 389 KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test, 390 390 utility::Matrix& prediction) const 391 { 391 { 392 392 // matrix with training samples as rows and test samples as columns 393 utility::Matrix* distances=0; 393 utility::Matrix* distances=0; 394 394 // unweighted training data 395 if(data_ml_ && !data_mlw_) { 395 if(data_ml_ && !data_mlw_) { 396 396 utility::yat_assert<utility::runtime_error> 397 397 (data_ml_->rows()==test.rows(), 398 "KNN::predict different number of rows in training and test data"); 398 "KNN::predict different number of rows in training and test data"); 399 399 distances=new utility::Matrix(data_ml_->columns(),test.columns()); 400 calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances); 400 calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances); 401 401 } 402 402 // weighted training data … … 404 404 utility::yat_assert<utility::runtime_error> 405 405 (data_mlw_->rows()==test.rows(), 406 "KNN::predict different number of rows in training and test data"); 406 "KNN::predict different number of rows in training and test data"); 407 407 distances=new utility::Matrix(data_mlw_->columns(),test.columns()); 408 calculate_weighted(*data_mlw_,test,distances); 408 calculate_weighted(*data_mlw_,test,distances); 409 409 } 410 410 else { … … 414 414 prediction.resize(target_->nof_classes(),test.columns(),0.0); 415 415 predict_common(*distances,prediction); 416 416 417 417 if(distances) 418 418 delete distances; 419 419 } 420 420 421 421 template <typename Distance, typename NeighborWeighting> 422 422 void KNN<Distance, NeighborWeighting>::predict_common 423 423 (const utility::Matrix& distances, utility::Matrix& prediction) const 424 { 424 { 425 425 for(size_t sample=0;sample<distances.columns();sample++) { 426 426 std::vector<size_t> k_index; … … 430 430 weighting_(dist,k_index,*target_,pred); 431 431 } 432 432 433 433 // classes for which there are no training samples should be set 434 434 // to nan in the predictions 435 for(size_t c=0;c<target_->nof_classes(); c++) 436 if(!target_->size(c)) 435 for(size_t c=0;c<target_->nof_classes(); c++) 436 if(!target_->size(c)) 437 437 for(size_t j=0;j<prediction.columns();j++) 438 438 prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
Note: See TracChangeset
for help on using the changeset viewer.