source: trunk/test/crossvalidation_test.cc @ 615

Last change on this file since 615 was 615, checked in by Peter, 15 years ago

ref #60 NOTE: there is most likely a bug around. I have removed the ensemble.build() test in the ensemble_test to get the test go through. I will try to find and remove this bug asap.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 8.3 KB
Line 
1// $Id: crossvalidation_test.cc 615 2006-08-31 05:33:35Z peter $
2
3#include <c++_tools/classifier/CrossValidationSampler.h>
4#include <c++_tools/classifier/SubsetGenerator.h>
5#include <c++_tools/classifier/MatrixLookup.h>
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/gslapi/matrix.h>
8
9#include <cstdlib>
10#include <fstream>
11#include <iostream>
12#include <string>
13#include <vector>
14
15// forward declaration
16void class_count_test(const std::vector<size_t>&, std::ostream*, bool&);
17void sample_count_test(const std::vector<size_t>&, std::ostream*, bool&);
18
19
20int main(const int argc,const char* argv[])
21{ 
22  using namespace theplu;
23 
24  std::ostream* error;
25  if (argc>1 && argv[1]==std::string("-v"))
26    error = &std::cerr;
27  else {
28    error = new std::ofstream("/dev/null");
29    if (argc>1)
30      std::cout << "crossvalidation_test -v : for printing extra information\n";
31  }
32  *error << "testing crosssplitter" << std::endl;
33  bool ok = true;
34
35  std::vector<std::string> label(10,"default");
36  label[2]=label[7]="white";
37  label[4]=label[5]="black";
38  label[6]=label[3]="green";
39  label[8]=label[9]="red";
40                 
41  classifier::Target target(label);
42  gslapi::matrix raw_data(10,10);
43  classifier::MatrixLookup data(raw_data);
44  classifier::CrossValidationSampler cv(target,3,3);
45 
46  std::vector<size_t> sample_count(10,0);
47  for (size_t j=0; j<cv.size(); ++j){
48    std::vector<size_t> class_count(5,0);
49    assert(j<cv.size());
50    if (cv.training_index(j).size()+cv.validation_index(j).size()!=
51        target.size()){
52      ok = false;
53      *error << "ERROR: size of training samples plus " 
54             << "size of validation samples is invalid." << std::endl;
55    }
56    if (cv.validation_index(j).size()!=3 && cv.validation_index(j).size()!=4){
57      ok = false;
58      *error << "ERROR: size of validation samples is invalid." 
59             << "expected size to be 3 or 4" << std::endl;
60    }
61    for (size_t i=0; i<cv.validation_index(j).size(); i++) {
62      assert(cv.validation_index(j)[i]<sample_count.size());
63      sample_count[cv.validation_index(j)[i]]++;
64    }
65    for (size_t i=0; i<cv.training_index(j).size(); i++) {
66      class_count[target(cv.training_index(j)[i])]++;
67    }
68    class_count_test(class_count,error,ok);
69  }
70  sample_count_test(sample_count,error,ok);
71 
72  //
73  // Test two nested CrossSplitters
74  //
75
76  *error << "\ntesting two nested crossplitters" << std::endl;
77  label.resize(9);
78  label[0]=label[1]=label[2]="0";
79  label[3]=label[4]=label[5]="1";
80  label[6]=label[7]=label[8]="2";
81                 
82  target=classifier::Target(label);
83  gslapi::matrix raw_data2(2,9);
84  for(size_t i=0;i<raw_data2.rows();i++)
85    for(size_t j=0;j<raw_data2.columns();j++)
86      raw_data2(i,j)=i*10+10+j+1;
87   
88  classifier::MatrixLookup data2(raw_data2);
89  classifier::CrossValidationSampler cv2(target,3,3);
90  classifier::SubsetGenerator cv_test(cv2,data2);
91
92  std::vector<size_t> test_sample_count(9,0);
93  std::vector<size_t> test_class_count(3,0);
94  std::vector<double> test_value1(4,0);
95  std::vector<double> test_value2(4,0);
96  std::vector<double> t_value(4,0);
97  std::vector<double> v_value(4,0); 
98  cv_test.reset();
99  while(cv_test.more()) {
100   
101    const classifier::DataLookup2D& tv_view=cv_test.training_data();
102    const classifier::Target& tv_target=cv_test.training_target();
103    const std::vector<size_t>& tv_index=cv_test.training_index();
104    const classifier::DataLookup2D& test_view=cv_test.validation_data();
105    const classifier::Target& test_target=cv_test.validation_target();
106    const std::vector<size_t>& test_index=cv_test.validation_index();
107
108    for (size_t i=0; i<test_index.size(); i++) {
109      assert(test_index[i]<sample_count.size());
110      test_sample_count[test_index[i]]++;
111      test_class_count[target(test_index[i])]++;
112      test_value1[0]+=test_view(0,i);
113      test_value2[0]+=test_view(1,i);
114      test_value1[test_target(i)+1]+=test_view(0,i);
115      test_value2[test_target(i)+1]+=test_view(1,i);
116      if(test_target(i)!=target(test_index[i])) {
117        ok=false;
118        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
119      }       
120    }
121   
122    classifier::CrossValidationSampler sampler_training(tv_target,2,2);
123    classifier::SubsetGenerator cv_training(sampler_training,tv_view);
124    std::vector<size_t> v_sample_count(6,0);
125    std::vector<size_t> t_sample_count(6,0);
126    std::vector<size_t> v_class_count(3,0);
127    std::vector<size_t> t_class_count(3,0);
128    std::vector<size_t> t_class_count2(3,0);
129    cv_training.reset();
130    while(cv_training.more()) {
131      const classifier::DataLookup2D& t_view=cv_training.training_data();
132      const classifier::Target& t_target=cv_training.training_target();
133      const std::vector<size_t>& t_index=cv_training.training_index();
134      const classifier::DataLookup2D& v_view=cv_training.validation_data();
135      const classifier::Target& v_target=cv_training.validation_target();
136      const std::vector<size_t>& v_index=cv_training.validation_index();
137     
138      if (test_index.size()+tv_index.size()!=target.size() 
139          || t_index.size()+v_index.size() != tv_target.size() 
140          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
141        ok = false;
142        *error << "ERROR: size of training samples, validation samples " 
143               << "and test samples in is invalid." 
144               << std::endl;
145      }
146      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
147          v_index.size()!=3){
148        ok = false;
149        *error << "ERROR: size of training, validation, and test samples"
150               << " is invalid." 
151               << " Expected sizes to be 3" << std::endl;
152      }     
153
154      std::vector<size_t> tv_sample_count(6,0);
155      for (size_t i=0; i<t_index.size(); i++) {
156        assert(t_index[i]<t_sample_count.size());
157        tv_sample_count[t_index[i]]++;
158        t_sample_count[t_index[i]]++;
159        t_class_count[t_target(i)]++;
160        t_class_count2[tv_target(t_index[i])]++;
161        t_value[0]+=t_view(0,i);
162        t_value[t_target(i)+1]+=t_view(0,i);       
163      }
164      for (size_t i=0; i<v_index.size(); i++) {
165        assert(v_index[i]<v_sample_count.size());
166        tv_sample_count[v_index[i]]++;
167        v_sample_count[v_index[i]]++;
168        v_class_count[v_target(i)]++;
169        v_value[0]+=v_view(0,i);
170        v_value[v_target(i)+1]+=v_view(0,i);
171      }
172 
173      sample_count_test(tv_sample_count,error,ok);     
174
175      cv_training.next();     
176    }
177    sample_count_test(v_sample_count,error,ok);
178    sample_count_test(t_sample_count,error,ok);
179   
180    class_count_test(t_class_count,error,ok);
181    class_count_test(t_class_count2,error,ok);
182    class_count_test(v_class_count,error,ok);
183
184
185    cv_test.next();
186  }
187  sample_count_test(test_sample_count,error,ok);
188  class_count_test(test_class_count,error,ok);
189 
190  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
191     test_value1[3]!=54) {
192    ok=false;
193    *error << "ERROR: incorrect sums of test values in row 1" 
194           << " found: " << test_value1[0] << ", "  << test_value1[1] 
195           << ", "  << test_value1[2] << " and "  << test_value1[3] 
196           << std::endl;
197  }
198
199 
200  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
201     test_value2[3]!=84) {
202    ok=false;
203    *error << "ERROR: incorrect sums of test values in row 2" 
204           << " found: " << test_value2[0] << ", "  << test_value2[1] 
205           << ", "  << test_value2[2] << " and "  << test_value2[3] 
206           << std::endl;
207  }
208
209  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
210    ok=false;
211    *error << "ERROR: incorrect sums of training values in row 1" 
212           << " found: " << t_value[0] << ", "  << t_value[1] 
213           << ", "  << t_value[2] << " and "  << t_value[3] 
214           << std::endl;   
215  }
216
217  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
218    ok=false;
219    *error << "ERROR: incorrect sums of validation values in row 1" 
220           << " found: " << v_value[0] << ", "  << v_value[1] 
221           << ", "  << v_value[2] << " and "  << v_value[3] 
222           << std::endl;   
223  }
224
225
226
227  if (error!=&std::cerr)
228    delete error;
229 
230  if (ok)
231    return 0;
232  return -1;
233}
234
235
236void class_count_test(const std::vector<size_t>& class_count, 
237                      std::ostream* error, bool& ok) 
238{
239  for (size_t i=0; i<class_count.size(); i++)
240    if (class_count[i]==0){
241      ok = false;
242      *error << "ERROR: class " << i << " was not in set." 
243             << " Expected at least one sample from each class." 
244             << std::endl;
245    }
246}
247
248void sample_count_test(const std::vector<size_t>& sample_count, 
249                       std::ostream* error, bool& ok) 
250{
251  for (size_t i=0; i<sample_count.size(); i++){
252    if (sample_count[i]!=1){
253      ok = false;
254      *error << "ERROR: sample " << i << " was in a group " << sample_count[i] 
255             << " times." << " Expected to be 1 time" << std::endl;
256    }
257  }
258}
Note: See TracBrowser for help on using the repository browser.