Ignore:
Timestamp:
Mar 7, 2006, 2:39:51 PM (17 years ago)
Author:
Markus Ringnér
Message:

Added tests of two nested CrossSplitters? to crossvalidation_test

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/crossvalidation_test.cc

    r514 r551  
    1212#include <vector>
    1313
     14// forward declaration
     15void class_count_test(const std::vector<size_t>&, std::ostream*, bool&);
     16void sample_count_test(const std::vector<size_t>&, std::ostream*, bool&);
     17
     18
    1419int main(const int argc,const char* argv[])
    1520
    1621  using namespace theplu;
    17 
     22 
    1823  std::ostream* error;
    1924  if (argc>1 && argv[1]==std::string("-v"))
     
    5863      class_count[target(cv.training_index()[i])]++;
    5964    }
    60     for (size_t i=0; i<class_count.size(); i++)
    61       if (class_count[i]==0){
    62         ok = false;
    63         *error << "ERROR: class " << i << " was not in training set."
    64                   << " Expected at least one sample from each class."
    65                   << std::endl;
     65    class_count_test(class_count,error,ok);
     66  }
     67  sample_count_test(sample_count,error,ok);
     68 
     69  //
     70  // Test two nested CrossSplitters
     71  //
     72
     73  *error << "\ntesting two nested crossplitters" << std::endl;
     74  label.resize(9);
     75  label[0]=label[1]=label[2]="0";
     76  label[3]=label[4]=label[5]="1";
     77  label[6]=label[7]=label[8]="2";
     78                 
     79  target=classifier::Target(label);
     80  gslapi::matrix raw_data2(2,9);
     81  for(size_t i=0;i<raw_data2.rows();i++)
     82    for(size_t j=0;j<raw_data2.columns();j++)
     83      raw_data2(i,j)=i*10+10+j+1;
     84   
     85  classifier::MatrixLookup data2(raw_data2);
     86  classifier::CrossSplitter cv_test(target,data2,3,3);
     87
     88  std::vector<size_t> test_sample_count(9,0);
     89  std::vector<size_t> test_class_count(3,0);
     90  std::vector<double> test_value1(4,0);
     91  std::vector<double> test_value2(4,0);
     92  std::vector<double> t_value(4,0);
     93  std::vector<double> v_value(4,0);
     94  while(cv_test.more()) {
     95   
     96    const classifier::DataLookup2D& tv_view=cv_test.training_data();
     97    const classifier::Target& tv_target=cv_test.training_target();
     98    const std::vector<size_t>& tv_index=cv_test.training_index();
     99    const classifier::DataLookup2D& test_view=cv_test.validation_data();
     100    const classifier::Target& test_target=cv_test.validation_target();
     101    const std::vector<size_t>& test_index=cv_test.validation_index();
     102
     103    for (size_t i=0; i<test_index.size(); i++) {
     104      assert(test_index[i]<sample_count.size());
     105      test_sample_count[test_index[i]]++;
     106      test_class_count[target(test_index[i])]++;
     107      test_value1[0]+=test_view(0,i);
     108      test_value2[0]+=test_view(1,i);
     109      test_value1[test_target(i)+1]+=test_view(0,i);
     110      test_value2[test_target(i)+1]+=test_view(1,i);
     111      if(test_target(i)!=target(test_index[i])) {
     112        ok=false;
     113        *error << "ERROR: incorrect mapping of test indices" << std:: endl;
     114      }       
     115    }
     116   
     117    classifier::CrossSplitter cv_training(tv_target,tv_view,2,2);
     118    std::vector<size_t> v_sample_count(6,0);
     119    std::vector<size_t> t_sample_count(6,0);
     120    std::vector<size_t> v_class_count(3,0);
     121    std::vector<size_t> t_class_count(3,0);
     122    std::vector<size_t> t_class_count2(3,0);
     123    while(cv_training.more()) {
     124      const classifier::DataLookup2D& t_view=cv_training.training_data();
     125      const classifier::Target& t_target=cv_training.training_target();
     126      const std::vector<size_t>& t_index=cv_training.training_index();
     127      const classifier::DataLookup2D& v_view=cv_training.validation_data();
     128      const classifier::Target& v_target=cv_training.validation_target();
     129      const std::vector<size_t>& v_index=cv_training.validation_index();
     130     
     131      if (test_index.size()+tv_index.size()!=target.size()
     132          || t_index.size()+v_index.size() != tv_target.size()
     133          || test_index.size()+v_index.size()+t_index.size() !=  target.size()){
     134        ok = false;
     135        *error << "ERROR: size of training samples, validation samples "
     136               << "and test samples in is invalid."
     137               << std::endl;
     138      }
     139      if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 ||
     140          v_index.size()!=3){
     141        ok = false;
     142        *error << "ERROR: size of training, validation, and test samples"
     143               << " is invalid."
     144               << " Expected sizes to be 3" << std::endl;
     145      }     
     146
     147      for (size_t i=0; i<t_index.size(); i++) {
     148        assert(t_index[i]<t_sample_count.size());
     149        t_sample_count[t_index[i]]++;
     150        t_class_count[t_target(i)]++;
     151        t_class_count2[tv_target(t_index[i])]++;
     152        t_value[0]+=t_view(0,i);
     153        t_value[t_target(i)+1]+=t_view(0,i);
    66154      }
    67   }
    68   for (size_t i=0; i<sample_count.size(); i++){
    69     if (sample_count[i]!=1){
    70       ok = false;
    71       *error << "ERROR: sample " << i << " was validated " << sample_count[i]
    72                 << " times." << " Expected to be 1 time" << std::endl;
    73     }
    74   }
    75  
     155      for (size_t i=0; i<v_index.size(); i++) {
     156        assert(v_index[i]<v_sample_count.size());
     157        v_sample_count[v_index[i]]++;
     158        v_class_count[v_target(i)]++;
     159        v_value[0]+=v_view(0,i);
     160        v_value[v_target(i)+1]+=v_view(0,i);
     161      }
     162
     163      cv_training.next();     
     164    }
     165    sample_count_test(v_sample_count,error,ok);
     166    sample_count_test(t_sample_count,error,ok);
     167   
     168    class_count_test(t_class_count,error,ok);
     169    class_count_test(t_class_count2,error,ok);
     170    class_count_test(v_class_count,error,ok);
     171
     172
     173    cv_test.next();
     174  }
     175  sample_count_test(test_sample_count,error,ok);
     176  class_count_test(test_class_count,error,ok);
     177 
     178  if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 ||
     179     test_value1[3]!=54) {
     180    ok=false;
     181    *error << "ERROR: incorrect sums of test values in row 1"
     182           << " found: " << test_value1[0] << ", "  << test_value1[1]
     183           << ", "  << test_value1[2] << " and "  << test_value1[3]
     184           << std::endl;
     185  }
     186
     187 
     188  if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 ||
     189     test_value2[3]!=84) {
     190    ok=false;
     191    *error << "ERROR: incorrect sums of test values in row 2"
     192           << " found: " << test_value2[0] << ", "  << test_value2[1]
     193           << ", "  << test_value2[2] << " and "  << test_value2[3]
     194           << std::endl;
     195  }
     196
     197  if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108)  {
     198    ok=false;
     199    *error << "ERROR: incorrect sums of training values in row 1"
     200           << " found: " << t_value[0] << ", "  << t_value[1]
     201           << ", "  << t_value[2] << " and "  << t_value[3]
     202           << std::endl;   
     203  }
     204
     205  if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108)  {
     206    ok=false;
     207    *error << "ERROR: incorrect sums of validation values in row 1"
     208           << " found: " << v_value[0] << ", "  << v_value[1]
     209           << ", "  << v_value[2] << " and "  << v_value[3]
     210           << std::endl;   
     211  }
     212
     213
     214
    76215  if (error!=&std::cerr)
    77216    delete error;
    78 
     217 
    79218  if (ok)
    80219    return 0;
    81220  return -1;
    82221}
     222
     223
     224void class_count_test(const std::vector<size_t>& class_count,
     225                      std::ostream* error, bool& ok) 
     226{
     227  for (size_t i=0; i<class_count.size(); i++)
     228    if (class_count[i]==0){
     229      ok = false;
     230      *error << "ERROR: class " << i << " was not in set."
     231             << " Expected at least one sample from each class."
     232             << std::endl;
     233    }
     234}
     235
     236void sample_count_test(const std::vector<size_t>& sample_count,
     237                       std::ostream* error, bool& ok) 
     238{
     239  for (size_t i=0; i<sample_count.size(); i++){
     240    if (sample_count[i]!=1){
     241      ok = false;
     242      *error << "ERROR: sample " << i << " was validated " << sample_count[i]
     243             << " times." << " Expected to be 1 time" << std::endl;
     244    }
     245  }
     246}
Note: See TracChangeset for help on using the changeset viewer.