Changeset 1864


Ignore:
Timestamp:
Mar 15, 2009, 3:49:53 AM (15 years ago)
Author:
Peter
Message:

adding test for train and predict. refs #269

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/svm_multi_class_test.cc

    r1862 r1864  
    2222#include "Suite.h"
    2323
     24#include "yat/classifier/Kernel_SEV.h"
     25#include "yat/classifier/KernelLookup.h"
     26#include "yat/classifier/MatrixLookupWeighted.h"
     27#include "yat/classifier/PolynomialKernelFunction.h"
    2428#include "yat/classifier/SvmMultiClass.h"
     29
     30#include <fstream>
     31#include <string>
    2532
    2633using namespace theplu::yat;
    2734
    2835void test_construction(test::Suite&);
     36void test_predict(test::Suite&);
    2937
    3038int main( int argc, char* argv[])
     
    3341  suite.err() << "testing SvmMultiClass" << std::endl;
    3442  test_construction(suite);
     43  test_predict(suite);
    3544
    3645  return suite.return_value();
     
    5362  delete svm3;
    5463}
     64
     65
     66void test_predict(test::Suite& suite)
     67{
     68  using namespace classifier;
     69  std::string file = test::filename("data/sorlie_centroid_data.txt");
     70  std::ifstream is(file.c_str());
     71  suite.err() << "load data `" << file << "'" << std::endl;
     72  MatrixLookupWeighted data(is, '\t');
     73  is.close();
     74  PolynomialKernelFunction linear;
     75  suite.err() << "calculating kernel" << std::endl;
     76  Kernel_SEV kernel_raw(data, linear);
     77  KernelLookup kernel(kernel_raw);
     78  file = test::filename("data/sorlie_centroid_classes.txt");
     79  suite.err() << "load classes `" << file << "'" << std::endl;
     80  is.open(file.c_str());
     81  Target target(is);
     82  is.close();
     83
     84  SvmMultiClass svm;
     85  suite.err() << "training svm" << std::endl;
     86  svm.train(kernel, target);
     87
     88  utility::Matrix result;
     89  svm.predict(kernel, result);
     90
     91  if (!suite.add(result.rows()==5 && result.columns()==79)) {
     92    suite.err() << "ERROR: incorrect dimension in result Matrix\n"
     93                << "found " << result.rows() << "x" << result.columns() << "\n"
     94                << "expected 5x79\n";
     95  }
     96
     97  // we expect perfect predictions on training data
     98  for (size_t i=0; i<79; ++i) {
     99    for (size_t j=0; j<5; ++j) {
     100      if (target(i)==j && result(j,i)<0) {
     101        suite.err() << "result(" << j << "," << i << ") is "
     102                    << result(j,i) << " expected greater than 0" << std::endl;
     103        suite.add(false);
     104      }
     105      else if (target(i)!=j && result(j,i)>0) {
     106        suite.err() << "result(" << j << "," << i << ") is "
     107                    << result(j,i) << " expected smaller than 0" << std::endl;
     108        suite.add(false);
     109      }
     110    }
     111  }
     112}
Note: See TracChangeset for help on using the changeset viewer.