Changeset 1112 for trunk/yat/classifier/KNN.h
 Timestamp:
 Feb 21, 2008, 3:59:30 PM (15 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/KNN.h
r1107 r1112 27 27 #include "DataLookup1D.h" 28 28 #include "DataLookupWeighted1D.h" 29 #include "KNN_Uniform.h" 29 30 #include "MatrixLookup.h" 30 31 #include "MatrixLookupWeighted.h" … … 43 44 44 45 /// 45 /// @brief Class for Nearest Centroid Classification. 46 /// 47 48 49 template <typename Distance> 46 /// @brief Class for Nearest Neigbor Classification. 47 /// 48 /// The template argument Distance should be a class implementing 49 /// the concept \ref concept_distance. 50 /// The template argument NeigborWeighting should be a class implementing 51 /// the concept \ref concept_neighbor_weighting. 52 53 template <typename Distance, typename NeighborWeighting=KNN_Uniform> 50 54 class KNN : public SupervisedClassifier 51 55 { … … 74 78 75 79 /// 76 /// Default number of neighbo urs (k) is set to 3.77 /// 78 /// @return the number of neighbo urs80 /// Default number of neighbors (k) is set to 3. 81 /// 82 /// @return the number of neighbors 79 83 /// 80 84 u_int k() const; 81 85 82 86 /// 83 /// @brief sets the number of neighbo urs, k.87 /// @brief sets the number of neighbors, k. 84 88 /// 85 89 void k(u_int); … … 99 103 100 104 /// 101 /// For each sample, calculate the number of neighbo urs for each105 /// For each sample, calculate the number of neighbors for each 102 106 /// class. 103 107 /// … … 112 116 const DataLookup2D& data_; 113 117 114 // The number of neighbo urs118 // The number of neighbors 115 119 u_int k_; 116 120 117 121 Distance distance_; 122 123 NeighborWeighting weighting_; 124 118 125 /// 119 126 /// Calculates the distances between a data set and the training … … 123 130 /// 124 131 utility::matrix* calculate_distances(const DataLookup2D&) const; 132 125 133 void calculate_unweighted(const MatrixLookup&, 126 134 const MatrixLookup&, … … 134 142 // templates 135 143 136 template <typename Distance >137 KNN<Distance >::KNN(const MatrixLookup& data, const Target& target)144 template <typename Distance, typename NeighborWeighting> 145 KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target) 138 146 : SupervisedClassifier(target), data_(data),k_(3) 139 147 { … … 141 149 142 150 143 template <typename Distance> 144 KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) 151 template <typename Distance, typename NeighborWeighting> 152 KNN<Distance, NeighborWeighting>::KNN 153 (const MatrixLookupWeighted& data, const Target& target) 145 154 : SupervisedClassifier(target), data_(data),k_(3) 146 155 { 147 156 } 148 157 149 template <typename Distance> 150 KNN<Distance>::~KNN() 151 { 152 } 153 154 template <typename Distance> 155 utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& test) const 158 template <typename Distance, typename NeighborWeighting> 159 KNN<Distance, NeighborWeighting>::~KNN() 160 { 161 } 162 163 template <typename Distance, typename NeighborWeighting> 164 utility::matrix* KNN<Distance, NeighborWeighting>::calculate_distances 165 (const DataLookup2D& test) const 156 166 { 157 167 // matrix with training samples as rows and test samples as columns … … 197 207 } 198 208 199 template <typename Distance >200 void KNN<Distance >:: calculate_unweighted(const MatrixLookup& training,201 202 209 template <typename Distance, typename NeighborWeighting> 210 void KNN<Distance, NeighborWeighting>::calculate_unweighted 211 (const MatrixLookup& training, const MatrixLookup& test, 212 utility::matrix* distances) const 203 213 { 204 214 for(size_t i=0; i<training.columns(); i++) { … … 212 222 } 213 223 214 template <typename Distance> 215 void KNN<Distance>:: calculate_weighted(const MatrixLookupWeighted& training, 216 const MatrixLookupWeighted& test, 217 utility::matrix* distances) const 224 template <typename Distance, typename NeighborWeighting> 225 void 226 KNN<Distance, NeighborWeighting>::calculate_weighted 227 (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test, 228 utility::matrix* distances) const 218 229 { 219 230 for(size_t i=0; i<training.columns(); i++) { … … 221 232 for(size_t j=0; j<test.columns(); j++) { 222 233 classifier::DataLookupWeighted1D test1(test,j,false); 223 (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin()); 234 (*distances)(i,j) = distance_(training1.begin(), training1.end(), 235 test1.begin()); 224 236 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 225 237 } … … 228 240 229 241 230 template <typename Distance >231 const DataLookup2D& KNN<Distance >::data(void) const242 template <typename Distance, typename NeighborWeighting> 243 const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const 232 244 { 233 245 return data_; … … 235 247 236 248 237 template <typename Distance >238 u_int KNN<Distance >::k() const249 template <typename Distance, typename NeighborWeighting> 250 u_int KNN<Distance, NeighborWeighting>::k() const 239 251 { 240 252 return k_; 241 253 } 242 254 243 template <typename Distance >244 void KNN<Distance >::k(u_int k)255 template <typename Distance, typename NeighborWeighting> 256 void KNN<Distance, NeighborWeighting>::k(u_int k) 245 257 { 246 258 k_=k; … … 248 260 249 261 250 template <typename Distance >262 template <typename Distance, typename NeighborWeighting> 251 263 SupervisedClassifier* 252 KNN<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 264 KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data, 265 const Target& target) const 253 266 { 254 267 KNN* knn=0; 255 268 try { 256 269 if(data.weighted()) { 257 knn=new KNN<Distance >(dynamic_cast<const MatrixLookupWeighted&>(data),258 270 knn=new KNN<Distance, NeighborWeighting> 271 (dynamic_cast<const MatrixLookupWeighted&>(data),target); 259 272 } 260 273 else { 261 knn=new KNN<Distance >(dynamic_cast<const MatrixLookup&>(data),262 274 knn=new KNN<Distance, NeighborWeighting> 275 (dynamic_cast<const MatrixLookup&>(data),target); 263 276 } 264 277 knn>k(this>k()); 265 278 } 266 279 catch (std::bad_cast) { 267 std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class."; 280 std::string str = "Error in KNN<Distance, NeighborWeighting>"; 281 str += "::make_classifier: DataLookup2D of unexpected class."; 268 282 throw std::runtime_error(str); 269 283 } … … 272 286 273 287 274 template <typename Distance >275 void KNN<Distance >::train()288 template <typename Distance, typename NeighborWeighting> 289 void KNN<Distance, NeighborWeighting>::train() 276 290 { 277 291 trained_=true; … … 279 293 280 294 281 template <typename Distance >282 void KNN<Distance >::predict(const DataLookup2D& test,283 utility::matrix& prediction) const295 template <typename Distance, typename NeighborWeighting> 296 void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test, 297 utility::matrix& prediction) const 284 298 { 285 299 utility::yat_assert<std::runtime_error>(data_.rows()==test.rows()); … … 287 301 utility::matrix* distances=calculate_distances(test); 288 302 289 // for each test sample (column in distances) find the closest290 // training samples291 303 prediction.resize(target_.nof_classes(),test.columns(),0.0); 292 304 for(size_t sample=0;sample<distances>columns();sample++) { 293 305 std::vector<size_t> k_index; 294 utility::sort_smallest_index(k_index,k_, 295 distances>column_const_view(sample)); 296 for(size_t j=0;j<k_index.size();j++) { 297 prediction(target_(k_index[j]),sample)++; 298 } 299 } 300 prediction*=(1.0/k_); 306 utility::VectorConstView dist=distances>column_const_view(sample); 307 utility::sort_smallest_index(k_index,k_,dist); 308 utility::VectorView pred=prediction.column_view(sample); 309 weighting_(dist,k_index,target_,pred); 310 } 301 311 delete distances; 302 312 }
Note: See TracChangeset
for help on using the changeset viewer.