Changeset 1241 for trunk


Ignore:
Timestamp:
Mar 16, 2008, 6:47:42 AM (14 years ago)
Author:
Peter
Message:

working on #223

Location:
trunk/test
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/nbc_test.cc

    r1157 r1241  
    2222*/
    2323
     24#include "Suite.h"
     25
    2426#include "yat/classifier/MatrixLookup.h"
    2527#include "yat/classifier/MatrixLookupWeighted.h"
     
    3537using namespace theplu::yat;
    3638
    37 int main(const int argc,const char* argv[])
     39int main(int argc, char* argv[])
    3840
    39 
    40   std::ostream* error;
    41   if (argc>1 && argv[1]==std::string("-v"))
    42     error = &std::cerr;
    43   else {
    44     error = new std::ofstream("/dev/null");
    45     if (argc>1)
    46       *error << "nbc_test -v : for printing extra information\n";
    47   }
    48   *error << "testing ncc" << std::endl;
    49   bool ok = true;
     41  test::Suite suite(argc, argv);
     42  suite.err() << "testing ncc" << std::endl;
    5043
    5144  std::ifstream is("data/nm_data_centralized.txt");
     
    5851  is.close();
    5952
    60   *error << "Constructing NBC" << std::endl;
     53  suite.err() << "Constructing NBC" << std::endl;
    6154  classifier::NBC nbc;
    62   *error << "Training NBC" << std::endl;
     55  suite.err() << "Training NBC" << std::endl;
    6356  nbc.train(data,target);
    6457  utility::Matrix res;
    65   *error << "Predicting" << std::endl;
     58  suite.err() << "Predicting" << std::endl;
    6659  nbc.predict(data, res);
    6760
    68 
    69   if(ok)
    70     *error << "OK" << std::endl;
    71   else
    72     *error << "test failed" << std::endl;
    73 
    74   if (error!=&std::cerr)
    75     delete error;
    76 
    77   if(ok)
    78     return 0;
    79   return -1;
    80  
     61  return suite.return_value();
    8162}
  • trunk/test/ncc_test.cc

    r1210 r1241  
    2222*/
    2323
    24 #include "yat/classifier/Kernel_MEV.h"
    25 #include "yat/classifier/KernelLookup.h"
     24#include "Suite.h"
     25
    2626#include "yat/classifier/MatrixLookup.h"
    2727#include "yat/classifier/MatrixLookupWeighted.h"
    2828#include "yat/classifier/NCC.h"
    29 #include "yat/classifier/PolynomialKernelFunction.h"
    3029#include "yat/classifier/Target.h"
    3130#include "yat/utility/Matrix.h"
     
    4544using namespace theplu::yat;
    4645
    47 double deviation(const utility::Matrix& a, const utility::Matrix& b) {
    48   double sl=0;
    49   for (size_t i=0; i<a.rows(); i++){
    50     for (size_t j=0; j<a.columns(); j++){
    51       sl += std::abs(a(i,j)-b(i,j));
    52     }
    53   }
    54   sl /= (a.columns()*a.rows());
    55   return sl;
    56 }
    57 
    58 int main(const int argc,const char* argv[])
     46int main(int argc,char* argv[])
    5947
    60 
    61   std::ostream* error;
    62   if (argc>1 && argv[1]==std::string("-v"))
    63     error = &std::cerr;
    64   else {
    65     error = new std::ofstream("/dev/null");
    66     if (argc>1)
    67       std::cout << "ncc_test -v : for printing extra information\n";
    68   }
    69   *error << "testing ncc" << std::endl;
    70   bool ok = true;
     48  test::Suite suite(argc, argv);
     49  suite.err() << "testing ncc" << std::endl;
    7150
    7251  /////////////////////////////////////////////
     
    7857  classifier::Target target(vec);
    7958  classifier::NCC<statistics::EuclideanDistance> ncctmp;
    80   *error << "training...\n";
     59  suite.err() << "training...\n";
    8160  ncctmp.train(ml,target);
    82   *error << "done\n";
     61  suite.err() << "done\n";
    8362
    8463  /////////////////////////////////////////////
    8564  // A test of predictions using unweighted data
    8665  /////////////////////////////////////////////
    87   *error << "test of predictions using unweighted test data\n";
     66  suite.err() << "test of predictions using unweighted test data\n";
    8867  utility::Matrix data1(3,4);
    8968  for(size_t i=0;i<3;i++) {
     
    10483  utility::Matrix prediction1;
    10584  ncc1.predict(ml1,prediction1);
    106   double slack_bound=2e-7;
    10785  utility::Matrix result1(2,4);
    10886  result1(0,0)=result1(0,1)=result1(1,2)=result1(1,3)=sqrt(3.0);
    10987  result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=sqrt(11.0);
    110   double slack = deviation(prediction1,result1);
    111   if (slack > slack_bound || std::isnan(slack)){
    112     *error << "Difference to expected prediction too large\n";
    113     *error << "slack: " << slack << std::endl;
    114     *error << "expected less than " << slack_bound << std::endl;
    115     ok = false;
     88  if (!suite.equal_range(prediction1.begin(), prediction1.end(),
     89                         result1.begin())) {
     90    suite.add(false);
     91    suite.err() << "Difference to expected prediction too large\n";
    11692  }
    11793
     
    11995  // A test of predictions using unweighted training and weighted test data
    12096  //////////////////////////////////////////////////////////////////////////
    121   *error << "test of predictions using unweighted training and weighted test data\n";
     97  suite.err() << "test of predictions using unweighted training and weighted test data\n";
    12298  utility::Matrix weights1(3,4,1.0);
    12399  weights1(0,0)=weights1(1,1)=weights1(2,2)=weights1(1,3)=0.0;
     
    125101  ncc1.predict(mlw1,prediction1);
    126102  result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=sqrt(15.0);
    127   slack = deviation(prediction1,result1);
    128   if (slack > slack_bound || std::isnan(slack)){
    129     *error << "Difference to expected prediction too large\n";
    130     *error << "slack: " << slack << std::endl;
    131     *error << "expected less than " << slack_bound << std::endl;
    132     ok = false;
     103  if (!suite.equal_range(prediction1.begin(), prediction1.end(),
     104                         result1.begin())) {
     105    suite.add(false);
     106    suite.err() << "Difference to expected prediction too large\n";
    133107  }
    134108
     
    137111  // in centroids and unweighted test data
    138112  //////////////////////////////////////////////////////////////////////////
    139   *error << "test of predictions using nan centroids and unweighted test data\n";
     113  suite.err() << "test of predictions using nan centroids and unweighted test data\n";
    140114  utility::Matrix weights2(3,4,1.0);
    141115  weights2(1,0)=weights2(1,1)=0.0;
     
    147121  result1(1,0)=result1(1,1)=sqrt(11.0);
    148122  result1(0,2)=result1(0,3)=sqrt(15.0);
    149   slack = deviation(prediction1,result1);
    150123  if(!std::isnan(ncc2.centroids()(1,0)))
    151     ok=false;
    152   if (slack > slack_bound || std::isnan(slack)){
    153     *error << "Difference to expected prediction too large\n";
    154     *error << "slack: " << slack << std::endl;
    155     *error << "expected less than " << slack_bound << std::endl;
    156     ok = false;
     124    suite.add(false);
     125  if (!suite.equal_range(prediction1.begin(), prediction1.end(),
     126                         result1.begin())) {
     127    suite.add(false);
     128    suite.err() << "Difference to expected prediction too large\n";
    157129  }
    158130
     
    161133  // test sample has non-zero weights for.
    162134  //////////////////////////////////////////////////////////////////////////
    163   *error << "test of predictions using nan centroids and weighted test data\n";
    164   *error << "... using EuclideanDistance" << std::endl;
     135  suite.err() << "test of predictions using nan centroids and weighted test data\n";
     136  suite.err() << "... using EuclideanDistance" << std::endl;
    165137  weights1(0,0)=weights1(2,0)=0;
    166138  classifier::NCC<statistics::EuclideanDistance> ncc3;
     
    168140  ncc3.predict(mlw1,prediction1);
    169141  if(!std::isnan(ncc3.centroids()(1,0))) {
    170     ok=false;
    171     *error << "Training failed: expected nan in centroid" << std::endl;
     142    suite.add(false);
     143    suite.err() << "Training failed: expected nan in centroid" << std::endl;
    172144  }
    173145  if(!(std::isnan(prediction1(0,0)) &&
    174        std::abs(prediction1(1,0)-sqrt(3.0))<slack_bound &&
    175        std::abs(prediction1(0,1)-sqrt(3.0))<slack_bound &&
    176        std::abs(prediction1(1,1)-sqrt(15.0))<slack_bound &&
    177        std::abs(prediction1(0,2)-sqrt(27.0))<slack_bound)) { 
    178     ok=false;
    179     *error << "Test failed: predictions incorrect" << std::endl;
    180   }
    181   *error << "... using PearsonDistance" << std::endl;;
     146       suite.equal(prediction1(1,0),sqrt(3.0)) &&
     147       suite.equal(prediction1(0,1),sqrt(3.0)) &&
     148       suite.equal(prediction1(1,1),sqrt(15.0)) &&
     149       suite.equal(prediction1(0,2),sqrt(27.0)) )) { 
     150    suite.add(false);
     151    suite.err() << "Test failed: predictions incorrect" << std::endl;
     152  }
     153  suite.err() << "... using PearsonDistance" << std::endl;;
    182154  classifier::NCC<statistics::PearsonDistance> ncc4;
    183155  ncc4.train(mlw2,target1);
    184156  ncc4.predict(mlw1,prediction1);
    185157  if(!std::isnan(ncc4.centroids()(1,0))) {
    186     ok=false;
    187     *error << "Training failed: expected nan in centroid" << std::endl;
     158    suite.add(false);
     159    suite.err() << "Training failed: expected nan in centroid" << std::endl;
    188160  }
    189161  if(!(std::isnan(prediction1(0,0)) &&
    190162       std::isnan(prediction1(0,2)) &&
    191163       std::isnan(prediction1(1,0)) &&
    192        std::abs(prediction1(0,1))<slack_bound &&
    193        std::abs(prediction1(1,2))<slack_bound &&
    194        std::abs(prediction1(1,3))<slack_bound && 
    195        std::abs(prediction1(0,3)-2.0)<slack_bound &&
    196        std::abs(prediction1(1,1)-2.0)<slack_bound)) {
    197     ok=false;
    198     *error << "Test failed: predictions incorrect" << std::endl;
     164       suite.equal(prediction1(0,1), 0) &&
     165       suite.equal(prediction1(1,2), 0) &&
     166       suite.equal(prediction1(1,3), 0) && 
     167       suite.equal(prediction1(0,3), 2.0) &&
     168       suite.equal(prediction1(1,1), 2.0) )) {
     169    suite.add(false);
     170    suite.err() << "Test failed: predictions incorrect" << std::endl;
    199171  }
    200172
     
    214186        std::isnan(prediction1(0,2)) && std::isnan(prediction1(0,3)) &&
    215187        std::isnan(prediction1(1,0)) &&
    216         std::abs(prediction1(1,1)-2.0)<slack_bound &&
    217         std::abs(prediction1(1,2))<slack_bound &&
    218         std::abs(prediction1(1,3))<slack_bound)) {
    219     *error << "Difference to expected prediction too large\n";
    220     ok = false;
     188        suite.equal(prediction1(1,1), 2.0) &&
     189        suite.equal(prediction1(1,2),0) &&
     190        suite.equal(prediction1(1,3),0) )) {
     191    suite.err() << "Difference to expected prediction too large\n";
     192    suite.add(false);
    221193  }
    222194
     
    224196  // A test of predictions using Sorlie data
    225197  //////////////////////////////////////////////////////////////////////////
    226   *error << "test with Sorlie data\n";
     198  suite.err() << "test with Sorlie data\n";
    227199  std::ifstream is("data/sorlie_centroid_data.txt");
    228200  utility::Matrix data(is,'\t');
     
    239211  classifier::MatrixLookupWeighted dataviewweighted(data,weights);
    240212  classifier::NCC<statistics::PearsonDistance> ncc;
    241   *error << "training...\n";
     213  suite.err() << "training...\n";
    242214  ncc.train(dataviewweighted,targets);
    243215
     
    249221  if(centroids.rows() != ncc.centroids().rows() ||
    250222     centroids.columns() != ncc.centroids().columns()) {
    251     *error << "Error in the dimensionality of centroids\n";
    252     *error << "Nof rows: " << centroids.rows() << " expected: "
     223    suite.err() << "Error in the dimensionality of centroids\n";
     224    suite.err() << "Nof rows: " << centroids.rows() << " expected: "
    253225           << ncc.centroids().rows() << std::endl;
    254     *error << "Nof columns: " << centroids.columns() << " expected: "
     226    suite.err() << "Nof columns: " << centroids.columns() << " expected: "
    255227           << ncc.centroids().columns() << std::endl;
    256228  }
    257229
    258   slack = deviation(centroids,ncc.centroids());
    259   if (slack > slack_bound || std::isnan(slack)){
    260     *error << "Difference to stored centroids too large\n";
    261     *error << "slack: " << slack << std::endl;
    262     *error << "expected less than " << slack_bound << std::endl;
    263     ok = false;
    264   }
    265 
    266   *error << "...predicting...\n";
     230  if (!suite.equal_range(centroids.begin(), centroids.end(),
     231                         ncc.centroids().begin(), 100000)) {
     232    suite.add(false);
     233    suite.err() << "Difference to stored centroids too large\n";
     234  }
     235
     236  suite.err() << "...predicting...\n";
    267237  utility::Matrix prediction;
    268238  ncc.predict(dataviewweighted,prediction);
     
    273243  is.close();
    274244
    275   slack = deviation(result,prediction);
    276   if (slack > slack_bound || std::isnan(slack)){
    277     *error << "Difference to stored prediction too large\n";
    278     *error << "slack: " << slack << std::endl;
    279     *error << "expected less than " << slack_bound << std::endl;
    280     ok = false;
    281   }
    282   *error << "done\n";
    283 
    284   if(ok)
    285     *error << "OK" << std::endl;
    286   else
    287     *error << "FAILED" << std::endl;
    288 
    289   if (error!=&std::cerr)
    290     delete error;
    291 
    292   if(ok)
    293     return 0;
    294   return -1; 
     245  if (!suite.equal_range(result.begin(), result.end(),
     246                         prediction.begin(), 100000)) {
     247    suite.add(false);
     248    suite.err() << "Difference to stored prediction too large\n";
     249  }
     250
     251  return suite.return_value();
    295252}
Note: See TracChangeset for help on using the changeset viewer.