Changeset 898 for trunk/yat/classifier/NCC.cc
 Timestamp:
 Sep 26, 2007, 3:44:19 PM (15 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/yat/classifier/NCC.cc
r874 r898 27 27 #include "DataLookup1D.h" 28 28 #include "DataLookup2D.h" 29 #include "DataLookupWeighted1D.h" 29 30 #include "MatrixLookup.h" 30 31 #include "MatrixLookupWeighted.h" 31 32 #include "Target.h" 33 #include "yat/statistics/vector_distance.h" 34 #include "yat/statistics/euclidean_vector_distance.h" 35 #include "yat/utility/Iterator.h" 36 #include "yat/utility/IteratorWeighted.h" 32 37 #include "yat/utility/matrix.h" 33 38 #include "yat/utility/vector.h" 34 #include "yat/statistics/Distance.h"35 39 #include "yat/utility/stl_utility.h" 36 40 … … 45 49 46 50 NCC::NCC(const MatrixLookup& data, const Target& target, 47 const statistics:: Distance&distance)51 const statistics::vector_distance_lookup_weighted_ptr distance) 48 52 : SupervisedClassifier(target), distance_(distance), data_(data) 49 53 { … … 51 55 52 56 NCC::NCC(const MatrixLookupWeighted& data, const Target& target, 53 const statistics:: Distance&distance)57 const statistics::vector_distance_lookup_weighted_ptr distance) 54 58 : SupervisedClassifier(target), distance_(distance), data_(data) 55 59 { … … 65 69 return centroids_; 66 70 } 71 67 72 68 69 73 const DataLookup2D& NCC::data(void) const 74 { 70 75 return data_; 71 72 76 } 77 73 78 SupervisedClassifier* 74 79 NCC::make_classifier(const DataLookup2D& data, const Target& target) const … … 109 114 } 110 115 111 112 void NCC::predict(const utility::vector& input, const utility::vector& weights,113 utility::vector& prediction) const114 {115 prediction.clone(utility::vector(centroids_.columns()));116 117 // take care of nan's in centroids118 for(size_t j=0; j<centroids_.columns(); j++) {119 const utility::vector centroid(utility::vector(centroids_,j,false));120 utility::vector wc(centroid.size(),0);121 for(size_t i=0; i<centroid.size(); i++) {122 if(!std::isnan(centroid(i)))123 wc(i)=1.0;124 }125 prediction(j)=distance_(input,centroid,weights,wc);126 }127 }128 129 130 116 void NCC::predict(const DataLookup2D& input, 131 117 utility::matrix& prediction) const 132 118 { 133 prediction.clone(utility::matrix(centroids_.columns(), input.columns())); 134 // weighted case 135 const MatrixLookupWeighted* data =136 dynamic_cast<const MatrixLookupWeighted*>(&input);137 if (data) {138 for(size_t j=0; j<input.columns();j++) {139 utility::vector in(input.rows(),0);140 for(size_t i=0; i<in.size();i++)141 in(i)=data>data(i,j);142 utility::vector weights(in.size(),0);143 for(size_t i=0; i<in.size();i++)144 weights(i)=data>weight(i,j);145 utility::vector out;146 predict(in,weights,out);147 prediction.column(j,out);119 prediction.clone(utility::matrix(centroids_.columns(), input.columns())); 120 121 // Weighted case 122 const MatrixLookupWeighted* testdata = 123 dynamic_cast<const MatrixLookupWeighted*>(&input); 124 if (testdata) { 125 utility::matrix centroid_weights; 126 utility::nan(centroids_,centroid_weights); 127 MatrixLookupWeighted weighted_centroids(centroids_,centroid_weights); 128 for(size_t j=0; j<input.columns();j++) { 129 DataLookupWeighted1D in(*testdata,j,false); 130 for(size_t k=0; k<centroids_.columns();k++) { 131 DataLookupWeighted1D centroid(weighted_centroids,k,false); 132 prediction(k,j)=(*distance_)(in.begin(),in.end(),centroid.begin()); 133 } 148 134 } 149 return;150 135 } 151 // nonweighted case 152 const MatrixLookup* x = dynamic_cast<const MatrixLookup*>(&input); 153 if (!x){ 136 else { 154 137 std::string str; 155 138 str = "Error in NCC::predict: DataLookup2D of unexpected class."; 156 139 throw std::runtime_error(str); 157 140 } 158 for(size_t j=0; j<input.columns();j++) {159 utility::vector in(input.rows(),0);160 for(size_t i=0; i<in.size();i++)161 in(i)=(*data)(i,j);162 utility::vector weights(in.size(),1.0);163 utility::vector out;164 predict(in,weights,out);165 prediction.column(j,out);166 }167 141 } 168 169 170 // additional operators 171 172 // std::ostream& operator<< (std::ostream& s, const NCC& ncc) { 173 // std::copy(ncc.classes().begin(), ncc.classes().end(), 174 // std::ostream_iterator<std::map<double, u_int>::value_type> 175 // (s, "\n")); 176 // s << "\n" << ncc.centroids() << "\n"; 177 // return s; 178 // } 179 142 180 143 }}} // of namespace classifier, yat, and theplu
Note: See TracChangeset
for help on using the changeset viewer.