source: trunk/test/subset_generator_test.cc @ 931

Last change on this file since 931 was 931, checked in by Markus Ringnér, 16 years ago

Working on ticket:259. Removed old Distance see ticket:250

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 11.3 KB
Line 
1// $Id: subset_generator_test.cc 931 2007-10-05 15:42:25Z markus $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "yat/classifier/BootstrapSampler.h"
26#include "yat/classifier/CrossValidationSampler.h"
27#include "yat/classifier/FeatureSelectorIR.h"
28#include "yat/classifier/Kernel_SEV.h"
29#include "yat/classifier/KernelLookup.h"
30#include "yat/classifier/MatrixLookup.h"
31#include "yat/classifier/PolynomialKernelFunction.h"
32#include "yat/classifier/SubsetGenerator.h"
33#include "yat/classifier/SVM.h"
34#include "yat/classifier/NCC.h"
35#include "yat/statistics/AUC.h"
36#include "yat/utility/matrix.h"
37
38#include <cassert>
39#include <fstream>
40#include <iostream>
41#include <string>
42
43using namespace theplu::yat;
44
45bool class_count_test(const std::vector<size_t>&, std::ostream*);
46bool sample_count_test(const std::vector<size_t>&, std::ostream*);
47bool test_nested(std::ostream* error);
48bool test_cv(std::ostream*);
49bool test_creation(std::ostream* error);
50bool test_bootstrap(std::ostream* error);
51
52
53int main(const int argc,const char* argv[])
54{ 
55  std::ostream* error;
56  if (argc>1 && argv[1]==std::string("-v"))
57    error = &std::cerr;
58  else {
59    error = new std::ofstream("/dev/null");
60    if (argc>1)
61      std::cout << "subset_generator -v : for printing extra information\n";
62  }
63  *error << "testing subset_generator" << std::endl;
64  bool ok = true;
65
66  ok = ok && test_creation(error);
67  ok = ok && test_nested(error);
68  ok = ok && test_cv(error);
69
70  if (ok)
71    return 0;
72  return -1;
73}
74
75
76bool test_creation(std::ostream* error)
77{
78  bool ok=true;
79  std::ifstream is("data/nm_target_bin.txt");
80  *error << "loading target " << std::endl;
81  classifier::Target target(is);
82  is.close();
83  *error << "number of targets: " << target.size() << std::endl;
84  *error << "number of classes: " << target.nof_classes() << std::endl;
85  is.open("data/nm_data_centralized.txt");
86  *error << "loading data " << std::endl;
87  utility::matrix m(is);
88  is.close();
89  classifier::MatrixLookup data(m);
90  *error << "number of samples: " << data.columns() << std::endl;
91  *error << "number of features: " << data.rows() << std::endl;
92  assert(data.columns()==target.size());
93
94  *error << "building kernel" << std::endl;
95  classifier::PolynomialKernelFunction kf(1);
96  classifier::Kernel_SEV kernel_core(data,kf);
97  classifier::KernelLookup kernel(kernel_core);
98  *error << "building Sampler" << std::endl;
99  classifier::CrossValidationSampler sampler(target, 30, 3);
100
101  statistics::AUC score;
102  classifier::FeatureSelectorIR fs(score, 96, 0);
103  *error << "building SubsetGenerator" << std::endl;
104  classifier::SubsetGenerator subset_data(sampler, data, fs);
105  classifier::SubsetGenerator subset_kernel(sampler, kernel, fs);
106  return ok;
107}
108
109bool test_nested(std::ostream* error)
110{
111  bool ok=true;
112  //
113  // Test two nested CrossSplitters
114  //
115
116  *error << "\ntesting two nested crossplitters" << std::endl;
117  std::vector<std::string> label(9);
118  label[0]=label[1]=label[2]="0";
119  label[3]=label[4]=label[5]="1";
120  label[6]=label[7]=label[8]="2";
121                 
122  classifier::Target target(label);
123  utility::matrix raw_data2(2,9);
124  for(size_t i=0;i<raw_data2.rows();i++)
125    for(size_t j=0;j<raw_data2.columns();j++)
126      raw_data2(i,j)=i*10+10+j+1;
127   
128  classifier::MatrixLookup data2(raw_data2);
129  classifier::CrossValidationSampler cv2(target,3,3);
130  classifier::SubsetGenerator cv_test(cv2,data2);
131
132  std::vector<size_t> sample_count(10,0);
133  std::vector<size_t> test_sample_count(9,0);
134  std::vector<size_t> test_class_count(3,0);
135  std::vector<double> test_value1(4,0);
136  std::vector<double> test_value2(4,0);
137  std::vector<double> t_value(4,0);
138  std::vector<double> v_value(4,0); 
139  for(u_long k=0;k<cv_test.size();k++) {
140   
141    const classifier::DataLookup2D& tv_view=cv_test.training_data(k);
142    const classifier::Target& tv_target=cv_test.training_target(k);
143    const std::vector<size_t>& tv_index=cv_test.training_index(k);
144    const classifier::DataLookup2D& test_view=cv_test.validation_data(k);
145    const classifier::Target& test_target=cv_test.validation_target(k);
146    const std::vector<size_t>& test_index=cv_test.validation_index(k);
147
148    for (size_t i=0; i<test_index.size(); i++) {
149      assert(test_index[i]<sample_count.size());
150      test_sample_count[test_index[i]]++;
151      test_class_count[target(test_index[i])]++;
152      test_value1[0]+=test_view(0,i);
153      test_value2[0]+=test_view(1,i);
154      test_value1[test_target(i)+1]+=test_view(0,i);
155      test_value2[test_target(i)+1]+=test_view(1,i);
156      if(test_target(i)!=target(test_index[i])) {
157        ok=false;
158        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
159      }       
160    }
161   
162    classifier::CrossValidationSampler sampler_training(tv_target,2,2);
163    classifier::SubsetGenerator cv_training(sampler_training,tv_view);
164    std::vector<size_t> v_sample_count(6,0);
165    std::vector<size_t> t_sample_count(6,0);
166    std::vector<size_t> v_class_count(3,0);
167    std::vector<size_t> t_class_count(3,0);
168    std::vector<size_t> t_class_count2(3,0);
169    for(u_long l=0;l<cv_training.size();l++) {
170      const classifier::DataLookup2D& t_view=cv_training.training_data(l);
171      const classifier::Target& t_target=cv_training.training_target(l);
172      const std::vector<size_t>& t_index=cv_training.training_index(l);
173      const classifier::DataLookup2D& v_view=cv_training.validation_data(l);
174      const classifier::Target& v_target=cv_training.validation_target(l);
175      const std::vector<size_t>& v_index=cv_training.validation_index(l);
176     
177      if (test_index.size()+tv_index.size()!=target.size() 
178          || t_index.size()+v_index.size() != tv_target.size() 
179          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
180        ok = false;
181        *error << "ERROR: size of training samples, validation samples " 
182               << "and test samples in is invalid." 
183               << std::endl;
184      }
185      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
186          v_index.size()!=3){
187        ok = false;
188        *error << "ERROR: size of training, validation, and test samples"
189               << " is invalid." 
190               << " Expected sizes to be 3" << std::endl;
191      }     
192
193      std::vector<size_t> tv_sample_count(6,0);
194      for (size_t i=0; i<t_index.size(); i++) {
195        assert(t_index[i]<t_sample_count.size());
196        tv_sample_count[t_index[i]]++;
197        t_sample_count[t_index[i]]++;
198        t_class_count[t_target(i)]++;
199        t_class_count2[tv_target(t_index[i])]++;
200        t_value[0]+=t_view(0,i);
201        t_value[t_target(i)+1]+=t_view(0,i);       
202      }
203      for (size_t i=0; i<v_index.size(); i++) {
204        assert(v_index[i]<v_sample_count.size());
205        tv_sample_count[v_index[i]]++;
206        v_sample_count[v_index[i]]++;
207        v_class_count[v_target(i)]++;
208        v_value[0]+=v_view(0,i);
209        v_value[v_target(i)+1]+=v_view(0,i);
210      }
211 
212      ok = ok && sample_count_test(tv_sample_count,error);     
213
214    }
215    ok = ok && sample_count_test(v_sample_count,error);
216    ok = ok && sample_count_test(t_sample_count,error);
217   
218    ok = ok && class_count_test(t_class_count,error);
219    ok = ok && class_count_test(t_class_count2,error);
220    ok = ok && class_count_test(v_class_count,error);
221
222
223  }
224  ok = ok && sample_count_test(test_sample_count,error);
225  ok = ok && class_count_test(test_class_count,error);
226 
227  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
228     test_value1[3]!=54) {
229    ok=false;
230    *error << "ERROR: incorrect sums of test values in row 1" 
231           << " found: " << test_value1[0] << ", "  << test_value1[1] 
232           << ", "  << test_value1[2] << " and "  << test_value1[3] 
233           << std::endl;
234  }
235
236 
237  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
238     test_value2[3]!=84) {
239    ok=false;
240    *error << "ERROR: incorrect sums of test values in row 2" 
241           << " found: " << test_value2[0] << ", "  << test_value2[1] 
242           << ", "  << test_value2[2] << " and "  << test_value2[3] 
243           << std::endl;
244  }
245
246  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
247    ok=false;
248    *error << "ERROR: incorrect sums of training values in row 1" 
249           << " found: " << t_value[0] << ", "  << t_value[1] 
250           << ", "  << t_value[2] << " and "  << t_value[3] 
251           << std::endl;   
252  }
253
254  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
255    ok=false;
256    *error << "ERROR: incorrect sums of validation values in row 1" 
257           << " found: " << v_value[0] << ", "  << v_value[1] 
258           << ", "  << v_value[2] << " and "  << v_value[3] 
259           << std::endl;   
260  }
261  return ok;
262}
263
264bool class_count_test(const std::vector<size_t>& class_count, 
265                      std::ostream* error) 
266{
267  bool ok=true;
268  for (size_t i=0; i<class_count.size(); i++)
269    if (class_count[i]==0){
270      ok = false;
271      *error << "ERROR: class " << i << " was not in set." 
272             << " Expected at least one sample from each class." 
273             << std::endl;
274    }
275  return ok;
276}
277
278bool sample_count_test(const std::vector<size_t>& sample_count, 
279                       std::ostream* error) 
280{
281  bool ok=true;
282  for (size_t i=0; i<sample_count.size(); i++){
283    if (sample_count[i]!=1){
284      ok = false;
285      *error << "ERROR: sample " << i << " was in a group " << sample_count[i] 
286             << " times." << " Expected to be 1 time" << std::endl;
287    }
288  }
289  return ok;
290}
291
292
293bool test_bootstrap(std::ostream* error)
294{
295  bool ok=true;
296  std::vector<std::string> label(10,"default");
297  label[2]=label[7]="white";
298  label[4]=label[5]="black";
299  label[6]=label[3]="green";
300  label[8]=label[9]="red";
301                 
302  classifier::Target target(label);
303  utility::matrix raw_data(10,10);
304  classifier::MatrixLookup data(raw_data);
305  classifier::BootstrapSampler cv(target,3);
306  return ok;
307}
308
309
310bool test_cv(std::ostream* error)
311{
312  bool ok=true;
313  std::vector<std::string> label(10,"default");
314  label[2]=label[7]="white";
315  label[4]=label[5]="black";
316  label[6]=label[3]="green";
317  label[8]=label[9]="red";
318                 
319  classifier::Target target(label);
320  utility::matrix raw_data(10,10);
321  classifier::MatrixLookup data(raw_data);
322  classifier::CrossValidationSampler cv(target,3,3);
323 
324  std::vector<size_t> sample_count(10,0);
325  for (size_t j=0; j<cv.size(); ++j){
326    std::vector<size_t> class_count(5,0);
327    assert(j<cv.size());
328    if (cv.training_index(j).size()+cv.validation_index(j).size()!=
329        target.size()){
330      ok = false;
331      *error << "ERROR: size of training samples plus " 
332             << "size of validation samples is invalid." << std::endl;
333    }
334    if (cv.validation_index(j).size()!=3 && cv.validation_index(j).size()!=4){
335      ok = false;
336      *error << "ERROR: size of validation samples is invalid." 
337             << "expected size to be 3 or 4" << std::endl;
338    }
339    for (size_t i=0; i<cv.validation_index(j).size(); i++) {
340      assert(cv.validation_index(j)[i]<sample_count.size());
341      sample_count[cv.validation_index(j)[i]]++;
342    }
343    for (size_t i=0; i<cv.training_index(j).size(); i++) {
344      class_count[target(cv.training_index(j)[i])]++;
345    }
346    ok = ok && class_count_test(class_count,error);
347  }
348  ok = ok && sample_count_test(sample_count,error);
349 
350  return ok;
351}
Note: See TracBrowser for help on using the repository browser.