Changeset 930
 Timestamp:
 Oct 4, 2007, 3:30:02 PM (15 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/NCC.h
r925 r930 105 105 private: 106 106 107 utility::matrix centroids_;107 utility::matrix* centroids_; 108 108 109 109 // data_ has to be of type DataLookup2D to accomodate both … … 123 123 template <typename Distance> 124 124 NCC<Distance>::NCC(const MatrixLookup& data, const Target& target) 125 : SupervisedClassifier(target), data_(data)125 : SupervisedClassifier(target), centroids_(0), data_(data) 126 126 { 127 127 } … … 129 129 template <typename Distance> 130 130 NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target) 131 : SupervisedClassifier(target), data_(data)131 : SupervisedClassifier(target), centroids_(0), data_(data) 132 132 { 133 133 } … … 136 136 NCC<Distance>::~NCC() 137 137 { 138 } 139 138 if(centroids_) 139 delete centroids_; 140 } 140 141 141 142 template <typename Distance> 142 143 const utility::matrix& NCC<Distance>::centroids(void) const 143 144 { 144 return centroids_;145 return *centroids_; 145 146 } 146 147 … … 158 159 NCC* ncc=0; 159 160 if(data.weighted()) { 160 ncc=new NCC<Distance>( dynamic_cast<const MatrixLookupWeighted&>(data),161 ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookupWeighted*>(&data), 161 162 target); 162 163 } 163 164 else { 164 ncc=new NCC<Distance>( dynamic_cast<const MatrixLookup&>(data),165 ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookup*>(&data), 165 166 target); 166 167 } 168 ncc>centroids_=0; 167 169 return ncc; 168 170 } … … 172 174 bool NCC<Distance>::train() 173 175 { 174 centroids_.clone(utility::matrix(data_.rows(), target_.nof_classes())); 176 if(centroids_) 177 delete centroids_; 178 centroids_= new utility::matrix(data_.rows(), target_.nof_classes()); 175 179 utility::matrix nof_in_class(data_.rows(), target_.nof_classes()); 176 180 const MatrixLookupWeighted* weighted_data = … … 180 184 for(size_t i=0; i<data_.rows(); i++) { 181 185 for(size_t j=0; j<data_.columns(); j++) { 182 centroids_(i,target_(j)) += data_(i,j);186 (*centroids_)(i,target_(j)) += data_(i,j); 183 187 if (weighted) 184 188 nof_in_class(i,target_(j))+= weighted_data>weight(i,j); … … 187 191 } 188 192 } 189 centroids_ .div(nof_in_class);193 centroids_>div(nof_in_class); 190 194 trained_=true; 191 195 return trained_; … … 194 198 template <typename Distance> 195 199 void NCC<Distance>::predict(const DataLookup2D& input, 196 utility::matrix& prediction) const200 utility::matrix& prediction) const 197 201 { 198 prediction.clone(utility::matrix(centroids_ .columns(), input.columns()));199 202 prediction.clone(utility::matrix(centroids_>columns(), input.columns())); 203 200 204 // Weighted case 201 205 const MatrixLookupWeighted* testdata = 202 206 dynamic_cast<const MatrixLookupWeighted*>(&input); 203 207 if (testdata) { 204 MatrixLookupWeighted weighted_centroids( centroids_);208 MatrixLookupWeighted weighted_centroids(*centroids_); 205 209 for(size_t j=0; j<input.columns();j++) { 206 210 DataLookupWeighted1D in(*testdata,j,false); 207 for(size_t k=0; k<centroids_ .columns();k++) {211 for(size_t k=0; k<centroids_>columns();k++) { 208 212 DataLookupWeighted1D centroid(weighted_centroids,k,false); 209 210 213 yat_assert(in.size()==centroid.size()); 211 214 prediction(k,j)=statistics:: 212 215 vector_distance(in.begin(),in.end(),centroid.begin(), 213 216 typename statistics::vector_distance_traits<Distance>::distance()); 214 217 } 215 218 } 216 219 } 220 // Nonweighted case 217 221 else { 218 std::string str; 219 str = "Error in NCC<Distance>::predict: DataLookup2D of unexpected class."; 220 throw std::runtime_error(str); 222 const MatrixLookup* testdata = 223 dynamic_cast<const MatrixLookup*>(&input); 224 if (testdata) { 225 MatrixLookup unweighted_centroids(*centroids_); 226 for(size_t j=0; j<input.columns();j++) { 227 DataLookup1D in(*testdata,j,false); 228 for(size_t k=0; k<centroids_>columns();k++) { 229 DataLookup1D centroid(unweighted_centroids,k,false); 230 yat_assert(in.size()==centroid.size()); 231 prediction(k,j)=statistics:: 232 vector_distance(in.begin(),in.end(),centroid.begin(), 233 typename statistics::vector_distance_traits<Distance>::distance()); 234 } 235 } 236 } 237 else { 238 std::string str; 239 str = "Error in NCC<Distance>::predict: DataLookup2D of unexpected class."; 240 throw std::runtime_error(str); 241 } 221 242 } 222 243 }
Note: See TracChangeset
for help on using the changeset viewer.