Changeset 1107
- Timestamp:
- Feb 19, 2008, 4:23:52 PM (16 years ago)
- Location:
- trunk
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/test/knn_test.cc
r1052 r1107 23 23 24 24 #include "yat/classifier/KNN.h" 25 #include "yat/classifier/MatrixLookup.h" 25 26 #include "yat/classifier/MatrixLookupWeighted.h" 26 27 #include "yat/statistics/EuclideanDistance.h" 27 #include "yat/statistics/PearsonDistance.h"28 28 #include "yat/utility/matrix.h" 29 29 … … 38 38 39 39 using namespace theplu::yat; 40 41 double deviation(const utility::matrix& a, const utility::matrix& b) { 42 double sl=0; 43 for (size_t i=0; i<a.rows(); i++){ 44 for (size_t j=0; j<a.columns(); j++){ 45 sl += fabs(a(i,j)-b(i,j)); 46 } 47 } 48 sl /= (a.columns()*a.rows()); 49 return sl; 50 } 40 51 41 52 int main(const int argc,const char* argv[]) … … 52 63 *error << "testing knn" << std::endl; 53 64 bool ok = true; 65 66 //////////////////////////////////////////////////////////////// 67 // A test of training and predictions using unweighted data 68 //////////////////////////////////////////////////////////////// 69 *error << "test of predictions using unweighted training and test data\n"; 70 utility::matrix data1(3,4); 71 for(size_t i=0;i<3;i++) { 72 data1(i,0)=3-i; 73 data1(i,1)=5-i; 74 data1(i,2)=i+1; 75 data1(i,3)=i+3; 76 } 77 std::vector<std::string> vec1(4, "pos"); 78 vec1[0]="neg"; 79 vec1[1]="neg"; 54 80 55 std::ifstream is("data/sorlie_centroid_data.txt"); 56 utility::matrix data(is,'\t'); 57 is.close(); 81 classifier::MatrixLookup ml1(data1); 82 classifier::Target target1(vec1); 58 83 59 is.open("data/sorlie_centroid_classes.txt"); 60 classifier::Target targets(is); 61 is.close(); 84 classifier::KNN<statistics::EuclideanDistance> knn1(ml1,target1); 85 knn1.k(3); 86 knn1.train(); 87 utility::matrix prediction1; 88 knn1.predict(ml1,prediction1); 89 double slack_bound=2e-7; 90 utility::matrix result1(2,4); 91 result1(0,0)=result1(0,1)=result1(1,2)=result1(1,3)=2.0/3.0; 92 result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=1.0/3.0; 93 double slack = deviation(prediction1,result1); 94 if (slack > slack_bound || std::isnan(slack)){ 95 *error << "Difference to expected prediction too large\n"; 96 *error << "slack: " << slack << std::endl; 97 *error << "expected less than " << slack_bound << std::endl; 98 ok = false; 99 } 100 62 101 63 // Generate weight matrix with 0 for missing values and 1 for others. 64 utility::matrix weights(data.rows(),data.columns(),0.0); 65 utility::nan(data,weights); 102 //////////////////////////////////////////////////////////////// 103 // A test of training unweighted and test weighted 104 //////////////////////////////////////////////////////////////// 105 *error << "test of predictions using unweighted training and weighted test data\n"; 106 utility::matrix weights1(3,4,1.0); 107 weights1(2,0)=0; 108 classifier::MatrixLookupWeighted mlw1(data1,weights1); 109 knn1.predict(mlw1,prediction1); 110 result1(0,0)=1.0/3.0; 111 result1(1,0)=2.0/3.0; 112 slack = deviation(prediction1,result1); 113 if (slack > slack_bound || std::isnan(slack)){ 114 *error << "Difference to expected prediction too large\n"; 115 *error << "slack: " << slack << std::endl; 116 *error << "expected less than " << slack_bound << std::endl; 117 ok = false; 118 } 66 119 67 classifier::MatrixLookupWeighted dataviewweighted(data,weights); 68 classifier::KNN<statistics::PearsonDistance> knn(dataviewweighted,targets); 69 *error << "training KNN" << std::endl; 70 knn.train(); 71 72 utility::matrix prediction; 73 knn.predict(dataviewweighted,prediction); 74 *error << prediction << std::endl; 75 120 //////////////////////////////////////////////////////////////// 121 // A test of training and test both weighted 122 //////////////////////////////////////////////////////////////// 123 *error << "test of predictions using weighted training and test data\n"; 124 weights1(0,1)=0; 125 utility::matrix weights2(3,4,1.0); 126 weights2(2,3)=0; 127 classifier::MatrixLookupWeighted mlw2(data1,weights2); 128 classifier::KNN<statistics::EuclideanDistance> knn2(mlw2,target1); 129 knn2.k(3); 130 knn2.train(); 131 knn2.predict(mlw1,prediction1); 132 result1(0,1)=1.0/3.0; 133 result1(1,1)=2.0/3.0; 134 slack = deviation(prediction1,result1); 135 if (slack > slack_bound || std::isnan(slack)){ 136 *error << "Difference to expected prediction too large\n"; 137 *error << "slack: " << slack << std::endl; 138 *error << "expected less than " << slack_bound << std::endl; 139 ok = false; 140 } 141 142 76 143 if(!ok) { 77 144 *error << "knn_test failed" << std::endl; -
trunk/yat/classifier/KNN.h
r1098 r1107 90 90 91 91 /// 92 /// Train the classifier using the training data. Centroids are93 /// calculated for each class.92 /// Train the classifier using the training data. 93 /// This function does nothing but is required by the interface. 94 94 /// 95 95 /// @return true if training succedeed. … … 99 99 100 100 /// 101 /// Calculate the distance to each centroid for test samples 101 /// For each sample, calculate the number of neighbours for each 102 /// class. 103 /// 102 104 /// 103 105 void predict(const DataLookup2D&, utility::matrix&) const; … … 121 123 /// 122 124 utility::matrix* calculate_distances(const DataLookup2D&) const; 125 void calculate_unweighted(const MatrixLookup&, 126 const MatrixLookup&, 127 utility::matrix*) const; 128 void calculate_weighted(const MatrixLookupWeighted&, 129 const MatrixLookupWeighted&, 130 utility::matrix*) const; 123 131 }; 124 132 … … 151 159 new utility::matrix(data_.columns(),test.columns()); 152 160 161 153 162 // unweighted test data 154 163 if(const MatrixLookup* test_unweighted = 155 164 dynamic_cast<const MatrixLookup*>(&test)) { 156 for(size_t i=0; i<data_.columns(); i++) { 157 for(size_t j=0; j<test.columns(); j++) { 158 classifier::DataLookup1D test(*test_unweighted,j,false); 159 classifier::DataLookup1D tmp(data_,i,false); 160 (*distances)(i,j) = distance_(tmp.begin(), tmp.end(), test.begin()); 161 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 162 } 163 } 165 // unweighted training data 166 if(const MatrixLookup* training_unweighted = 167 dynamic_cast<const MatrixLookup*>(&data_)) 168 calculate_unweighted(*training_unweighted,*test_unweighted,distances); 169 // weighted training data 170 else if(const MatrixLookupWeighted* training_weighted = 171 dynamic_cast<const MatrixLookupWeighted*>(&data_)) 172 calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted), 173 distances); 174 // Training data can not be of incorrect type 164 175 } 165 176 // weighted test data 177 else if (const MatrixLookupWeighted* test_weighted = 178 dynamic_cast<const MatrixLookupWeighted*>(&test)) { 179 // unweighted training data 180 if(const MatrixLookup* training_unweighted = 181 dynamic_cast<const MatrixLookup*>(&data_)) { 182 calculate_weighted(MatrixLookupWeighted(*training_unweighted), 183 *test_weighted,distances); 184 } 185 // weighted training data 186 else if(const MatrixLookupWeighted* training_weighted = 187 dynamic_cast<const MatrixLookupWeighted*>(&data_)) 188 calculate_weighted(*training_weighted,*test_weighted,distances); 189 // Training data can not be of incorrect type 190 } 166 191 else { 167 const MatrixLookupWeighted* data_weighted = 168 dynamic_cast<const MatrixLookupWeighted*>(&data_); 169 const MatrixLookupWeighted* test_weighted = 170 dynamic_cast<const MatrixLookupWeighted*>(&test); 171 if(data_weighted && test_weighted) { 172 for(size_t i=0; i<data_.columns(); i++) { 173 classifier::DataLookupWeighted1D training(*data_weighted,i,false); 174 for(size_t j=0; j<test.columns(); j++) { 175 classifier::DataLookupWeighted1D test(*test_weighted,j,false); 176 utility::yat_assert<std::runtime_error>(training.size()==test.size()); 177 (*distances)(i,j) = distance_(training.begin(), training.end(), 178 test.begin()); 179 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 180 } 181 } 182 } 183 else { 184 std::string str; 185 str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted"; 186 throw std::runtime_error(str); 187 } 192 std::string str; 193 str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted"; 194 throw std::runtime_error(str); 188 195 } 189 196 return distances; 190 197 } 198 199 template <typename Distance> 200 void KNN<Distance>:: calculate_unweighted(const MatrixLookup& training, 201 const MatrixLookup& test, 202 utility::matrix* distances) const 203 { 204 for(size_t i=0; i<training.columns(); i++) { 205 classifier::DataLookup1D training1(training,i,false); 206 for(size_t j=0; j<test.columns(); j++) { 207 classifier::DataLookup1D test1(test,j,false); 208 (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin()); 209 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 210 } 211 } 212 } 213 214 template <typename Distance> 215 void KNN<Distance>:: calculate_weighted(const MatrixLookupWeighted& training, 216 const MatrixLookupWeighted& test, 217 utility::matrix* distances) const 218 { 219 for(size_t i=0; i<training.columns(); i++) { 220 classifier::DataLookupWeighted1D training1(training,i,false); 221 for(size_t j=0; j<test.columns(); j++) { 222 classifier::DataLookupWeighted1D test1(test,j,false); 223 (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin()); 224 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 225 } 226 } 227 } 228 191 229 192 230 template <typename Distance>
Note: See TracChangeset
for help on using the changeset viewer.