source: trunk/test/crossvalidation_test.cc @ 551

Last change on this file since 551 was 551, checked in by Markus Ringnér, 17 years ago

Added tests of two nested CrossSplitters? to crossvalidation_test

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 7.8 KB
Line 
1// $Id: crossvalidation_test.cc 551 2006-03-07 13:39:51Z markus $
2
3#include <c++_tools/classifier/CrossSplitter.h>
4#include <c++_tools/classifier/MatrixLookup.h>
5#include <c++_tools/classifier/Target.h>
6#include <c++_tools/gslapi/matrix.h>
7
8#include <cstdlib>
9#include <fstream>
10#include <iostream>
11#include <string>
12#include <vector>
13
14// forward declaration
15void class_count_test(const std::vector<size_t>&, std::ostream*, bool&);
16void sample_count_test(const std::vector<size_t>&, std::ostream*, bool&);
17
18
19int main(const int argc,const char* argv[])
20{ 
21  using namespace theplu;
22 
23  std::ostream* error;
24  if (argc>1 && argv[1]==std::string("-v"))
25    error = &std::cerr;
26  else {
27    error = new std::ofstream("/dev/null");
28    if (argc>1)
29      std::cout << "crossvalidation_test -v : for printing extra information\n";
30  }
31  *error << "testing crosssplitter" << std::endl;
32  bool ok = true;
33
34  std::vector<std::string> label(10,"default");
35  label[2]=label[7]="white";
36  label[4]=label[5]="black";
37  label[6]=label[3]="green";
38  label[8]=label[9]="red";
39                 
40  classifier::Target target(label);
41  gslapi::matrix raw_data(10,10);
42  classifier::MatrixLookup data(raw_data);
43  classifier::CrossSplitter cv(target,data,3,3);
44 
45  std::vector<size_t> sample_count(10,0);
46  for (cv.reset(); cv.more(); cv.next()){
47    std::vector<size_t> class_count(5,0);
48    if (cv.training_index().size()+cv.validation_index().size()!=target.size()){
49      ok = false;
50      *error << "ERROR: size of training samples plus " 
51             << "size of validation samples is invalid." << std::endl;
52    }
53    if (cv.validation_index().size()!=3 && cv.validation_index().size()!=4){
54      ok = false;
55      *error << "ERROR: size of validation samples is invalid." 
56             << "expected size to be 3 or 4" << std::endl;
57    }
58    for (size_t i=0; i<cv.validation_index().size(); i++) {
59      assert(cv.validation_index()[i]<sample_count.size());
60      sample_count[cv.validation_index()[i]]++;
61    }
62    for (size_t i=0; i<cv.training_index().size(); i++) {
63      class_count[target(cv.training_index()[i])]++;
64    }
65    class_count_test(class_count,error,ok);
66  }
67  sample_count_test(sample_count,error,ok);
68 
69  //
70  // Test two nested CrossSplitters
71  //
72
73  *error << "\ntesting two nested crossplitters" << std::endl;
74  label.resize(9);
75  label[0]=label[1]=label[2]="0";
76  label[3]=label[4]=label[5]="1";
77  label[6]=label[7]=label[8]="2";
78                 
79  target=classifier::Target(label);
80  gslapi::matrix raw_data2(2,9);
81  for(size_t i=0;i<raw_data2.rows();i++)
82    for(size_t j=0;j<raw_data2.columns();j++)
83      raw_data2(i,j)=i*10+10+j+1;
84   
85  classifier::MatrixLookup data2(raw_data2);
86  classifier::CrossSplitter cv_test(target,data2,3,3);
87
88  std::vector<size_t> test_sample_count(9,0);
89  std::vector<size_t> test_class_count(3,0);
90  std::vector<double> test_value1(4,0);
91  std::vector<double> test_value2(4,0);
92  std::vector<double> t_value(4,0);
93  std::vector<double> v_value(4,0); 
94  while(cv_test.more()) {
95   
96    const classifier::DataLookup2D& tv_view=cv_test.training_data();
97    const classifier::Target& tv_target=cv_test.training_target();
98    const std::vector<size_t>& tv_index=cv_test.training_index();
99    const classifier::DataLookup2D& test_view=cv_test.validation_data();
100    const classifier::Target& test_target=cv_test.validation_target();
101    const std::vector<size_t>& test_index=cv_test.validation_index();
102
103    for (size_t i=0; i<test_index.size(); i++) {
104      assert(test_index[i]<sample_count.size());
105      test_sample_count[test_index[i]]++;
106      test_class_count[target(test_index[i])]++;
107      test_value1[0]+=test_view(0,i);
108      test_value2[0]+=test_view(1,i);
109      test_value1[test_target(i)+1]+=test_view(0,i);
110      test_value2[test_target(i)+1]+=test_view(1,i);
111      if(test_target(i)!=target(test_index[i])) {
112        ok=false;
113        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
114      }       
115    }
116   
117    classifier::CrossSplitter cv_training(tv_target,tv_view,2,2);
118    std::vector<size_t> v_sample_count(6,0);
119    std::vector<size_t> t_sample_count(6,0);
120    std::vector<size_t> v_class_count(3,0);
121    std::vector<size_t> t_class_count(3,0);
122    std::vector<size_t> t_class_count2(3,0);
123    while(cv_training.more()) {
124      const classifier::DataLookup2D& t_view=cv_training.training_data();
125      const classifier::Target& t_target=cv_training.training_target();
126      const std::vector<size_t>& t_index=cv_training.training_index();
127      const classifier::DataLookup2D& v_view=cv_training.validation_data();
128      const classifier::Target& v_target=cv_training.validation_target();
129      const std::vector<size_t>& v_index=cv_training.validation_index();
130     
131      if (test_index.size()+tv_index.size()!=target.size() 
132          || t_index.size()+v_index.size() != tv_target.size() 
133          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
134        ok = false;
135        *error << "ERROR: size of training samples, validation samples " 
136               << "and test samples in is invalid." 
137               << std::endl;
138      }
139      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
140          v_index.size()!=3){
141        ok = false;
142        *error << "ERROR: size of training, validation, and test samples"
143               << " is invalid." 
144               << " Expected sizes to be 3" << std::endl;
145      }     
146
147      for (size_t i=0; i<t_index.size(); i++) {
148        assert(t_index[i]<t_sample_count.size());
149        t_sample_count[t_index[i]]++;
150        t_class_count[t_target(i)]++;
151        t_class_count2[tv_target(t_index[i])]++;
152        t_value[0]+=t_view(0,i);
153        t_value[t_target(i)+1]+=t_view(0,i);
154      }
155      for (size_t i=0; i<v_index.size(); i++) {
156        assert(v_index[i]<v_sample_count.size());
157        v_sample_count[v_index[i]]++;
158        v_class_count[v_target(i)]++;
159        v_value[0]+=v_view(0,i);
160        v_value[v_target(i)+1]+=v_view(0,i);
161      }
162
163      cv_training.next();     
164    }
165    sample_count_test(v_sample_count,error,ok);
166    sample_count_test(t_sample_count,error,ok);
167   
168    class_count_test(t_class_count,error,ok);
169    class_count_test(t_class_count2,error,ok);
170    class_count_test(v_class_count,error,ok);
171
172
173    cv_test.next();
174  }
175  sample_count_test(test_sample_count,error,ok);
176  class_count_test(test_class_count,error,ok);
177 
178  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
179     test_value1[3]!=54) {
180    ok=false;
181    *error << "ERROR: incorrect sums of test values in row 1" 
182           << " found: " << test_value1[0] << ", "  << test_value1[1] 
183           << ", "  << test_value1[2] << " and "  << test_value1[3] 
184           << std::endl;
185  }
186
187 
188  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
189     test_value2[3]!=84) {
190    ok=false;
191    *error << "ERROR: incorrect sums of test values in row 2" 
192           << " found: " << test_value2[0] << ", "  << test_value2[1] 
193           << ", "  << test_value2[2] << " and "  << test_value2[3] 
194           << std::endl;
195  }
196
197  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
198    ok=false;
199    *error << "ERROR: incorrect sums of training values in row 1" 
200           << " found: " << t_value[0] << ", "  << t_value[1] 
201           << ", "  << t_value[2] << " and "  << t_value[3] 
202           << std::endl;   
203  }
204
205  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
206    ok=false;
207    *error << "ERROR: incorrect sums of validation values in row 1" 
208           << " found: " << v_value[0] << ", "  << v_value[1] 
209           << ", "  << v_value[2] << " and "  << v_value[3] 
210           << std::endl;   
211  }
212
213
214
215  if (error!=&std::cerr)
216    delete error;
217 
218  if (ok)
219    return 0;
220  return -1;
221}
222
223
224void class_count_test(const std::vector<size_t>& class_count, 
225                      std::ostream* error, bool& ok) 
226{
227  for (size_t i=0; i<class_count.size(); i++)
228    if (class_count[i]==0){
229      ok = false;
230      *error << "ERROR: class " << i << " was not in set." 
231             << " Expected at least one sample from each class." 
232             << std::endl;
233    }
234}
235
236void sample_count_test(const std::vector<size_t>& sample_count, 
237                       std::ostream* error, bool& ok) 
238{
239  for (size_t i=0; i<sample_count.size(); i++){
240    if (sample_count[i]!=1){
241      ok = false;
242      *error << "ERROR: sample " << i << " was validated " << sample_count[i] 
243             << " times." << " Expected to be 1 time" << std::endl;
244    }
245  }
246}
Note: See TracBrowser for help on using the repository browser.