Changeset 526 for trunk/lib/classifier/NCC.cc
 Timestamp:
 Mar 1, 2006, 9:49:48 AM (17 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/lib/classifier/NCC.cc
r525 r526 5 5 #include <c++_tools/classifier/DataLookup1D.h> 6 6 #include <c++_tools/classifier/DataLookup2D.h> 7 #include <c++_tools/classifier/MatrixLookup.h> 7 8 #include <c++_tools/classifier/InputRanker.h> 8 9 #include <c++_tools/classifier/Target.h> … … 19 20 namespace classifier { 20 21 21 NCC::NCC(const DataLookup2D& data, const Target& target,22 NCC::NCC(const MatrixLookup& data, const Target& target, 22 23 const statistics::Distance& distance) 23 24 : SupervisedClassifier(target), distance_(distance), matrix_(data) … … 25 26 } 26 27 27 NCC::NCC(const DataLookup2D& data, const Target& target,28 NCC::NCC(const MatrixLookup& data, const Target& target, 28 29 const statistics::Distance& distance, 29 30 statistics::Score& score, size_t nof_inputs) … … 44 45 const Target& target) const 45 46 { 46 NCC* ncc= new NCC(data,target,this>distance_); 47 const MatrixLookup& tmp = dynamic_cast<const MatrixLookup&>(data); 48 49 NCC* ncc= new NCC(tmp,target,this>distance_); 47 50 ncc>score_=this>score_; 48 51 ncc>nof_inputs_=this>nof_inputs_; … … 53 56 bool NCC::train() 54 57 { 58 // If score is set calculate centroids only for nof_inputs_ number 59 // of top ranked inputs. Otherwise calculate centroids based on 60 // all inputs ( = all rows in data matrix). 55 61 if(ranker_) 56 62 delete ranker_; 57 if(score_) 58 ranker_=new InputRanker(matrix_, target_, *score_); 59 // Markus : ranker_ should be taken into account if used!!! 60 61 // Calculate the centroids for each class 62 centroids_=gslapi::matrix(matrix_.rows(),target_.nof_classes()); 63 gslapi::matrix nof_in_class(matrix_.rows(),target_.nof_classes()); 64 for(size_t i=0; i<matrix_.rows(); i++) { 63 size_t rows=matrix_.rows(); 64 if(score_) { 65 // Markus: missing values should not be handled here, but a weight matrix 66 // should be supported throughout the classifier class structure. 67 gslapi::matrix weight(matrix_.rows(),matrix_.columns(),0.0); 68 for(size_t i=0; i<matrix_.rows(); i++) 69 for(size_t j=0; j<matrix_.columns(); j++) 70 if(!std::isnan(matrix_(i,j))) 71 weight(i,j)=1.0; 72 MatrixLookup weightview(weight); 73 ranker_=new InputRanker(matrix_, target_, *score_, weightview); 74 rows=nof_inputs_; 75 } 76 centroids_=gslapi::matrix(rows, target_.nof_classes()); 77 gslapi::matrix nof_in_class(rows, target_.nof_classes()); 78 for(size_t i=0; i<rows; i++) { 65 79 for(size_t j=0; j<matrix_.columns(); j++) { 66 if(!std::isnan(matrix_(i,j))) { 67 centroids_(i,target_(j)) += matrix_(i,j); 80 double value=matrix_(i,j); 81 if(score_) 82 value=matrix_(ranker_>id(i),j); 83 if(!std::isnan(value)) { 84 centroids_(i,target_(j)) += value; 68 85 nof_in_class(i,target_(j))++; 69 86 } … … 75 92 } 76 93 94 77 95 void NCC::predict(const DataLookup1D& input, 78 96 gslapi::vector& prediction) const 79 97 { 80 // Markus : ranker_ should be taken into account if used!!!81 82 98 prediction=gslapi::vector(centroids_.columns()); 83 gslapi::vector w(input.size(),0); 84 for(size_t i=0; i<input.size(); i++) // take care of missing values 85 if(!std::isnan(input(i))) 99 size_t size=input.size(); 100 if(ranker_) 101 size=nof_inputs_; 102 gslapi::vector w(size,0); 103 gslapi::vector value(size,0); 104 for(size_t i=0; i<size; i++) { // take care of missing values 105 value(i)=input(i); 106 if(ranker_) 107 value(i)=input(ranker_>id(i)); 108 if(!std::isnan(value(i))) 86 109 w(i)=1.0; 110 } 87 111 for(size_t j=0; j<centroids_.columns(); j++) 88 prediction(j)=distance_(gslapi::vector(input), 89 gslapi::vector(centroids_,j,false),w, w); 112 prediction(j)=distance_(value,gslapi::vector(centroids_,j,false),w, w); 90 113 } 91 114 … … 94 117 gslapi::matrix& prediction) const 95 118 { 96 // Markus : ranker_ should be taken into account if used!!!97 98 119 prediction=gslapi::matrix(centroids_.columns(), input.columns()); 99 120 for(size_t j=0; j<input.columns();j++) { … … 108 129 // additional operators 109 130 110 std::ostream& operator<< (std::ostream& s, const NCC& ncc) {131 // std::ostream& operator<< (std::ostream& s, const NCC& ncc) { 111 132 // std::copy(ncc.classes().begin(), ncc.classes().end(), 112 133 // std::ostream_iterator<std::map<double, u_int>::value_type> 113 134 // (s, "\n")); 114 s << "\n" << ncc.centroids() << "\n";115 return s;116 }135 // s << "\n" << ncc.centroids() << "\n"; 136 // return s; 137 // } 117 138 118 139 }} // of namespace classifier and namespace theplu
Note: See TracChangeset
for help on using the changeset viewer.