 Timestamp:
 Oct 8, 2007, 4:06:53 PM (15 years ago)
 Location:
 trunk/yat/classifier
 Files:

 3 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/KNN.h
r936 r948 5 5 6 6 #include "DataLookupWeighted1D.h" 7 #include "MatrixLookup.h" 7 8 #include "MatrixLookupWeighted.h" 8 9 #include "SupervisedClassifier.h" … … 31 32 public: 32 33 /// 33 /// Constructor taking the training data with weights, the target 34 /// vector, the distance measure, and a weight matrix for the 35 /// training data as input. 34 /// Constructor taking the training data and the target 35 /// as input. 36 /// 37 KNN(const MatrixLookup&, const Target&); 38 39 40 /// 41 /// Constructor taking the training data with weights and the 42 /// target as input. 36 43 /// 37 44 KNN(const MatrixLookupWeighted&, const Target&); … … 93 100 utility::matrix* calculate_distances(const DataLookup2D&) const; 94 101 }; 95 96 ///97 /// The output operator for the KNN class.98 ///99 // std::ostream& operator<< (std::ostream&, const KNN&);100 102 101 103 102 104 // templates 103 105 106 template <typename Distance> 107 KNN<Distance>::KNN(const MatrixLookup& data, const Target& target) 108 : SupervisedClassifier(target), data_(data),k_(3) 109 { 110 } 111 112 104 113 template <typename Distance> 105 114 KNN<Distance>::KNN(const MatrixLookupWeighted& data, const Target& target) … … 116 125 utility::matrix* KNN<Distance>::calculate_distances(const DataLookup2D& input) const 117 126 { 118 const MatrixLookupWeighted* weighted_data =119 dynamic_cast<const MatrixLookupWeighted*>(&data_);120 const MatrixLookupWeighted* weighted_input =121 dynamic_cast<const MatrixLookupWeighted*>(&input);122 123 127 // matrix with training samples as rows and test samples as columns 124 128 utility::matrix* distances = 125 129 new utility::matrix(data_.columns(),input.columns()); 126 127 if(weighted_data && weighted_input) { 130 131 // if both training and test are unweighted: unweighted 132 // calculations are used. 133 const MatrixLookup* test_unweighted = 134 dynamic_cast<const MatrixLookup*>(&input); 135 if(test_unweighted && !data_.weighted()) { 136 const MatrixLookup* data_unweighted = 137 dynamic_cast<const MatrixLookup*>(&data_); 128 138 for(size_t i=0; i<data_.columns(); i++) { 129 classifier::DataLookup Weighted1D training(*weighted_data,i,false);139 classifier::DataLookup1D training(*data_unweighted,i,false); 130 140 for(size_t j=0; j<input.columns(); j++) { 131 classifier::DataLookup Weighted1D test(*weighted_input,j,false);141 classifier::DataLookup1D test(*test_unweighted,j,false); 132 142 utility::yat_assert<std::runtime_error>(training.size()==test.size()); 133 143 (*distances)(i,j) = … … 138 148 } 139 149 } 150 // if either training or test is weighted: weighted calculations 151 // are used. 140 152 else { 141 std::string str; 142 str = "Error in KNN::calculate_distances: Only MatrixLookupWeighted supported still."; 143 throw std::runtime_error(str); 153 const MatrixLookupWeighted* data_weighted = 154 dynamic_cast<const MatrixLookupWeighted*>(&data_); 155 const MatrixLookupWeighted* test_weighted = 156 dynamic_cast<const MatrixLookupWeighted*>(&input); 157 if(data_weighted && test_weighted) { 158 for(size_t i=0; i<data_.columns(); i++) { 159 classifier::DataLookupWeighted1D training(*data_weighted,i,false); 160 for(size_t j=0; j<input.columns(); j++) { 161 classifier::DataLookupWeighted1D test(*test_weighted,j,false); 162 utility::yat_assert<std::runtime_error>(training.size()==test.size()); 163 (*distances)(i,j) = 164 statistics::vector_distance(training.begin(),training.end(), 165 test.begin(), typename statistics::vector_distance_traits<Distance>::distance()); 166 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j))); 167 } 168 } 169 } 170 else { 171 std::string str; 172 str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted"; 173 throw std::runtime_error(str); 174 } 144 175 } 145 176 return distances; … … 176 207 target); 177 208 } 209 else { 210 knn=new KNN<Distance>(dynamic_cast<const MatrixLookup&>(data), 211 target); 212 } 178 213 knn>k(this>k()); 179 214 } … … 210 245 } 211 246 } 247 prediction*=(1.0/k_); 212 248 delete distances; 213 249 } 
trunk/yat/classifier/NBC.cc
r865 r948 63 63 NBC::make_classifier(const DataLookup2D& data, const Target& target) const 64 64 { 65 NBC* ncc=0; 66 if(data.weighted()) { 67 ncc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target); 65 NBC* nbc=0; 66 try { 67 if(data.weighted()) { 68 nbc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target); 69 } 70 else { 71 nbc=new NBC(dynamic_cast<const MatrixLookup&>(data),target); 72 } 68 73 } 69 else { 70 ncc=new NBC(dynamic_cast<const MatrixLookup&>(data),target); 74 catch (std::bad_cast) { 75 std::string str = "Error in NBC::make_classifier: DataLookup2D of unexpected class."; 76 throw std::runtime_error(str); 71 77 } 72 return n cc;78 return nbc; 73 79 } 74 80 
trunk/yat/classifier/NCC.h
r936 r948 246 246 const MatrixLookupWeighted* test_weighted = 247 247 dynamic_cast<const MatrixLookupWeighted*>(&input); 248 MatrixLookupWeighted weighted_centroids(*centroids_);249 248 if(test_weighted) { 249 MatrixLookupWeighted weighted_centroids(*centroids_); 250 250 for(size_t j=0; j<input.columns();j++) { 251 251 DataLookupWeighted1D in(*test_weighted,j,false); … … 260 260 } 261 261 else if(data_.weighted() && test_unweighted) { 262 // MatrixLookupWeighted test2weighted(*test_unweighted); 263 // Need to convert MatrixLookup to MatrixLookupWeighted here 264 // and use it in the code below 265 for(size_t j=0; j<input.columns();j++) { 266 DataLookupWeighted1D in(*test_weighted,j,false); 267 for(size_t k=0; k<centroids_>columns();k++) { 268 DataLookupWeighted1D centroid(weighted_centroids,k,false); 269 utility::yat_assert<std::runtime_error>(in.size()==centroid.size()); 270 prediction(k,j)=statistics:: 271 vector_distance(in.begin(),in.end(),centroid.begin(), 272 typename statistics::vector_distance_traits<Distance>::distance()); 273 } 274 } 262 std::string str = "Error in NCC<Distance>::predict:"; 263 str += " predicting unweighted data when NCC"; 264 str += " is trained on weighted data is not yet supported"; 265 throw std::runtime_error(str); 275 266 } 276 267 else { 277 std::string str ;278 str ="Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";268 std::string str = 269 "Error in NCC<Distance>::predict: DataLookup2D of unexpected class."; 279 270 throw std::runtime_error(str); 280 271 }
Note: See TracChangeset
for help on using the changeset viewer.