source: trunk/test/crossvalidation_test.cc @ 514

Last change on this file since 514 was 514, checked in by Peter, 17 years ago

generalised binary functionality in Target

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.3 KB
Line 
1// $Id: crossvalidation_test.cc 514 2006-02-20 09:45:34Z peter $
2
3#include <c++_tools/classifier/CrossSplitter.h>
4#include <c++_tools/classifier/MatrixLookup.h>
5#include <c++_tools/classifier/Target.h>
6#include <c++_tools/gslapi/matrix.h>
7
8#include <cstdlib>
9#include <fstream>
10#include <iostream>
11#include <string>
12#include <vector>
13
14int main(const int argc,const char* argv[])
15{ 
16  using namespace theplu;
17
18  std::ostream* error;
19  if (argc>1 && argv[1]==std::string("-v"))
20    error = &std::cerr;
21  else {
22    error = new std::ofstream("/dev/null");
23    if (argc>1)
24      std::cout << "crossvalidation_test -v : for printing extra information\n";
25  }
26  *error << "testing crosssplitter" << std::endl;
27  bool ok = true;
28
29  std::vector<std::string> label(10,"default");
30  label[2]=label[7]="white";
31  label[4]=label[5]="black";
32  label[6]=label[3]="green";
33  label[8]=label[9]="red";
34                 
35  classifier::Target target(label);
36  gslapi::matrix raw_data(10,10);
37  classifier::MatrixLookup data(raw_data);
38  classifier::CrossSplitter cv(target,data,3,3);
39 
40  std::vector<size_t> sample_count(10,0);
41  for (cv.reset(); cv.more(); cv.next()){
42    std::vector<size_t> class_count(5,0);
43    if (cv.training_index().size()+cv.validation_index().size()!=target.size()){
44      ok = false;
45      *error << "ERROR: size of training samples plus " 
46             << "size of validation samples is invalid." << std::endl;
47    }
48    if (cv.validation_index().size()!=3 && cv.validation_index().size()!=4){
49      ok = false;
50      *error << "ERROR: size of validation samples is invalid." 
51             << "expected size to be 3 or 4" << std::endl;
52    }
53    for (size_t i=0; i<cv.validation_index().size(); i++) {
54      assert(cv.validation_index()[i]<sample_count.size());
55      sample_count[cv.validation_index()[i]]++;
56    }
57    for (size_t i=0; i<cv.training_index().size(); i++) {
58      class_count[target(cv.training_index()[i])]++;
59    }
60    for (size_t i=0; i<class_count.size(); i++)
61      if (class_count[i]==0){
62        ok = false;
63        *error << "ERROR: class " << i << " was not in training set." 
64                  << " Expected at least one sample from each class." 
65                  << std::endl;
66      }
67  }
68  for (size_t i=0; i<sample_count.size(); i++){
69    if (sample_count[i]!=1){
70      ok = false;
71      *error << "ERROR: sample " << i << " was validated " << sample_count[i] 
72                << " times." << " Expected to be 1 time" << std::endl;
73    }
74  }
75 
76  if (error!=&std::cerr)
77    delete error;
78
79  if (ok)
80    return 0;
81  return -1;
82}
Note: See TracBrowser for help on using the repository browser.