source: trunk/test/crossvalidation_test.cc @ 554

Last change on this file since 554 was 554, checked in by Markus Ringnér, 16 years ago

added small test

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 8.0 KB
Line 
1// $Id: crossvalidation_test.cc 554 2006-03-07 14:22:29Z markus $
2
3#include <c++_tools/classifier/CrossSplitter.h>
4#include <c++_tools/classifier/MatrixLookup.h>
5#include <c++_tools/classifier/Target.h>
6#include <c++_tools/gslapi/matrix.h>
7
8#include <cstdlib>
9#include <fstream>
10#include <iostream>
11#include <string>
12#include <vector>
13
14// forward declaration
15void class_count_test(const std::vector<size_t>&, std::ostream*, bool&);
16void sample_count_test(const std::vector<size_t>&, std::ostream*, bool&);
17
18
19int main(const int argc,const char* argv[])
20{ 
21  using namespace theplu;
22 
23  std::ostream* error;
24  if (argc>1 && argv[1]==std::string("-v"))
25    error = &std::cerr;
26  else {
27    error = new std::ofstream("/dev/null");
28    if (argc>1)
29      std::cout << "crossvalidation_test -v : for printing extra information\n";
30  }
31  *error << "testing crosssplitter" << std::endl;
32  bool ok = true;
33
34  std::vector<std::string> label(10,"default");
35  label[2]=label[7]="white";
36  label[4]=label[5]="black";
37  label[6]=label[3]="green";
38  label[8]=label[9]="red";
39                 
40  classifier::Target target(label);
41  gslapi::matrix raw_data(10,10);
42  classifier::MatrixLookup data(raw_data);
43  classifier::CrossSplitter cv(target,data,3,3);
44 
45  std::vector<size_t> sample_count(10,0);
46  for (cv.reset(); cv.more(); cv.next()){
47    std::vector<size_t> class_count(5,0);
48    if (cv.training_index().size()+cv.validation_index().size()!=target.size()){
49      ok = false;
50      *error << "ERROR: size of training samples plus " 
51             << "size of validation samples is invalid." << std::endl;
52    }
53    if (cv.validation_index().size()!=3 && cv.validation_index().size()!=4){
54      ok = false;
55      *error << "ERROR: size of validation samples is invalid." 
56             << "expected size to be 3 or 4" << std::endl;
57    }
58    for (size_t i=0; i<cv.validation_index().size(); i++) {
59      assert(cv.validation_index()[i]<sample_count.size());
60      sample_count[cv.validation_index()[i]]++;
61    }
62    for (size_t i=0; i<cv.training_index().size(); i++) {
63      class_count[target(cv.training_index()[i])]++;
64    }
65    class_count_test(class_count,error,ok);
66  }
67  sample_count_test(sample_count,error,ok);
68 
69  //
70  // Test two nested CrossSplitters
71  //
72
73  *error << "\ntesting two nested crossplitters" << std::endl;
74  label.resize(9);
75  label[0]=label[1]=label[2]="0";
76  label[3]=label[4]=label[5]="1";
77  label[6]=label[7]=label[8]="2";
78                 
79  target=classifier::Target(label);
80  gslapi::matrix raw_data2(2,9);
81  for(size_t i=0;i<raw_data2.rows();i++)
82    for(size_t j=0;j<raw_data2.columns();j++)
83      raw_data2(i,j)=i*10+10+j+1;
84   
85  classifier::MatrixLookup data2(raw_data2);
86  classifier::CrossSplitter cv_test(target,data2,3,3);
87
88  std::vector<size_t> test_sample_count(9,0);
89  std::vector<size_t> test_class_count(3,0);
90  std::vector<double> test_value1(4,0);
91  std::vector<double> test_value2(4,0);
92  std::vector<double> t_value(4,0);
93  std::vector<double> v_value(4,0); 
94  while(cv_test.more()) {
95   
96    const classifier::DataLookup2D& tv_view=cv_test.training_data();
97    const classifier::Target& tv_target=cv_test.training_target();
98    const std::vector<size_t>& tv_index=cv_test.training_index();
99    const classifier::DataLookup2D& test_view=cv_test.validation_data();
100    const classifier::Target& test_target=cv_test.validation_target();
101    const std::vector<size_t>& test_index=cv_test.validation_index();
102
103    for (size_t i=0; i<test_index.size(); i++) {
104      assert(test_index[i]<sample_count.size());
105      test_sample_count[test_index[i]]++;
106      test_class_count[target(test_index[i])]++;
107      test_value1[0]+=test_view(0,i);
108      test_value2[0]+=test_view(1,i);
109      test_value1[test_target(i)+1]+=test_view(0,i);
110      test_value2[test_target(i)+1]+=test_view(1,i);
111      if(test_target(i)!=target(test_index[i])) {
112        ok=false;
113        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
114      }       
115    }
116   
117    classifier::CrossSplitter cv_training(tv_target,tv_view,2,2);
118    std::vector<size_t> v_sample_count(6,0);
119    std::vector<size_t> t_sample_count(6,0);
120    std::vector<size_t> v_class_count(3,0);
121    std::vector<size_t> t_class_count(3,0);
122    std::vector<size_t> t_class_count2(3,0);
123    while(cv_training.more()) {
124      const classifier::DataLookup2D& t_view=cv_training.training_data();
125      const classifier::Target& t_target=cv_training.training_target();
126      const std::vector<size_t>& t_index=cv_training.training_index();
127      const classifier::DataLookup2D& v_view=cv_training.validation_data();
128      const classifier::Target& v_target=cv_training.validation_target();
129      const std::vector<size_t>& v_index=cv_training.validation_index();
130     
131      if (test_index.size()+tv_index.size()!=target.size() 
132          || t_index.size()+v_index.size() != tv_target.size() 
133          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
134        ok = false;
135        *error << "ERROR: size of training samples, validation samples " 
136               << "and test samples in is invalid." 
137               << std::endl;
138      }
139      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
140          v_index.size()!=3){
141        ok = false;
142        *error << "ERROR: size of training, validation, and test samples"
143               << " is invalid." 
144               << " Expected sizes to be 3" << std::endl;
145      }     
146
147      std::vector<size_t> tv_sample_count(6,0);
148      for (size_t i=0; i<t_index.size(); i++) {
149        assert(t_index[i]<t_sample_count.size());
150        tv_sample_count[t_index[i]]++;
151        t_sample_count[t_index[i]]++;
152        t_class_count[t_target(i)]++;
153        t_class_count2[tv_target(t_index[i])]++;
154        t_value[0]+=t_view(0,i);
155        t_value[t_target(i)+1]+=t_view(0,i);       
156      }
157      for (size_t i=0; i<v_index.size(); i++) {
158        assert(v_index[i]<v_sample_count.size());
159        tv_sample_count[v_index[i]]++;
160        v_sample_count[v_index[i]]++;
161        v_class_count[v_target(i)]++;
162        v_value[0]+=v_view(0,i);
163        v_value[v_target(i)+1]+=v_view(0,i);
164      }
165 
166      sample_count_test(tv_sample_count,error,ok);     
167
168      cv_training.next();     
169    }
170    sample_count_test(v_sample_count,error,ok);
171    sample_count_test(t_sample_count,error,ok);
172   
173    class_count_test(t_class_count,error,ok);
174    class_count_test(t_class_count2,error,ok);
175    class_count_test(v_class_count,error,ok);
176
177
178    cv_test.next();
179  }
180  sample_count_test(test_sample_count,error,ok);
181  class_count_test(test_class_count,error,ok);
182 
183  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
184     test_value1[3]!=54) {
185    ok=false;
186    *error << "ERROR: incorrect sums of test values in row 1" 
187           << " found: " << test_value1[0] << ", "  << test_value1[1] 
188           << ", "  << test_value1[2] << " and "  << test_value1[3] 
189           << std::endl;
190  }
191
192 
193  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
194     test_value2[3]!=84) {
195    ok=false;
196    *error << "ERROR: incorrect sums of test values in row 2" 
197           << " found: " << test_value2[0] << ", "  << test_value2[1] 
198           << ", "  << test_value2[2] << " and "  << test_value2[3] 
199           << std::endl;
200  }
201
202  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
203    ok=false;
204    *error << "ERROR: incorrect sums of training values in row 1" 
205           << " found: " << t_value[0] << ", "  << t_value[1] 
206           << ", "  << t_value[2] << " and "  << t_value[3] 
207           << std::endl;   
208  }
209
210  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
211    ok=false;
212    *error << "ERROR: incorrect sums of validation values in row 1" 
213           << " found: " << v_value[0] << ", "  << v_value[1] 
214           << ", "  << v_value[2] << " and "  << v_value[3] 
215           << std::endl;   
216  }
217
218
219
220  if (error!=&std::cerr)
221    delete error;
222 
223  if (ok)
224    return 0;
225  return -1;
226}
227
228
229void class_count_test(const std::vector<size_t>& class_count, 
230                      std::ostream* error, bool& ok) 
231{
232  for (size_t i=0; i<class_count.size(); i++)
233    if (class_count[i]==0){
234      ok = false;
235      *error << "ERROR: class " << i << " was not in set." 
236             << " Expected at least one sample from each class." 
237             << std::endl;
238    }
239}
240
241void sample_count_test(const std::vector<size_t>& sample_count, 
242                       std::ostream* error, bool& ok) 
243{
244  for (size_t i=0; i<sample_count.size(); i++){
245    if (sample_count[i]!=1){
246      ok = false;
247      *error << "ERROR: sample " << i << " was in a group " << sample_count[i] 
248             << " times." << " Expected to be 1 time" << std::endl;
249    }
250  }
251}
Note: See TracBrowser for help on using the repository browser.