Changeset 1157 for trunk/yat/classifier/KNN.h
 Timestamp:
 Feb 26, 2008, 2:25:19 PM (15 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/KNN.h
r1156 r1157 57 57 public: 58 58 /// 59 /// Constructor taking the training data and the target 60 /// as input. 61 /// 62 KNN(const MatrixLookup&, const Target&); 63 64 65 /// 66 /// Constructor taking the training data with weights and the 67 /// target as input. 68 /// 69 KNN(const MatrixLookupWeighted&, const Target&); 70 59 /// @brief Constructor 60 /// 61 KNN(void); 62 63 64 /// 65 /// @brief Destructor 66 /// 71 67 virtual ~KNN(); 72 68 73 //74 // @return the training data75 //76 const DataLookup2D& data(void) const;77 78 69 79 70 /// … … 85 76 86 77 /// 87 /// @brief sets the number of neighbors, k. If the number of 88 /// training samples set is smaller than \a k_in, k is set to the number of 89 /// training samples. 78 /// @brief sets the number of neighbors, k. 90 79 /// 91 80 void k(u_int k_in); 92 81 93 82 94 KNN<Distance,NeighborWeighting>* make_classifier(const DataLookup2D&, 95 const Target&) const; 96 97 /// 98 /// Train the classifier using the training data. 99 /// This function does nothing but is required by the interface. 100 /// 101 void train(); 83 KNN<Distance,NeighborWeighting>* make_classifier(void) const; 84 85 /// 86 /// Train the classifier using training data and target. 87 /// 88 /// If the number of training samples set is smaller than \a k_in, 89 /// k is set to the number of training samples. 90 /// 91 void train(const MatrixLookup&, const Target&); 92 93 /// 94 /// Train the classifier using weighted training data and target. 95 /// 96 void train(const MatrixLookupWeighted&, const Target&); 102 97 103 98 … … 114 109 // data_ has to be of type DataLookup2D to accomodate both 115 110 // MatrixLookup and MatrixLookupWeighted 116 const DataLookup2D& data_; 111 const DataLookup2D* data_; 112 const Target* target_; 117 113 118 114 // The number of neighbors … … 143 139 144 140 template <typename Distance, typename NeighborWeighting> 145 KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target) 146 : SupervisedClassifier(target), data_(data),k_(3) 147 { 148 utility::yat_assert<std::runtime_error> 149 (data.columns()==target.size(), 150 "KNN::KNN called with different sizes of target and data"); 151 // k has to be at most the number of training samples. 152 if(data_.columns()>k_) 153 k_=data_.columns(); 154 } 155 156 157 template <typename Distance, typename NeighborWeighting> 158 KNN<Distance, NeighborWeighting>::KNN 159 (const MatrixLookupWeighted& data, const Target& target) 160 : SupervisedClassifier(target), data_(data),k_(3) 161 { 162 utility::yat_assert<std::runtime_error> 163 (data.columns()==target.size(), 164 "KNN::KNN called with different sizes of target and data"); 165 if(data_.columns()>k_) 166 k_=data_.columns(); 167 } 141 KNN<Distance, NeighborWeighting>::KNN() 142 : SupervisedClassifier(),data_(0),target_(0),k_(3) 143 { 144 } 145 168 146 169 147 template <typename Distance, typename NeighborWeighting> … … 178 156 // matrix with training samples as rows and test samples as columns 179 157 utility::Matrix* distances = 180 new utility::Matrix(data_ .columns(),test.columns());158 new utility::Matrix(data_>columns(),test.columns()); 181 159 182 160 … … 186 164 // unweighted training data 187 165 if(const MatrixLookup* training_unweighted = 188 dynamic_cast<const MatrixLookup*>( &data_))166 dynamic_cast<const MatrixLookup*>(data_)) 189 167 calculate_unweighted(*training_unweighted,*test_unweighted,distances); 190 168 // weighted training data 191 169 else if(const MatrixLookupWeighted* training_weighted = 192 dynamic_cast<const MatrixLookupWeighted*>( &data_))170 dynamic_cast<const MatrixLookupWeighted*>(data_)) 193 171 calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted), 194 172 distances); … … 200 178 // unweighted training data 201 179 if(const MatrixLookup* training_unweighted = 202 dynamic_cast<const MatrixLookup*>( &data_)) {180 dynamic_cast<const MatrixLookup*>(data_)) { 203 181 calculate_weighted(MatrixLookupWeighted(*training_unweighted), 204 182 *test_weighted,distances); … … 206 184 // weighted training data 207 185 else if(const MatrixLookupWeighted* training_weighted = 208 dynamic_cast<const MatrixLookupWeighted*>( &data_))186 dynamic_cast<const MatrixLookupWeighted*>(data_)) 209 187 calculate_weighted(*training_weighted,*test_weighted,distances); 210 188 // Training data can not be of incorrect type … … 252 230 } 253 231 } 254 255 256 template <typename Distance, typename NeighborWeighting>257 const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const258 {259 return data_;260 }261 232 262 233 … … 271 242 { 272 243 k_=k; 273 if(k_>data_.columns())274 k_=data_.columns();275 244 } 276 245 … … 278 247 template <typename Distance, typename NeighborWeighting> 279 248 KNN<Distance, NeighborWeighting>* 280 KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data, 281 const Target& target) const 249 KNN<Distance, NeighborWeighting>::make_classifier() const 282 250 { 283 KNN* knn=0; 284 try { 285 if(data.weighted()) { 286 knn=new KNN<Distance, NeighborWeighting> 287 (dynamic_cast<const MatrixLookupWeighted&>(data),target); 288 } 289 else { 290 knn=new KNN<Distance, NeighborWeighting> 291 (dynamic_cast<const MatrixLookup&>(data),target); 292 } 293 knn>k(this>k()); 294 } 295 catch (std::bad_cast) { 296 std::string str = "Error in KNN<Distance, NeighborWeighting>"; 297 str += "::make_classifier: DataLookup2D of unexpected class."; 298 throw std::runtime_error(str); 299 } 251 KNN* knn=new KNN<Distance, NeighborWeighting>(); 252 knn>k(this>k()); 300 253 return knn; 301 254 } … … 303 256 304 257 template <typename Distance, typename NeighborWeighting> 305 void KNN<Distance, NeighborWeighting>::train() 258 void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 259 const Target& target) 306 260 { 261 utility::yat_assert<std::runtime_error> 262 (data.columns()==target.size(), 263 "KNN::train called with different sizes of target and data"); 264 // k has to be at most the number of training samples. 265 if(data.columns()<k_) 266 k_=data.columns(); 267 data_=&data; 268 target_=⌖ 269 trained_=true; 270 } 271 272 template <typename Distance, typename NeighborWeighting> 273 void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 274 const Target& target) 275 { 276 utility::yat_assert<std::runtime_error> 277 (data.columns()==target.size(), 278 "KNN::train called with different sizes of target and data"); 279 // k has to be at most the number of training samples. 280 if(data.columns()<k_) 281 k_=data.columns(); 282 data_=&data; 283 target_=⌖ 307 284 trained_=true; 308 285 } … … 313 290 utility::Matrix& prediction) const 314 291 { 315 utility::yat_assert<std::runtime_error>(data_ .rows()==test.rows(),"KNN::predict different number of rows in training and test data");292 utility::yat_assert<std::runtime_error>(data_>rows()==test.rows(),"KNN::predict different number of rows in training and test data"); 316 293 317 294 utility::Matrix* distances=calculate_distances(test); 318 295 319 prediction.resize(target_ .nof_classes(),test.columns(),0.0);296 prediction.resize(target_>nof_classes(),test.columns(),0.0); 320 297 for(size_t sample=0;sample<distances>columns();sample++) { 321 298 std::vector<size_t> k_index; … … 323 300 utility::sort_smallest_index(k_index,k_,dist); 324 301 utility::VectorView pred=prediction.column_view(sample); 325 weighting_(dist,k_index, target_,pred);302 weighting_(dist,k_index,*target_,pred); 326 303 } 327 304 delete distances; … … 329 306 // classes for which there are no training samples should be set 330 307 // to nan in the predictions 331 for(size_t c=0;c<target_ .nof_classes(); c++)332 if(!target_ .size(c))308 for(size_t c=0;c<target_>nof_classes(); c++) 309 if(!target_>size(c)) 333 310 for(size_t j=0;j<prediction.columns();j++) 334 311 prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
Note: See TracChangeset
for help on using the changeset viewer.