source: trunk/test/subset_generator_test.cc @ 1134

Last change on this file since 1134 was 1134, checked in by Peter, 15 years ago

using Index class instead of std::vector<size_t>

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 11.4 KB
Line 
1// $Id: subset_generator_test.cc 1134 2008-02-23 22:52:43Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "yat/classifier/BootstrapSampler.h"
26#include "yat/classifier/CrossValidationSampler.h"
27#include "yat/classifier/FeatureSelectorIR.h"
28#include "yat/classifier/Kernel_SEV.h"
29#include "yat/classifier/KernelLookup.h"
30#include "yat/classifier/MatrixLookup.h"
31#include "yat/classifier/PolynomialKernelFunction.h"
32#include "yat/classifier/SubsetGenerator.h"
33#include "yat/classifier/SVM.h"
34#include "yat/classifier/NCC.h"
35#include "yat/statistics/AUC.h"
36#include "yat/utility/Matrix.h"
37
38#include <cassert>
39#include <fstream>
40#include <iostream>
41#include <string>
42
43using namespace theplu::yat;
44
45bool class_count_test(const std::vector<size_t>&, std::ostream*);
46bool sample_count_test(const std::vector<size_t>&, std::ostream*);
47bool test_nested(std::ostream* error);
48bool test_cv(std::ostream*);
49bool test_creation(std::ostream* error);
50bool test_bootstrap(std::ostream* error);
51
52
53int main(const int argc,const char* argv[])
54{ 
55  std::ostream* error;
56  if (argc>1 && argv[1]==std::string("-v"))
57    error = &std::cerr;
58  else {
59    error = new std::ofstream("/dev/null");
60    if (argc>1)
61      std::cout << "subset_generator -v : for printing extra information\n";
62  }
63  *error << "testing subset_generator" << std::endl;
64  bool ok = true;
65
66  ok = ok && test_creation(error);
67  ok = ok && test_nested(error);
68  ok = ok && test_cv(error);
69
70  if (ok)
71    return 0;
72  return -1;
73}
74
75
76bool test_creation(std::ostream* error)
77{
78  bool ok=true;
79  std::ifstream is("data/nm_target_bin.txt");
80  *error << "loading target " << std::endl;
81  classifier::Target target(is);
82  is.close();
83  *error << "number of targets: " << target.size() << std::endl;
84  *error << "number of classes: " << target.nof_classes() << std::endl;
85  is.open("data/nm_data_centralized.txt");
86  *error << "loading data " << std::endl;
87  utility::Matrix m(is);
88  is.close();
89  classifier::MatrixLookup data(m);
90  *error << "number of samples: " << data.columns() << std::endl;
91  *error << "number of features: " << data.rows() << std::endl;
92  assert(data.columns()==target.size());
93
94  *error << "building kernel" << std::endl;
95  classifier::PolynomialKernelFunction kf(1);
96  classifier::Kernel_SEV kernel_core(data,kf);
97  classifier::KernelLookup kernel(kernel_core);
98  *error << "building Sampler" << std::endl;
99  classifier::CrossValidationSampler sampler(target, 30, 3);
100
101  statistics::AUC score;
102  classifier::FeatureSelectorIR fs(score, 96, 0);
103  *error << "building SubsetGenerator" << std::endl;
104  classifier::SubsetGenerator<classifier::MatrixLookup> 
105    subset_data(sampler, data, fs);
106  classifier::SubsetGenerator<classifier::KernelLookup> 
107    subset_kernel(sampler, kernel,fs);
108  return ok;
109}
110
111bool test_nested(std::ostream* error)
112{
113  bool ok=true;
114  //
115  // Test two nested CrossSplitters
116  //
117
118  *error << "\ntesting two nested crossplitters" << std::endl;
119  std::vector<std::string> label(9);
120  label[0]=label[1]=label[2]="0";
121  label[3]=label[4]=label[5]="1";
122  label[6]=label[7]=label[8]="2";
123                 
124  classifier::Target target(label);
125  utility::Matrix raw_data2(2,9);
126  for(size_t i=0;i<raw_data2.rows();i++)
127    for(size_t j=0;j<raw_data2.columns();j++)
128      raw_data2(i,j)=i*10+10+j+1;
129   
130  classifier::MatrixLookup data2(raw_data2);
131  classifier::CrossValidationSampler cv2(target,3,3);
132  classifier::SubsetGenerator<classifier::DataLookup2D> cv_test(cv2,data2);
133
134  std::vector<size_t> sample_count(10,0);
135  std::vector<size_t> test_sample_count(9,0);
136  std::vector<size_t> test_class_count(3,0);
137  std::vector<double> test_value1(4,0);
138  std::vector<double> test_value2(4,0);
139  std::vector<double> t_value(4,0);
140  std::vector<double> v_value(4,0); 
141  for(u_long k=0;k<cv_test.size();k++) {
142   
143    const classifier::DataLookup2D& tv_view=cv_test.training_data(k);
144    const classifier::Target& tv_target=cv_test.training_target(k);
145    const utility::Index& tv_index=cv_test.training_index(k);
146    const classifier::DataLookup2D& test_view=cv_test.validation_data(k);
147    const classifier::Target& test_target=cv_test.validation_target(k);
148    const utility::Index& test_index=cv_test.validation_index(k);
149
150    for (size_t i=0; i<test_index.size(); i++) {
151      assert(test_index[i]<sample_count.size());
152      test_sample_count[test_index[i]]++;
153      test_class_count[target(test_index[i])]++;
154      test_value1[0]+=test_view(0,i);
155      test_value2[0]+=test_view(1,i);
156      test_value1[test_target(i)+1]+=test_view(0,i);
157      test_value2[test_target(i)+1]+=test_view(1,i);
158      if(test_target(i)!=target(test_index[i])) {
159        ok=false;
160        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
161      }       
162    }
163   
164    classifier::CrossValidationSampler sampler_training(tv_target,2,2);
165    classifier::SubsetGenerator<classifier::DataLookup2D> 
166      cv_training(sampler_training,tv_view);
167    std::vector<size_t> v_sample_count(6,0);
168    std::vector<size_t> t_sample_count(6,0);
169    std::vector<size_t> v_class_count(3,0);
170    std::vector<size_t> t_class_count(3,0);
171    std::vector<size_t> t_class_count2(3,0);
172    for(u_long l=0;l<cv_training.size();l++) {
173      const classifier::DataLookup2D& t_view=cv_training.training_data(l);
174      const classifier::Target& t_target=cv_training.training_target(l);
175      const utility::Index& t_index=cv_training.training_index(l);
176      const classifier::DataLookup2D& v_view=cv_training.validation_data(l);
177      const classifier::Target& v_target=cv_training.validation_target(l);
178      const utility::Index& v_index=cv_training.validation_index(l);
179     
180      if (test_index.size()+tv_index.size()!=target.size() 
181          || t_index.size()+v_index.size() != tv_target.size() 
182          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
183        ok = false;
184        *error << "ERROR: size of training samples, validation samples " 
185               << "and test samples in is invalid." 
186               << std::endl;
187      }
188      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
189          v_index.size()!=3){
190        ok = false;
191        *error << "ERROR: size of training, validation, and test samples"
192               << " is invalid." 
193               << " Expected sizes to be 3" << std::endl;
194      }     
195
196      std::vector<size_t> tv_sample_count(6,0);
197      for (size_t i=0; i<t_index.size(); i++) {
198        assert(t_index[i]<t_sample_count.size());
199        tv_sample_count[t_index[i]]++;
200        t_sample_count[t_index[i]]++;
201        t_class_count[t_target(i)]++;
202        t_class_count2[tv_target(t_index[i])]++;
203        t_value[0]+=t_view(0,i);
204        t_value[t_target(i)+1]+=t_view(0,i);       
205      }
206      for (size_t i=0; i<v_index.size(); i++) {
207        assert(v_index[i]<v_sample_count.size());
208        tv_sample_count[v_index[i]]++;
209        v_sample_count[v_index[i]]++;
210        v_class_count[v_target(i)]++;
211        v_value[0]+=v_view(0,i);
212        v_value[v_target(i)+1]+=v_view(0,i);
213      }
214 
215      ok = ok && sample_count_test(tv_sample_count,error);     
216
217    }
218    ok = ok && sample_count_test(v_sample_count,error);
219    ok = ok && sample_count_test(t_sample_count,error);
220   
221    ok = ok && class_count_test(t_class_count,error);
222    ok = ok && class_count_test(t_class_count2,error);
223    ok = ok && class_count_test(v_class_count,error);
224
225
226  }
227  ok = ok && sample_count_test(test_sample_count,error);
228  ok = ok && class_count_test(test_class_count,error);
229 
230  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
231     test_value1[3]!=54) {
232    ok=false;
233    *error << "ERROR: incorrect sums of test values in row 1" 
234           << " found: " << test_value1[0] << ", "  << test_value1[1] 
235           << ", "  << test_value1[2] << " and "  << test_value1[3] 
236           << std::endl;
237  }
238
239 
240  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
241     test_value2[3]!=84) {
242    ok=false;
243    *error << "ERROR: incorrect sums of test values in row 2" 
244           << " found: " << test_value2[0] << ", "  << test_value2[1] 
245           << ", "  << test_value2[2] << " and "  << test_value2[3] 
246           << std::endl;
247  }
248
249  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
250    ok=false;
251    *error << "ERROR: incorrect sums of training values in row 1" 
252           << " found: " << t_value[0] << ", "  << t_value[1] 
253           << ", "  << t_value[2] << " and "  << t_value[3] 
254           << std::endl;   
255  }
256
257  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
258    ok=false;
259    *error << "ERROR: incorrect sums of validation values in row 1" 
260           << " found: " << v_value[0] << ", "  << v_value[1] 
261           << ", "  << v_value[2] << " and "  << v_value[3] 
262           << std::endl;   
263  }
264  return ok;
265}
266
267bool class_count_test(const std::vector<size_t>& class_count, 
268                      std::ostream* error) 
269{
270  bool ok=true;
271  for (size_t i=0; i<class_count.size(); i++)
272    if (class_count[i]==0){
273      ok = false;
274      *error << "ERROR: class " << i << " was not in set." 
275             << " Expected at least one sample from each class." 
276             << std::endl;
277    }
278  return ok;
279}
280
281bool sample_count_test(const std::vector<size_t>& sample_count, 
282                       std::ostream* error) 
283{
284  bool ok=true;
285  for (size_t i=0; i<sample_count.size(); i++){
286    if (sample_count[i]!=1){
287      ok = false;
288      *error << "ERROR: sample " << i << " was in a group " << sample_count[i] 
289             << " times." << " Expected to be 1 time" << std::endl;
290    }
291  }
292  return ok;
293}
294
295
296bool test_bootstrap(std::ostream* error)
297{
298  bool ok=true;
299  std::vector<std::string> label(10,"default");
300  label[2]=label[7]="white";
301  label[4]=label[5]="black";
302  label[6]=label[3]="green";
303  label[8]=label[9]="red";
304                 
305  classifier::Target target(label);
306  utility::Matrix raw_data(10,10);
307  classifier::MatrixLookup data(raw_data);
308  classifier::BootstrapSampler cv(target,3);
309  return ok;
310}
311
312
313bool test_cv(std::ostream* error)
314{
315  bool ok=true;
316  std::vector<std::string> label(10,"default");
317  label[2]=label[7]="white";
318  label[4]=label[5]="black";
319  label[6]=label[3]="green";
320  label[8]=label[9]="red";
321                 
322  classifier::Target target(label);
323  utility::Matrix raw_data(10,10);
324  classifier::MatrixLookup data(raw_data);
325  classifier::CrossValidationSampler cv(target,3,3);
326 
327  std::vector<size_t> sample_count(10,0);
328  for (size_t j=0; j<cv.size(); ++j){
329    std::vector<size_t> class_count(5,0);
330    assert(j<cv.size());
331    if (cv.training_index(j).size()+cv.validation_index(j).size()!=
332        target.size()){
333      ok = false;
334      *error << "ERROR: size of training samples plus " 
335             << "size of validation samples is invalid." << std::endl;
336    }
337    if (cv.validation_index(j).size()!=3 && cv.validation_index(j).size()!=4){
338      ok = false;
339      *error << "ERROR: size of validation samples is invalid." 
340             << "expected size to be 3 or 4" << std::endl;
341    }
342    for (size_t i=0; i<cv.validation_index(j).size(); i++) {
343      assert(cv.validation_index(j)[i]<sample_count.size());
344      sample_count[cv.validation_index(j)[i]]++;
345    }
346    for (size_t i=0; i<cv.training_index(j).size(); i++) {
347      class_count[target(cv.training_index(j)[i])]++;
348    }
349    ok = ok && class_count_test(class_count,error);
350  }
351  ok = ok && sample_count_test(sample_count,error);
352 
353  return ok;
354}
Note: See TracBrowser for help on using the repository browser.