Changeset 1173
- Timestamp:
- Feb 27, 2008, 4:19:05 PM (16 years ago)
- Location:
- trunk
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/test/subset_generator_test.cc
r1170 r1173 31 31 #include "yat/classifier/PolynomialKernelFunction.h" 32 32 #include "yat/classifier/SubsetGenerator.h" 33 #include "yat/classifier/SVM.h"34 #include "yat/classifier/NCC.h"35 33 #include "yat/statistics/AUC.h" 36 34 #include "yat/utility/Matrix.h" -
trunk/yat/classifier/NCC.h
r1164 r1173 114 114 void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const; 115 115 116 utility::Matrix *centroids_;116 utility::Matrix centroids_; 117 117 bool centroids_nan_; 118 118 Distance distance_; … … 129 129 template <typename Distance> 130 130 NCC<Distance>::NCC() 131 : SupervisedClassifier(), centroids_ (0), centroids_nan_(false)131 : SupervisedClassifier(), centroids_nan_(false) 132 132 { 133 133 } … … 135 135 template <typename Distance> 136 136 NCC<Distance>::NCC(const Distance& dist) 137 : SupervisedClassifier(), centroids_ (0), centroids_nan_(false), distance_(dist)137 : SupervisedClassifier(), centroids_nan_(false), distance_(dist) 138 138 { 139 139 } … … 143 143 NCC<Distance>::~NCC() 144 144 { 145 if(centroids_)146 delete centroids_;147 145 } 148 146 … … 151 149 const utility::Matrix& NCC<Distance>::centroids(void) const 152 150 { 153 return *centroids_;151 return centroids_; 154 152 } 155 153 … … 167 165 void NCC<Distance>::train(const MatrixLookup& data, const Target& target) 168 166 { 169 if(centroids_) 170 delete centroids_; 171 centroids_= new utility::Matrix(data.rows(), target.nof_classes()); 167 centroids_.resize(data.rows(), target.nof_classes()); 172 168 for(size_t i=0; i<data.rows(); i++) { 173 169 std::vector<statistics::Averager> class_averager; … … 177 173 } 178 174 for(size_t c=0;c<target.nof_classes();c++) { 179 (*centroids_)(i,c) = class_averager[c].mean();175 centroids_(i,c) = class_averager[c].mean(); 180 176 } 181 177 } … … 186 182 void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target) 187 183 { 188 if(centroids_) 189 delete centroids_; 190 centroids_= new utility::Matrix(data.rows(), target.nof_classes()); 184 centroids_.resize(data.rows(), target.nof_classes()); 191 185 for(size_t i=0; i<data.rows(); i++) { 192 186 std::vector<statistics::AveragerWeighted> class_averager; … … 198 192 centroids_nan_=true; 199 193 } 200 (*centroids_)(i,c) = class_averager[c].mean();194 centroids_(i,c) = class_averager[c].mean(); 201 195 } 202 196 } … … 209 203 { 210 204 utility::yat_assert<std::runtime_error> 211 (centroids_,"NCC::predict called for untrained classifier"); 212 utility::yat_assert<std::runtime_error> 213 (centroids_->rows()==test.rows(), 205 (centroids_.rows()==test.rows(), 214 206 "NCC::predict test data with incorrect number of rows"); 215 207 216 prediction.resize(centroids_ ->columns(), test.columns());208 prediction.resize(centroids_.columns(), test.columns()); 217 209 218 210 // If weighted training data has resulted in NaN in centroids: weighted calculations … … 231 223 { 232 224 utility::yat_assert<std::runtime_error> 233 (centroids_,"NCC::predict called for untrained classifier"); 234 utility::yat_assert<std::runtime_error> 235 (centroids_->rows()==test.rows(), 225 (centroids_.rows()==test.rows(), 236 226 "NCC::predict test data with incorrect number of rows"); 237 227 238 prediction.resize(centroids_ ->columns(), test.columns());228 prediction.resize(centroids_.columns(), test.columns()); 239 229 predict_weighted(test,prediction); 240 230 } … … 245 235 utility::Matrix& prediction) const 246 236 { 247 MatrixLookup centroids( *centroids_);237 MatrixLookup centroids(centroids_); 248 238 for(size_t j=0; j<test.columns();j++) 249 for(size_t k=0; k<centroids_ ->columns();k++)239 for(size_t k=0; k<centroids_.columns();k++) 250 240 prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 251 241 centroids.begin_column(k)); … … 256 246 utility::Matrix& prediction) const 257 247 { 258 MatrixLookupWeighted weighted_centroids( *centroids_);248 MatrixLookupWeighted weighted_centroids(centroids_); 259 249 for(size_t j=0; j<test.columns();j++) 260 for(size_t k=0; k<centroids_ ->columns();k++)250 for(size_t k=0; k<centroids_.columns();k++) 261 251 prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 262 252 weighted_centroids.begin_column(k));
Note: See TracChangeset
for help on using the changeset viewer.