Changeset 1865


Ignore:
Timestamp:
Mar 16, 2009, 12:36:36 PM (15 years ago)
Author:
Peter
Message:

adding another test and fixing prediction when having empty training classes. closes #269

Location:
trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/svm_multi_class_test.cc

    r1864 r1865  
    3535void test_construction(test::Suite&);
    3636void test_predict(test::Suite&);
     37void test_predict2(test::Suite&);
    3738
    3839int main( int argc, char* argv[])
     
    4243  test_construction(suite);
    4344  test_predict(suite);
     45  test_predict2(suite);
    4446
    4547  return suite.return_value();
     
    111113  }
    112114}
     115
     116void test_predict2(test::Suite& suite)
     117{
     118  using namespace classifier;
     119  std::string file = test::filename("data/sorlie_centroid_data.txt");
     120  std::ifstream is(file.c_str());
     121  suite.err() << "load data `" << file << "'" << std::endl;
     122  MatrixLookupWeighted data(is, '\t');
     123  is.close();
     124  PolynomialKernelFunction linear;
     125  suite.err() << "calculating kernel" << std::endl;
     126  Kernel_SEV kernel_raw(data, linear);
     127  file = test::filename("data/sorlie_centroid_classes.txt");
     128  suite.err() << "load classes `" << file << "'" << std::endl;
     129  is.open(file.c_str());
     130  Target target(is);
     131  is.close();
     132
     133  std::vector<size_t> index;
     134  for (size_t i=0; i<50; ++i)
     135    index.push_back(i);
     136  for (size_t i=70; i<79; ++i)
     137    index.push_back(i);
     138  utility::Index train_index(index);
     139 
     140  Target target_train(target, train_index);
     141  KernelLookup kernel_train(kernel_raw, train_index, train_index);
     142  SvmMultiClass svm;
     143  suite.err() << "training svm" << std::endl;
     144  svm.train(kernel_train, target_train);
     145
     146  index.clear();
     147  for (size_t i=50; i<70; ++i)
     148    index.push_back(i);
     149  utility::Index test_index(index);
     150
     151  KernelLookup kernel_test(kernel_raw, train_index, test_index);
     152  utility::Matrix result;
     153  suite.err() << "Predicting on test data" << std::endl;
     154  svm.predict(kernel_test, result);
     155
     156  if (!suite.add(result.rows()==5 && result.columns()==20)) {
     157    suite.err() << "ERROR: incorrect dimension in result Matrix\n"
     158                << "found " << result.rows() << "x" << result.columns() << "\n"
     159                << "expected 5x79\n";
     160  }
     161  if (!suite.add(std::isnan(result(3, 0))) ) {
     162    suite.err() << "ERROR: expected result(4,0) to be nan\n"
     163                << "  found " << result(4,0) << std::endl;
     164  }
     165
     166}
     167
  • trunk/yat/classifier/SvmMultiClass.cc

    r1861 r1865  
    6969                      std::numeric_limits<double>::quiet_NaN());
    7070    for (size_t i=0; i<svm_.size(); ++i) {
    71       yat::utility::Matrix tmp;
    72       svm_[i].predict(input, tmp);
    73       prediction.row_view(i) = tmp.row_const_view(0);
     71      if (svm_[i].trained()) {
     72        yat::utility::Matrix tmp;
     73        svm_[i].predict(input, tmp);
     74        prediction.row_view(i) = tmp.row_const_view(0);
     75      }
    7476    }
    7577
Note: See TracChangeset for help on using the changeset viewer.