Changeset 1107 for trunk/yat/classifier
- Timestamp:
- Feb 19, 2008, 4:23:52 PM (16 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/yat/classifier/KNN.h
r1098 r1107 90 90 91 91 /// 92 /// Train the classifier using the training data. Centroids are93 /// calculated for each class.92 /// Train the classifier using the training data. 93 /// This function does nothing but is required by the interface. 94 94 /// 95 95 /// @return true if training succedeed. … … 99 99 100 100 /// 101 /// Calculate the distance to each centroid for test samples 101 /// For each sample, calculate the number of neighbours for each 102 /// class. 103 /// 102 104 /// 103 105 void predict(const DataLookup2D&, utility::matrix&) const; … … 121 123 /// 122 124 utility::matrix* calculate_distances(const DataLookup2D&) const; 125 void calculate_unweighted(const MatrixLookup&, 126 const MatrixLookup&, 127 utility::matrix*) const; 128 void calculate_weighted(const MatrixLookupWeighted&, 129 const MatrixLookupWeighted&, 130 utility::matrix*) const; 123 131 }; 124 132 … … 151 159 new utility::matrix(data_.columns(),test.columns()); 152 160 161 153 162 // unweighted test data 154 163 if(const MatrixLookup* test_unweighted = 155 164 dynamic_cast<const MatrixLookup*>(&test)) { 156 for(size_t i=0; i<data_.columns(); i++) { 157 for(size_t j=0; j<test.columns(); j++) { 158 classifier::DataLookup1D test(*test_unweighted,j,false); 159 classifier::DataLookup1D tmp(data_,i,false); 160 (*distances)(i,j) = distance_(tmp.begin(), tmp.end(), test.begin()); 161 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 162 } 163 } 165 // unweighted training data 166 if(const MatrixLookup* training_unweighted = 167 dynamic_cast<const MatrixLookup*>(&data_)) 168 calculate_unweighted(*training_unweighted,*test_unweighted,distances); 169 // weighted training data 170 else if(const MatrixLookupWeighted* training_weighted = 171 dynamic_cast<const MatrixLookupWeighted*>(&data_)) 172 calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted), 173 distances); 174 // Training data can not be of incorrect type 164 175 } 165 176 // weighted test data 177 else if (const MatrixLookupWeighted* test_weighted = 178 dynamic_cast<const MatrixLookupWeighted*>(&test)) { 179 // unweighted training data 180 if(const MatrixLookup* training_unweighted = 181 dynamic_cast<const MatrixLookup*>(&data_)) { 182 calculate_weighted(MatrixLookupWeighted(*training_unweighted), 183 *test_weighted,distances); 184 } 185 // weighted training data 186 else if(const MatrixLookupWeighted* training_weighted = 187 dynamic_cast<const MatrixLookupWeighted*>(&data_)) 188 calculate_weighted(*training_weighted,*test_weighted,distances); 189 // Training data can not be of incorrect type 190 } 166 191 else { 167 const MatrixLookupWeighted* data_weighted = 168 dynamic_cast<const MatrixLookupWeighted*>(&data_); 169 const MatrixLookupWeighted* test_weighted = 170 dynamic_cast<const MatrixLookupWeighted*>(&test); 171 if(data_weighted && test_weighted) { 172 for(size_t i=0; i<data_.columns(); i++) { 173 classifier::DataLookupWeighted1D training(*data_weighted,i,false); 174 for(size_t j=0; j<test.columns(); j++) { 175 classifier::DataLookupWeighted1D test(*test_weighted,j,false); 176 utility::yat_assert<std::runtime_error>(training.size()==test.size()); 177 (*distances)(i,j) = distance_(training.begin(), training.end(), 178 test.begin()); 179 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 180 } 181 } 182 } 183 else { 184 std::string str; 185 str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted"; 186 throw std::runtime_error(str); 187 } 192 std::string str; 193 str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted"; 194 throw std::runtime_error(str); 188 195 } 189 196 return distances; 190 197 } 198 199 template <typename Distance> 200 void KNN<Distance>:: calculate_unweighted(const MatrixLookup& training, 201 const MatrixLookup& test, 202 utility::matrix* distances) const 203 { 204 for(size_t i=0; i<training.columns(); i++) { 205 classifier::DataLookup1D training1(training,i,false); 206 for(size_t j=0; j<test.columns(); j++) { 207 classifier::DataLookup1D test1(test,j,false); 208 (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin()); 209 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 210 } 211 } 212 } 213 214 template <typename Distance> 215 void KNN<Distance>:: calculate_weighted(const MatrixLookupWeighted& training, 216 const MatrixLookupWeighted& test, 217 utility::matrix* distances) const 218 { 219 for(size_t i=0; i<training.columns(); i++) { 220 classifier::DataLookupWeighted1D training1(training,i,false); 221 for(size_t j=0; j<test.columns(); j++) { 222 classifier::DataLookupWeighted1D test1(test,j,false); 223 (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin()); 224 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 225 } 226 } 227 } 228 191 229 192 230 template <typename Distance>
Note: See TracChangeset
for help on using the changeset viewer.