source: trunk/test/crossvalidation_test.cc @ 704

Last change on this file since 704 was 704, checked in by Markus Ringnér, 16 years ago

Fixes #104. Also fixed inline bug in Averager.h

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 9.0 KB
Line 
1// $Id: crossvalidation_test.cc 704 2006-12-18 16:01:41Z markus $
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "yat/classifier/CrossValidationSampler.h"
25#include "yat/classifier/SubsetGenerator.h"
26#include "yat/classifier/MatrixLookup.h"
27#include "yat/classifier/Target.h"
28#include "yat/utility/matrix.h"
29
30#include <cstdlib>
31#include <fstream>
32#include <iostream>
33#include <string>
34#include <vector>
35
36// forward declaration
37void class_count_test(const std::vector<size_t>&, std::ostream*, bool&);
38void sample_count_test(const std::vector<size_t>&, std::ostream*, bool&);
39
40
41int main(const int argc,const char* argv[])
42{ 
43  using namespace theplu::yat;
44 
45  std::ostream* error;
46  if (argc>1 && argv[1]==std::string("-v"))
47    error = &std::cerr;
48  else {
49    error = new std::ofstream("/dev/null");
50    if (argc>1)
51      std::cout << "crossvalidation_test -v : for printing extra information\n";
52  }
53  *error << "testing crosssplitter" << std::endl;
54  bool ok = true;
55
56  std::vector<std::string> label(10,"default");
57  label[2]=label[7]="white";
58  label[4]=label[5]="black";
59  label[6]=label[3]="green";
60  label[8]=label[9]="red";
61                 
62  classifier::Target target(label);
63  utility::matrix raw_data(10,10);
64  classifier::MatrixLookup data(raw_data);
65  classifier::CrossValidationSampler cv(target,3,3);
66 
67  std::vector<size_t> sample_count(10,0);
68  for (size_t j=0; j<cv.size(); ++j){
69    std::vector<size_t> class_count(5,0);
70    assert(j<cv.size());
71    if (cv.training_index(j).size()+cv.validation_index(j).size()!=
72        target.size()){
73      ok = false;
74      *error << "ERROR: size of training samples plus " 
75             << "size of validation samples is invalid." << std::endl;
76    }
77    if (cv.validation_index(j).size()!=3 && cv.validation_index(j).size()!=4){
78      ok = false;
79      *error << "ERROR: size of validation samples is invalid." 
80             << "expected size to be 3 or 4" << std::endl;
81    }
82    for (size_t i=0; i<cv.validation_index(j).size(); i++) {
83      assert(cv.validation_index(j)[i]<sample_count.size());
84      sample_count[cv.validation_index(j)[i]]++;
85    }
86    for (size_t i=0; i<cv.training_index(j).size(); i++) {
87      class_count[target(cv.training_index(j)[i])]++;
88    }
89    class_count_test(class_count,error,ok);
90  }
91  sample_count_test(sample_count,error,ok);
92 
93  //
94  // Test two nested CrossSplitters
95  //
96
97  *error << "\ntesting two nested crossplitters" << std::endl;
98  label.resize(9);
99  label[0]=label[1]=label[2]="0";
100  label[3]=label[4]=label[5]="1";
101  label[6]=label[7]=label[8]="2";
102                 
103  target=classifier::Target(label);
104  utility::matrix raw_data2(2,9);
105  for(size_t i=0;i<raw_data2.rows();i++)
106    for(size_t j=0;j<raw_data2.columns();j++)
107      raw_data2(i,j)=i*10+10+j+1;
108   
109  classifier::MatrixLookup data2(raw_data2);
110  classifier::CrossValidationSampler cv2(target,3,3);
111  classifier::SubsetGenerator cv_test(cv2,data2);
112
113  std::vector<size_t> test_sample_count(9,0);
114  std::vector<size_t> test_class_count(3,0);
115  std::vector<double> test_value1(4,0);
116  std::vector<double> test_value2(4,0);
117  std::vector<double> t_value(4,0);
118  std::vector<double> v_value(4,0); 
119  for(u_long k=0;k<cv_test.size();k++) {
120   
121    const classifier::DataLookup2D& tv_view=cv_test.training_data(k);
122    const classifier::Target& tv_target=cv_test.training_target(k);
123    const std::vector<size_t>& tv_index=cv_test.training_index(k);
124    const classifier::DataLookup2D& test_view=cv_test.validation_data(k);
125    const classifier::Target& test_target=cv_test.validation_target(k);
126    const std::vector<size_t>& test_index=cv_test.validation_index(k);
127
128    for (size_t i=0; i<test_index.size(); i++) {
129      assert(test_index[i]<sample_count.size());
130      test_sample_count[test_index[i]]++;
131      test_class_count[target(test_index[i])]++;
132      test_value1[0]+=test_view(0,i);
133      test_value2[0]+=test_view(1,i);
134      test_value1[test_target(i)+1]+=test_view(0,i);
135      test_value2[test_target(i)+1]+=test_view(1,i);
136      if(test_target(i)!=target(test_index[i])) {
137        ok=false;
138        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
139      }       
140    }
141   
142    classifier::CrossValidationSampler sampler_training(tv_target,2,2);
143    classifier::SubsetGenerator cv_training(sampler_training,tv_view);
144    std::vector<size_t> v_sample_count(6,0);
145    std::vector<size_t> t_sample_count(6,0);
146    std::vector<size_t> v_class_count(3,0);
147    std::vector<size_t> t_class_count(3,0);
148    std::vector<size_t> t_class_count2(3,0);
149    for(u_long l=0;l<cv_training.size();l++) {
150      const classifier::DataLookup2D& t_view=cv_training.training_data(l);
151      const classifier::Target& t_target=cv_training.training_target(l);
152      const std::vector<size_t>& t_index=cv_training.training_index(l);
153      const classifier::DataLookup2D& v_view=cv_training.validation_data(l);
154      const classifier::Target& v_target=cv_training.validation_target(l);
155      const std::vector<size_t>& v_index=cv_training.validation_index(l);
156     
157      if (test_index.size()+tv_index.size()!=target.size() 
158          || t_index.size()+v_index.size() != tv_target.size() 
159          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
160        ok = false;
161        *error << "ERROR: size of training samples, validation samples " 
162               << "and test samples in is invalid." 
163               << std::endl;
164      }
165      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
166          v_index.size()!=3){
167        ok = false;
168        *error << "ERROR: size of training, validation, and test samples"
169               << " is invalid." 
170               << " Expected sizes to be 3" << std::endl;
171      }     
172
173      std::vector<size_t> tv_sample_count(6,0);
174      for (size_t i=0; i<t_index.size(); i++) {
175        assert(t_index[i]<t_sample_count.size());
176        tv_sample_count[t_index[i]]++;
177        t_sample_count[t_index[i]]++;
178        t_class_count[t_target(i)]++;
179        t_class_count2[tv_target(t_index[i])]++;
180        t_value[0]+=t_view(0,i);
181        t_value[t_target(i)+1]+=t_view(0,i);       
182      }
183      for (size_t i=0; i<v_index.size(); i++) {
184        assert(v_index[i]<v_sample_count.size());
185        tv_sample_count[v_index[i]]++;
186        v_sample_count[v_index[i]]++;
187        v_class_count[v_target(i)]++;
188        v_value[0]+=v_view(0,i);
189        v_value[v_target(i)+1]+=v_view(0,i);
190      }
191 
192      sample_count_test(tv_sample_count,error,ok);     
193
194    }
195    sample_count_test(v_sample_count,error,ok);
196    sample_count_test(t_sample_count,error,ok);
197   
198    class_count_test(t_class_count,error,ok);
199    class_count_test(t_class_count2,error,ok);
200    class_count_test(v_class_count,error,ok);
201
202
203  }
204  sample_count_test(test_sample_count,error,ok);
205  class_count_test(test_class_count,error,ok);
206 
207  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
208     test_value1[3]!=54) {
209    ok=false;
210    *error << "ERROR: incorrect sums of test values in row 1" 
211           << " found: " << test_value1[0] << ", "  << test_value1[1] 
212           << ", "  << test_value1[2] << " and "  << test_value1[3] 
213           << std::endl;
214  }
215
216 
217  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
218     test_value2[3]!=84) {
219    ok=false;
220    *error << "ERROR: incorrect sums of test values in row 2" 
221           << " found: " << test_value2[0] << ", "  << test_value2[1] 
222           << ", "  << test_value2[2] << " and "  << test_value2[3] 
223           << std::endl;
224  }
225
226  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
227    ok=false;
228    *error << "ERROR: incorrect sums of training values in row 1" 
229           << " found: " << t_value[0] << ", "  << t_value[1] 
230           << ", "  << t_value[2] << " and "  << t_value[3] 
231           << std::endl;   
232  }
233
234  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
235    ok=false;
236    *error << "ERROR: incorrect sums of validation values in row 1" 
237           << " found: " << v_value[0] << ", "  << v_value[1] 
238           << ", "  << v_value[2] << " and "  << v_value[3] 
239           << std::endl;   
240  }
241
242
243
244  if (error!=&std::cerr)
245    delete error;
246 
247  if (ok)
248    return 0;
249  return -1;
250}
251
252
253void class_count_test(const std::vector<size_t>& class_count, 
254                      std::ostream* error, bool& ok) 
255{
256  for (size_t i=0; i<class_count.size(); i++)
257    if (class_count[i]==0){
258      ok = false;
259      *error << "ERROR: class " << i << " was not in set." 
260             << " Expected at least one sample from each class." 
261             << std::endl;
262    }
263}
264
265void sample_count_test(const std::vector<size_t>& sample_count, 
266                       std::ostream* error, bool& ok) 
267{
268  for (size_t i=0; i<sample_count.size(); i++){
269    if (sample_count[i]!=1){
270      ok = false;
271      *error << "ERROR: sample " << i << " was in a group " << sample_count[i] 
272             << " times." << " Expected to be 1 time" << std::endl;
273    }
274  }
275}
Note: See TracBrowser for help on using the repository browser.