source: trunk/test/crossvalidation_test.cc @ 781

Last change on this file since 781 was 781, checked in by Peter, 16 years ago

changing name to ROCScore and also added some cassert includes

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 9.0 KB
Line 
1// $Id: crossvalidation_test.cc 781 2007-03-05 19:44:03Z peter $
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "yat/classifier/CrossValidationSampler.h"
25#include "yat/classifier/SubsetGenerator.h"
26#include "yat/classifier/MatrixLookup.h"
27#include "yat/classifier/Target.h"
28#include "yat/utility/matrix.h"
29
30#include <cassert>
31#include <cstdlib>
32#include <fstream>
33#include <iostream>
34#include <string>
35#include <vector>
36
37// forward declaration
38void class_count_test(const std::vector<size_t>&, std::ostream*, bool&);
39void sample_count_test(const std::vector<size_t>&, std::ostream*, bool&);
40
41
42int main(const int argc,const char* argv[])
43{ 
44  using namespace theplu::yat;
45 
46  std::ostream* error;
47  if (argc>1 && argv[1]==std::string("-v"))
48    error = &std::cerr;
49  else {
50    error = new std::ofstream("/dev/null");
51    if (argc>1)
52      std::cout << "crossvalidation_test -v : for printing extra information\n";
53  }
54  *error << "testing crosssplitter" << std::endl;
55  bool ok = true;
56
57  std::vector<std::string> label(10,"default");
58  label[2]=label[7]="white";
59  label[4]=label[5]="black";
60  label[6]=label[3]="green";
61  label[8]=label[9]="red";
62                 
63  classifier::Target target(label);
64  utility::matrix raw_data(10,10);
65  classifier::MatrixLookup data(raw_data);
66  classifier::CrossValidationSampler cv(target,3,3);
67 
68  std::vector<size_t> sample_count(10,0);
69  for (size_t j=0; j<cv.size(); ++j){
70    std::vector<size_t> class_count(5,0);
71    assert(j<cv.size());
72    if (cv.training_index(j).size()+cv.validation_index(j).size()!=
73        target.size()){
74      ok = false;
75      *error << "ERROR: size of training samples plus " 
76             << "size of validation samples is invalid." << std::endl;
77    }
78    if (cv.validation_index(j).size()!=3 && cv.validation_index(j).size()!=4){
79      ok = false;
80      *error << "ERROR: size of validation samples is invalid." 
81             << "expected size to be 3 or 4" << std::endl;
82    }
83    for (size_t i=0; i<cv.validation_index(j).size(); i++) {
84      assert(cv.validation_index(j)[i]<sample_count.size());
85      sample_count[cv.validation_index(j)[i]]++;
86    }
87    for (size_t i=0; i<cv.training_index(j).size(); i++) {
88      class_count[target(cv.training_index(j)[i])]++;
89    }
90    class_count_test(class_count,error,ok);
91  }
92  sample_count_test(sample_count,error,ok);
93 
94  //
95  // Test two nested CrossSplitters
96  //
97
98  *error << "\ntesting two nested crossplitters" << std::endl;
99  label.resize(9);
100  label[0]=label[1]=label[2]="0";
101  label[3]=label[4]=label[5]="1";
102  label[6]=label[7]=label[8]="2";
103                 
104  target=classifier::Target(label);
105  utility::matrix raw_data2(2,9);
106  for(size_t i=0;i<raw_data2.rows();i++)
107    for(size_t j=0;j<raw_data2.columns();j++)
108      raw_data2(i,j)=i*10+10+j+1;
109   
110  classifier::MatrixLookup data2(raw_data2);
111  classifier::CrossValidationSampler cv2(target,3,3);
112  classifier::SubsetGenerator cv_test(cv2,data2);
113
114  std::vector<size_t> test_sample_count(9,0);
115  std::vector<size_t> test_class_count(3,0);
116  std::vector<double> test_value1(4,0);
117  std::vector<double> test_value2(4,0);
118  std::vector<double> t_value(4,0);
119  std::vector<double> v_value(4,0); 
120  for(u_long k=0;k<cv_test.size();k++) {
121   
122    const classifier::DataLookup2D& tv_view=cv_test.training_data(k);
123    const classifier::Target& tv_target=cv_test.training_target(k);
124    const std::vector<size_t>& tv_index=cv_test.training_index(k);
125    const classifier::DataLookup2D& test_view=cv_test.validation_data(k);
126    const classifier::Target& test_target=cv_test.validation_target(k);
127    const std::vector<size_t>& test_index=cv_test.validation_index(k);
128
129    for (size_t i=0; i<test_index.size(); i++) {
130      assert(test_index[i]<sample_count.size());
131      test_sample_count[test_index[i]]++;
132      test_class_count[target(test_index[i])]++;
133      test_value1[0]+=test_view(0,i);
134      test_value2[0]+=test_view(1,i);
135      test_value1[test_target(i)+1]+=test_view(0,i);
136      test_value2[test_target(i)+1]+=test_view(1,i);
137      if(test_target(i)!=target(test_index[i])) {
138        ok=false;
139        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
140      }       
141    }
142   
143    classifier::CrossValidationSampler sampler_training(tv_target,2,2);
144    classifier::SubsetGenerator cv_training(sampler_training,tv_view);
145    std::vector<size_t> v_sample_count(6,0);
146    std::vector<size_t> t_sample_count(6,0);
147    std::vector<size_t> v_class_count(3,0);
148    std::vector<size_t> t_class_count(3,0);
149    std::vector<size_t> t_class_count2(3,0);
150    for(u_long l=0;l<cv_training.size();l++) {
151      const classifier::DataLookup2D& t_view=cv_training.training_data(l);
152      const classifier::Target& t_target=cv_training.training_target(l);
153      const std::vector<size_t>& t_index=cv_training.training_index(l);
154      const classifier::DataLookup2D& v_view=cv_training.validation_data(l);
155      const classifier::Target& v_target=cv_training.validation_target(l);
156      const std::vector<size_t>& v_index=cv_training.validation_index(l);
157     
158      if (test_index.size()+tv_index.size()!=target.size() 
159          || t_index.size()+v_index.size() != tv_target.size() 
160          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
161        ok = false;
162        *error << "ERROR: size of training samples, validation samples " 
163               << "and test samples in is invalid." 
164               << std::endl;
165      }
166      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
167          v_index.size()!=3){
168        ok = false;
169        *error << "ERROR: size of training, validation, and test samples"
170               << " is invalid." 
171               << " Expected sizes to be 3" << std::endl;
172      }     
173
174      std::vector<size_t> tv_sample_count(6,0);
175      for (size_t i=0; i<t_index.size(); i++) {
176        assert(t_index[i]<t_sample_count.size());
177        tv_sample_count[t_index[i]]++;
178        t_sample_count[t_index[i]]++;
179        t_class_count[t_target(i)]++;
180        t_class_count2[tv_target(t_index[i])]++;
181        t_value[0]+=t_view(0,i);
182        t_value[t_target(i)+1]+=t_view(0,i);       
183      }
184      for (size_t i=0; i<v_index.size(); i++) {
185        assert(v_index[i]<v_sample_count.size());
186        tv_sample_count[v_index[i]]++;
187        v_sample_count[v_index[i]]++;
188        v_class_count[v_target(i)]++;
189        v_value[0]+=v_view(0,i);
190        v_value[v_target(i)+1]+=v_view(0,i);
191      }
192 
193      sample_count_test(tv_sample_count,error,ok);     
194
195    }
196    sample_count_test(v_sample_count,error,ok);
197    sample_count_test(t_sample_count,error,ok);
198   
199    class_count_test(t_class_count,error,ok);
200    class_count_test(t_class_count2,error,ok);
201    class_count_test(v_class_count,error,ok);
202
203
204  }
205  sample_count_test(test_sample_count,error,ok);
206  class_count_test(test_class_count,error,ok);
207 
208  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
209     test_value1[3]!=54) {
210    ok=false;
211    *error << "ERROR: incorrect sums of test values in row 1" 
212           << " found: " << test_value1[0] << ", "  << test_value1[1] 
213           << ", "  << test_value1[2] << " and "  << test_value1[3] 
214           << std::endl;
215  }
216
217 
218  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
219     test_value2[3]!=84) {
220    ok=false;
221    *error << "ERROR: incorrect sums of test values in row 2" 
222           << " found: " << test_value2[0] << ", "  << test_value2[1] 
223           << ", "  << test_value2[2] << " and "  << test_value2[3] 
224           << std::endl;
225  }
226
227  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
228    ok=false;
229    *error << "ERROR: incorrect sums of training values in row 1" 
230           << " found: " << t_value[0] << ", "  << t_value[1] 
231           << ", "  << t_value[2] << " and "  << t_value[3] 
232           << std::endl;   
233  }
234
235  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
236    ok=false;
237    *error << "ERROR: incorrect sums of validation values in row 1" 
238           << " found: " << v_value[0] << ", "  << v_value[1] 
239           << ", "  << v_value[2] << " and "  << v_value[3] 
240           << std::endl;   
241  }
242
243
244
245  if (error!=&std::cerr)
246    delete error;
247 
248  if (ok)
249    return 0;
250  return -1;
251}
252
253
254void class_count_test(const std::vector<size_t>& class_count, 
255                      std::ostream* error, bool& ok) 
256{
257  for (size_t i=0; i<class_count.size(); i++)
258    if (class_count[i]==0){
259      ok = false;
260      *error << "ERROR: class " << i << " was not in set." 
261             << " Expected at least one sample from each class." 
262             << std::endl;
263    }
264}
265
266void sample_count_test(const std::vector<size_t>& sample_count, 
267                       std::ostream* error, bool& ok) 
268{
269  for (size_t i=0; i<sample_count.size(); i++){
270    if (sample_count[i]!=1){
271      ok = false;
272      *error << "ERROR: sample " << i << " was in a group " << sample_count[i] 
273             << " times." << " Expected to be 1 time" << std::endl;
274    }
275  }
276}
Note: See TracBrowser for help on using the repository browser.