Changeset 551 for trunk/test/crossvalidation_test.cc
- Timestamp:
- Mar 7, 2006, 2:39:51 PM (17 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/test/crossvalidation_test.cc
r514 r551 12 12 #include <vector> 13 13 14 // forward declaration 15 void class_count_test(const std::vector<size_t>&, std::ostream*, bool&); 16 void sample_count_test(const std::vector<size_t>&, std::ostream*, bool&); 17 18 14 19 int main(const int argc,const char* argv[]) 15 20 { 16 21 using namespace theplu; 17 22 18 23 std::ostream* error; 19 24 if (argc>1 && argv[1]==std::string("-v")) … … 58 63 class_count[target(cv.training_index()[i])]++; 59 64 } 60 for (size_t i=0; i<class_count.size(); i++) 61 if (class_count[i]==0){ 62 ok = false; 63 *error << "ERROR: class " << i << " was not in training set." 64 << " Expected at least one sample from each class." 65 << std::endl; 65 class_count_test(class_count,error,ok); 66 } 67 sample_count_test(sample_count,error,ok); 68 69 // 70 // Test two nested CrossSplitters 71 // 72 73 *error << "\ntesting two nested crossplitters" << std::endl; 74 label.resize(9); 75 label[0]=label[1]=label[2]="0"; 76 label[3]=label[4]=label[5]="1"; 77 label[6]=label[7]=label[8]="2"; 78 79 target=classifier::Target(label); 80 gslapi::matrix raw_data2(2,9); 81 for(size_t i=0;i<raw_data2.rows();i++) 82 for(size_t j=0;j<raw_data2.columns();j++) 83 raw_data2(i,j)=i*10+10+j+1; 84 85 classifier::MatrixLookup data2(raw_data2); 86 classifier::CrossSplitter cv_test(target,data2,3,3); 87 88 std::vector<size_t> test_sample_count(9,0); 89 std::vector<size_t> test_class_count(3,0); 90 std::vector<double> test_value1(4,0); 91 std::vector<double> test_value2(4,0); 92 std::vector<double> t_value(4,0); 93 std::vector<double> v_value(4,0); 94 while(cv_test.more()) { 95 96 const classifier::DataLookup2D& tv_view=cv_test.training_data(); 97 const classifier::Target& tv_target=cv_test.training_target(); 98 const std::vector<size_t>& tv_index=cv_test.training_index(); 99 const classifier::DataLookup2D& test_view=cv_test.validation_data(); 100 const classifier::Target& test_target=cv_test.validation_target(); 101 const std::vector<size_t>& test_index=cv_test.validation_index(); 102 103 for (size_t i=0; i<test_index.size(); i++) { 104 assert(test_index[i]<sample_count.size()); 105 test_sample_count[test_index[i]]++; 106 test_class_count[target(test_index[i])]++; 107 test_value1[0]+=test_view(0,i); 108 test_value2[0]+=test_view(1,i); 109 test_value1[test_target(i)+1]+=test_view(0,i); 110 test_value2[test_target(i)+1]+=test_view(1,i); 111 if(test_target(i)!=target(test_index[i])) { 112 ok=false; 113 *error << "ERROR: incorrect mapping of test indices" << std:: endl; 114 } 115 } 116 117 classifier::CrossSplitter cv_training(tv_target,tv_view,2,2); 118 std::vector<size_t> v_sample_count(6,0); 119 std::vector<size_t> t_sample_count(6,0); 120 std::vector<size_t> v_class_count(3,0); 121 std::vector<size_t> t_class_count(3,0); 122 std::vector<size_t> t_class_count2(3,0); 123 while(cv_training.more()) { 124 const classifier::DataLookup2D& t_view=cv_training.training_data(); 125 const classifier::Target& t_target=cv_training.training_target(); 126 const std::vector<size_t>& t_index=cv_training.training_index(); 127 const classifier::DataLookup2D& v_view=cv_training.validation_data(); 128 const classifier::Target& v_target=cv_training.validation_target(); 129 const std::vector<size_t>& v_index=cv_training.validation_index(); 130 131 if (test_index.size()+tv_index.size()!=target.size() 132 || t_index.size()+v_index.size() != tv_target.size() 133 || test_index.size()+v_index.size()+t_index.size() != target.size()){ 134 ok = false; 135 *error << "ERROR: size of training samples, validation samples " 136 << "and test samples in is invalid." 137 << std::endl; 138 } 139 if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 || 140 v_index.size()!=3){ 141 ok = false; 142 *error << "ERROR: size of training, validation, and test samples" 143 << " is invalid." 144 << " Expected sizes to be 3" << std::endl; 145 } 146 147 for (size_t i=0; i<t_index.size(); i++) { 148 assert(t_index[i]<t_sample_count.size()); 149 t_sample_count[t_index[i]]++; 150 t_class_count[t_target(i)]++; 151 t_class_count2[tv_target(t_index[i])]++; 152 t_value[0]+=t_view(0,i); 153 t_value[t_target(i)+1]+=t_view(0,i); 66 154 } 67 } 68 for (size_t i=0; i<sample_count.size(); i++){ 69 if (sample_count[i]!=1){ 70 ok = false; 71 *error << "ERROR: sample " << i << " was validated " << sample_count[i] 72 << " times." << " Expected to be 1 time" << std::endl; 73 } 74 } 75 155 for (size_t i=0; i<v_index.size(); i++) { 156 assert(v_index[i]<v_sample_count.size()); 157 v_sample_count[v_index[i]]++; 158 v_class_count[v_target(i)]++; 159 v_value[0]+=v_view(0,i); 160 v_value[v_target(i)+1]+=v_view(0,i); 161 } 162 163 cv_training.next(); 164 } 165 sample_count_test(v_sample_count,error,ok); 166 sample_count_test(t_sample_count,error,ok); 167 168 class_count_test(t_class_count,error,ok); 169 class_count_test(t_class_count2,error,ok); 170 class_count_test(v_class_count,error,ok); 171 172 173 cv_test.next(); 174 } 175 sample_count_test(test_sample_count,error,ok); 176 class_count_test(test_class_count,error,ok); 177 178 if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 || 179 test_value1[3]!=54) { 180 ok=false; 181 *error << "ERROR: incorrect sums of test values in row 1" 182 << " found: " << test_value1[0] << ", " << test_value1[1] 183 << ", " << test_value1[2] << " and " << test_value1[3] 184 << std::endl; 185 } 186 187 188 if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 || 189 test_value2[3]!=84) { 190 ok=false; 191 *error << "ERROR: incorrect sums of test values in row 2" 192 << " found: " << test_value2[0] << ", " << test_value2[1] 193 << ", " << test_value2[2] << " and " << test_value2[3] 194 << std::endl; 195 } 196 197 if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108) { 198 ok=false; 199 *error << "ERROR: incorrect sums of training values in row 1" 200 << " found: " << t_value[0] << ", " << t_value[1] 201 << ", " << t_value[2] << " and " << t_value[3] 202 << std::endl; 203 } 204 205 if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108) { 206 ok=false; 207 *error << "ERROR: incorrect sums of validation values in row 1" 208 << " found: " << v_value[0] << ", " << v_value[1] 209 << ", " << v_value[2] << " and " << v_value[3] 210 << std::endl; 211 } 212 213 214 76 215 if (error!=&std::cerr) 77 216 delete error; 78 217 79 218 if (ok) 80 219 return 0; 81 220 return -1; 82 221 } 222 223 224 void class_count_test(const std::vector<size_t>& class_count, 225 std::ostream* error, bool& ok) 226 { 227 for (size_t i=0; i<class_count.size(); i++) 228 if (class_count[i]==0){ 229 ok = false; 230 *error << "ERROR: class " << i << " was not in set." 231 << " Expected at least one sample from each class." 232 << std::endl; 233 } 234 } 235 236 void sample_count_test(const std::vector<size_t>& sample_count, 237 std::ostream* error, bool& ok) 238 { 239 for (size_t i=0; i<sample_count.size(); i++){ 240 if (sample_count[i]!=1){ 241 ok = false; 242 *error << "ERROR: sample " << i << " was validated " << sample_count[i] 243 << " times." << " Expected to be 1 time" << std::endl; 244 } 245 } 246 }
Note: See TracChangeset
for help on using the changeset viewer.