source: trunk/test/subset_generator_test.cc @ 865

Last change on this file since 865 was 865, checked in by Peter, 16 years ago

changing URL to http://trac.thep.lu.se/trac/yat

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 11.4 KB
Line 
1// $Id: subset_generator_test.cc 865 2007-09-10 19:41:04Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "yat/classifier/BootstrapSampler.h"
26#include "yat/classifier/CrossValidationSampler.h"
27#include "yat/classifier/FeatureSelectorIR.h"
28#include "yat/classifier/Kernel_SEV.h"
29#include "yat/classifier/KernelLookup.h"
30#include "yat/classifier/MatrixLookup.h"
31#include "yat/classifier/PolynomialKernelFunction.h"
32#include "yat/classifier/SubsetGenerator.h"
33#include "yat/classifier/SVM.h"
34#include "yat/classifier/NCC.h"
35#include "yat/statistics/AUC.h"
36#include "yat/statistics/PearsonDistance.h"
37#include "yat/utility/matrix.h"
38
39#include <cassert>
40#include <fstream>
41#include <iostream>
42#include <string>
43
44using namespace theplu::yat;
45
46bool class_count_test(const std::vector<size_t>&, std::ostream*);
47bool sample_count_test(const std::vector<size_t>&, std::ostream*);
48bool test_nested(std::ostream* error);
49bool test_cv(std::ostream*);
50bool test_creation(std::ostream* error);
51bool test_bootstrap(std::ostream* error);
52
53
54int main(const int argc,const char* argv[])
55{ 
56  std::ostream* error;
57  if (argc>1 && argv[1]==std::string("-v"))
58    error = &std::cerr;
59  else {
60    error = new std::ofstream("/dev/null");
61    if (argc>1)
62      std::cout << "subset_generator -v : for printing extra information\n";
63  }
64  *error << "testing subset_generator" << std::endl;
65  bool ok = true;
66
67  ok = ok && test_creation(error);
68  ok = ok && test_nested(error);
69  ok = ok && test_cv(error);
70
71  if (ok)
72    return 0;
73  return -1;
74}
75
76
77bool test_creation(std::ostream* error)
78{
79  bool ok=true;
80  std::ifstream is("data/nm_target_bin.txt");
81  *error << "loading target " << std::endl;
82  classifier::Target target(is);
83  is.close();
84  *error << "number of targets: " << target.size() << std::endl;
85  *error << "number of classes: " << target.nof_classes() << std::endl;
86  is.open("data/nm_data_centralized.txt");
87  *error << "loading data " << std::endl;
88  utility::matrix m(is);
89  is.close();
90  classifier::MatrixLookup data(m);
91  *error << "number of samples: " << data.columns() << std::endl;
92  *error << "number of features: " << data.rows() << std::endl;
93  assert(data.columns()==target.size());
94
95  *error << "building kernel" << std::endl;
96  classifier::PolynomialKernelFunction kf(1);
97  classifier::Kernel_SEV kernel_core(data,kf);
98  classifier::KernelLookup kernel(kernel_core);
99  *error << "building Sampler" << std::endl;
100  classifier::CrossValidationSampler sampler(target, 30, 3);
101
102  statistics::AUC score;
103  classifier::FeatureSelectorIR fs(score, 96, 0);
104  *error << "building SubsetGenerator" << std::endl;
105  classifier::SubsetGenerator subset_data(sampler, data, fs);
106  classifier::SubsetGenerator subset_kernel(sampler, kernel, fs);
107  return ok;
108}
109
110bool test_nested(std::ostream* error)
111{
112  bool ok=true;
113  //
114  // Test two nested CrossSplitters
115  //
116
117  *error << "\ntesting two nested crossplitters" << std::endl;
118  std::vector<std::string> label(9);
119  label[0]=label[1]=label[2]="0";
120  label[3]=label[4]=label[5]="1";
121  label[6]=label[7]=label[8]="2";
122                 
123  classifier::Target target(label);
124  utility::matrix raw_data2(2,9);
125  for(size_t i=0;i<raw_data2.rows();i++)
126    for(size_t j=0;j<raw_data2.columns();j++)
127      raw_data2(i,j)=i*10+10+j+1;
128   
129  classifier::MatrixLookup data2(raw_data2);
130  classifier::CrossValidationSampler cv2(target,3,3);
131  classifier::SubsetGenerator cv_test(cv2,data2);
132
133  std::vector<size_t> sample_count(10,0);
134  std::vector<size_t> test_sample_count(9,0);
135  std::vector<size_t> test_class_count(3,0);
136  std::vector<double> test_value1(4,0);
137  std::vector<double> test_value2(4,0);
138  std::vector<double> t_value(4,0);
139  std::vector<double> v_value(4,0); 
140  for(u_long k=0;k<cv_test.size();k++) {
141   
142    const classifier::DataLookup2D& tv_view=cv_test.training_data(k);
143    const classifier::Target& tv_target=cv_test.training_target(k);
144    const std::vector<size_t>& tv_index=cv_test.training_index(k);
145    const classifier::DataLookup2D& test_view=cv_test.validation_data(k);
146    const classifier::Target& test_target=cv_test.validation_target(k);
147    const std::vector<size_t>& test_index=cv_test.validation_index(k);
148
149    for (size_t i=0; i<test_index.size(); i++) {
150      assert(test_index[i]<sample_count.size());
151      test_sample_count[test_index[i]]++;
152      test_class_count[target(test_index[i])]++;
153      test_value1[0]+=test_view(0,i);
154      test_value2[0]+=test_view(1,i);
155      test_value1[test_target(i)+1]+=test_view(0,i);
156      test_value2[test_target(i)+1]+=test_view(1,i);
157      if(test_target(i)!=target(test_index[i])) {
158        ok=false;
159        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
160      }       
161    }
162   
163    classifier::CrossValidationSampler sampler_training(tv_target,2,2);
164    classifier::SubsetGenerator cv_training(sampler_training,tv_view);
165    std::vector<size_t> v_sample_count(6,0);
166    std::vector<size_t> t_sample_count(6,0);
167    std::vector<size_t> v_class_count(3,0);
168    std::vector<size_t> t_class_count(3,0);
169    std::vector<size_t> t_class_count2(3,0);
170    for(u_long l=0;l<cv_training.size();l++) {
171      const classifier::DataLookup2D& t_view=cv_training.training_data(l);
172      const classifier::Target& t_target=cv_training.training_target(l);
173      const std::vector<size_t>& t_index=cv_training.training_index(l);
174      const classifier::DataLookup2D& v_view=cv_training.validation_data(l);
175      const classifier::Target& v_target=cv_training.validation_target(l);
176      const std::vector<size_t>& v_index=cv_training.validation_index(l);
177     
178      if (test_index.size()+tv_index.size()!=target.size() 
179          || t_index.size()+v_index.size() != tv_target.size() 
180          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
181        ok = false;
182        *error << "ERROR: size of training samples, validation samples " 
183               << "and test samples in is invalid." 
184               << std::endl;
185      }
186      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
187          v_index.size()!=3){
188        ok = false;
189        *error << "ERROR: size of training, validation, and test samples"
190               << " is invalid." 
191               << " Expected sizes to be 3" << std::endl;
192      }     
193
194      std::vector<size_t> tv_sample_count(6,0);
195      for (size_t i=0; i<t_index.size(); i++) {
196        assert(t_index[i]<t_sample_count.size());
197        tv_sample_count[t_index[i]]++;
198        t_sample_count[t_index[i]]++;
199        t_class_count[t_target(i)]++;
200        t_class_count2[tv_target(t_index[i])]++;
201        t_value[0]+=t_view(0,i);
202        t_value[t_target(i)+1]+=t_view(0,i);       
203      }
204      for (size_t i=0; i<v_index.size(); i++) {
205        assert(v_index[i]<v_sample_count.size());
206        tv_sample_count[v_index[i]]++;
207        v_sample_count[v_index[i]]++;
208        v_class_count[v_target(i)]++;
209        v_value[0]+=v_view(0,i);
210        v_value[v_target(i)+1]+=v_view(0,i);
211      }
212 
213      ok = ok && sample_count_test(tv_sample_count,error);     
214
215    }
216    ok = ok && sample_count_test(v_sample_count,error);
217    ok = ok && sample_count_test(t_sample_count,error);
218   
219    ok = ok && class_count_test(t_class_count,error);
220    ok = ok && class_count_test(t_class_count2,error);
221    ok = ok && class_count_test(v_class_count,error);
222
223
224  }
225  ok = ok && sample_count_test(test_sample_count,error);
226  ok = ok && class_count_test(test_class_count,error);
227 
228  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
229     test_value1[3]!=54) {
230    ok=false;
231    *error << "ERROR: incorrect sums of test values in row 1" 
232           << " found: " << test_value1[0] << ", "  << test_value1[1] 
233           << ", "  << test_value1[2] << " and "  << test_value1[3] 
234           << std::endl;
235  }
236
237 
238  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
239     test_value2[3]!=84) {
240    ok=false;
241    *error << "ERROR: incorrect sums of test values in row 2" 
242           << " found: " << test_value2[0] << ", "  << test_value2[1] 
243           << ", "  << test_value2[2] << " and "  << test_value2[3] 
244           << std::endl;
245  }
246
247  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
248    ok=false;
249    *error << "ERROR: incorrect sums of training values in row 1" 
250           << " found: " << t_value[0] << ", "  << t_value[1] 
251           << ", "  << t_value[2] << " and "  << t_value[3] 
252           << std::endl;   
253  }
254
255  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
256    ok=false;
257    *error << "ERROR: incorrect sums of validation values in row 1" 
258           << " found: " << v_value[0] << ", "  << v_value[1] 
259           << ", "  << v_value[2] << " and "  << v_value[3] 
260           << std::endl;   
261  }
262  return ok;
263}
264
265bool class_count_test(const std::vector<size_t>& class_count, 
266                      std::ostream* error) 
267{
268  bool ok=true;
269  for (size_t i=0; i<class_count.size(); i++)
270    if (class_count[i]==0){
271      ok = false;
272      *error << "ERROR: class " << i << " was not in set." 
273             << " Expected at least one sample from each class." 
274             << std::endl;
275    }
276  return ok;
277}
278
279bool sample_count_test(const std::vector<size_t>& sample_count, 
280                       std::ostream* error) 
281{
282  bool ok=true;
283  for (size_t i=0; i<sample_count.size(); i++){
284    if (sample_count[i]!=1){
285      ok = false;
286      *error << "ERROR: sample " << i << " was in a group " << sample_count[i] 
287             << " times." << " Expected to be 1 time" << std::endl;
288    }
289  }
290  return ok;
291}
292
293
294bool test_bootstrap(std::ostream* error)
295{
296  bool ok=true;
297  std::vector<std::string> label(10,"default");
298  label[2]=label[7]="white";
299  label[4]=label[5]="black";
300  label[6]=label[3]="green";
301  label[8]=label[9]="red";
302                 
303  classifier::Target target(label);
304  utility::matrix raw_data(10,10);
305  classifier::MatrixLookup data(raw_data);
306  classifier::BootstrapSampler cv(target,3);
307  return ok;
308}
309
310
311bool test_cv(std::ostream* error)
312{
313  bool ok=true;
314  std::vector<std::string> label(10,"default");
315  label[2]=label[7]="white";
316  label[4]=label[5]="black";
317  label[6]=label[3]="green";
318  label[8]=label[9]="red";
319                 
320  classifier::Target target(label);
321  utility::matrix raw_data(10,10);
322  classifier::MatrixLookup data(raw_data);
323  classifier::CrossValidationSampler cv(target,3,3);
324 
325  std::vector<size_t> sample_count(10,0);
326  for (size_t j=0; j<cv.size(); ++j){
327    std::vector<size_t> class_count(5,0);
328    assert(j<cv.size());
329    if (cv.training_index(j).size()+cv.validation_index(j).size()!=
330        target.size()){
331      ok = false;
332      *error << "ERROR: size of training samples plus " 
333             << "size of validation samples is invalid." << std::endl;
334    }
335    if (cv.validation_index(j).size()!=3 && cv.validation_index(j).size()!=4){
336      ok = false;
337      *error << "ERROR: size of validation samples is invalid." 
338             << "expected size to be 3 or 4" << std::endl;
339    }
340    for (size_t i=0; i<cv.validation_index(j).size(); i++) {
341      assert(cv.validation_index(j)[i]<sample_count.size());
342      sample_count[cv.validation_index(j)[i]]++;
343    }
344    for (size_t i=0; i<cv.training_index(j).size(); i++) {
345      class_count[target(cv.training_index(j)[i])]++;
346    }
347    ok = ok && class_count_test(class_count,error);
348  }
349  ok = ok && sample_count_test(sample_count,error);
350 
351  return ok;
352}
Note: See TracBrowser for help on using the repository browser.