Changeset 463


Ignore:
Timestamp:
Dec 16, 2005, 6:59:15 PM (16 years ago)
Author:
Peter
Message:

fixed bug in KernelView? constructor and updated tests

Location:
trunk
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • trunk/lib/classifier/KernelView.cc

    r461 r463  
    88namespace classifier { 
    99
     10  KernelView::KernelView(const Kernel& kernel)
     11    : MatrixView(), kernel_(&kernel)
     12  {
     13    for(size_t i=0;i<(*kernel_).size();i++)
     14      column_index_.push_back(i);
     15    row_index_=column_index_;
     16  }
     17 
     18  KernelView::KernelView(const Kernel& kernel,
     19                         const std::vector<size_t>& index)
     20    : MatrixView(index,index), kernel_(&kernel)
     21  {
     22  }
     23 
    1024  KernelView::KernelView(const Kernel& kernel,
    1125                         const std::vector<size_t>& row,
  • trunk/lib/classifier/KernelView.h

    r461 r463  
    2121  public:
    2222   
     23    ///
     24    /// Constructor
     25    ///
     26    KernelView(const Kernel&);
     27
     28    ///
     29    /// Contructor taking the Kernel to view into, index
     30    /// vector. Equivalent to KernelView(kernel, index, index).
     31    ///
     32    /// @note For training usage row index shall always be equal to
     33    /// column index.
     34    ///
     35    KernelView(const Kernel& kernel, const std::vector<size_t>& index);
     36   
    2337    ///
    2438    /// Contructor taking the Kernel to view into, row index vector,
    25     /// and column vector.
     39    /// and column index vector.
    2640    ///
    2741    /// @note For training usage row index shall always be equal to
  • trunk/lib/classifier/SVM.cc

    r461 r463  
    1010#include <c++_tools/random/random.h>
    1111
    12 #include <iostream>
    1312#include <algorithm>
     13#include <cassert>
    1414#include <cmath>
    1515#include <limits>
  • trunk/test/crossvalidation_test.cc

    r453 r463  
    11// $Id$
    22
    3 #include <c++_tools/classifier/CrossValidation.h>
     3#include <c++_tools/classifier/CrossSplitting.h>
    44#include <c++_tools/gslapi/vector.h>
    55
     6#include <cstdlib>
     7#include <fstream>
     8#include <iostream>
    69#include <vector>
    7 #include <cstdlib>
    8 #include <iostream>
    910
    10 using namespace theplu;
    11 
    12 int main()
     11int main(const int argc,const char* argv[])
    1312
    1413
     14  using namespace theplu;
     15
     16  std::ostream* error;
     17  if (argc>1 && argv[1]==std::string("-v"))
     18    error = &std::cerr;
     19  else {
     20    error = new std::ofstream("/dev/null");
     21    if (argc>1)
     22      std::cout << "crossvalidation_test -v : for printing extra information\n";
     23  }
     24  *error << "testing crosssplitting" << std::endl;
     25  bool ok = true;
    1526  gslapi::vector target(10,1);
    1627  for (size_t i=0; i<5; i++)
    1728    target(i)=-1;
    1829
    19   classifier::CrossValidation cv(target,3);
     30  classifier::CrossSplitting cv(target,3);
    2031
    2132
     
    3647
    3748  for (unsigned int i=0; i<10 ; i++)
    38     if (count[i]!=2)
    39       return -1;
     49    ok = ok && (count[i]==2);
    4050
     51  if (!ok)
     52    *error << "crossvalidation failed" << std::endl;
     53
     54  if (error!=&std::cerr)
     55    delete error;
    4156  return 0;
    4257}
  • trunk/test/kernel_test.cc

    r453 r463  
    2222
    2323bool test_MEV(const gslapi::matrix& data, const classifier::KernelFunction* kf,
    24               const gslapi::matrix& control, const double error_bound)
     24              const gslapi::matrix& control, const double error_bound,
     25              std::ostream* error)
    2526{
    2627  classifier::Kernel_MEV kernel(data,*kf);
     
    3536  index[1]=2;
    3637  index[2]=3;
    37   classifier::KernelView(kernel,index);
     38  classifier::KernelView kv(kernel,index);
     39  if (kv.rows()!=index.size()){
     40    *error << "Error: KernelView(kernel, index)\n" << std::endl
     41           << "Size of KernelView is " << kv.rows() << std::endl
     42           << "expected " << index.size() << std::endl;
     43   
     44    return false;
     45  }
     46  classifier::KernelView kv2(kernel);
     47  if (kv2.rows()!=kernel.size()){
     48    *error << "Error: KernelView(kernel)\n" << std::endl
     49           << "Size of KernelView is " << kv.rows() << std::endl
     50           << "expected " << kernel.size() << std::endl;
     51   
     52    return false;
     53  }
    3854
    3955  return true;
     
    4157
    4258bool test_SEV(const gslapi::matrix& data, const classifier::KernelFunction* kf,
    43               const gslapi::matrix& control, const double error_bound)
     59              const gslapi::matrix& control, const double error_bound,
     60              std::ostream* error)
    4461{
    4562  classifier::Kernel_SEV kernel(data,*kf);
     
    5471  index[1]=2;
    5572  index[2]=3;
    56   classifier::KernelView(kernel,index);
     73  classifier::KernelView kv(kernel,index);
     74  if (kv.rows()!=index.size()){
     75    *error << "Error: KernelView(kernel, index)\n" << std::endl
     76           << "Size of KernelView is " << kv.rows() << std::endl
     77           << "expected " << index.size() << std::endl;
     78   
     79    return false;
     80  }
     81  classifier::KernelView kv2(kernel);
     82  if (kv2.rows()!=kernel.size()){
     83    *error << "Error: KernelView(kernel)\n" << std::endl
     84           << "Size of KernelView is " << kv.rows() << std::endl
     85           << "expected " << kernel.size() << std::endl;
     86   
     87    return false;
     88  }
    5789  return true;
    5890}
     
    85117  is.close();
    86118  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction();
    87   ok = (ok && test_MEV(data,kf,kernel_matlab,error_bound)
    88         & test_SEV(data,kf,kernel_matlab,error_bound));
     119  ok = (ok && test_MEV(data,kf,kernel_matlab,error_bound, error)
     120        & test_SEV(data,kf,kernel_matlab,error_bound, error));
    89121  delete kf;
    90122 
     
    93125  is.close();
    94126  kf = new classifier::PolynomialKernelFunction(2);
    95   ok = (ok && test_MEV(data,kf,kernel_matlab2,error_bound)
    96         & test_SEV(data,kf,kernel_matlab2,error_bound));
     127  ok = (ok && test_MEV(data,kf,kernel_matlab2,error_bound, error)
     128        & test_SEV(data,kf,kernel_matlab2,error_bound, error));
    97129  delete kf;
    98130
  • trunk/test/score_test.cc

    r447 r463  
    77#include <c++_tools/statistics/FoldChange.h>
    88#include <c++_tools/gslapi/vector.h>
     9#include <c++_tools/statistics/WilcoxonFoldChange.h>
    910
    1011#include <gsl/gsl_cdf.h>
     
    108109  statistics::Pearson pearson(true);
    109110
     111  *error << "testing WilcoxonFoldChange" << std::endl;
     112  statistics::WilcoxonFoldChange wfc(true);
     113
    110114
    111115  if (ok)
  • trunk/test/svm_test.cc

    r453 r463  
    2020
    2121
    22   bool print = (argc>1 && argv[1]==std::string("-p"));
     22  std::ostream* error;
     23  if (argc>1 && argv[1]==std::string("-v"))
     24    error = &std::cerr;
     25  else {
     26    error = new std::ofstream("/dev/null");
     27    if (argc>1)
     28      std::cout << "svm_test -v : for printing extra information\n";
     29  }
     30  *error << "testing svm" << std::endl;
    2331  bool ok = true;
    2432
     
    3745  classifier::Kernel_MEV kernel2(data2,*kf2);
    3846  assert(kernel2.size()==3);
    39   classifier::SVM classifier2(kernel2, target2);
     47  assert(target2.size()==3);
     48  classifier::KernelView kv2(kernel2);
     49  *error << "testing with linear kernel" << std::endl;
     50  assert(kv2.rows()==target2.size());
     51  classifier::SVM classifier2(kv2, target2);
     52  *error << "training...";
    4053  classifier2.train();
     54  *error << " done." << std::endl;
    4155
    4256  if (classifier2.alpha()*target2){
    43     std::cerr << "condition not fullfilled" << std::endl;
     57    *error << "condition not fullfilled" << std::endl;
    4458    return -1;
    4559  }
    4660
    4761  if (classifier2.alpha()(1)!=2 || classifier2.alpha()(2)!=2){
    48     std::cerr << "wrong alpha" << std::endl;
    49     std::cerr << "alpha: " << classifier2.alpha() <<  std::endl;
    50     std::cerr << "expected: 4 2 2" <<  std::endl;
     62    *error << "wrong alpha" << std::endl;
     63    *error << "alpha: " << classifier2.alpha() <<  std::endl;
     64    *error << "expected: 4 2 2" <<  std::endl;
    5165
    5266    return -1;
     
    7387  is.close();
    7488
    75   theplu::classifier::SVM classifier(kernel, target);
    76   if (!classifier.train()){
     89  classifier::KernelView kv(kernel);
     90  theplu::classifier::SVM svm(kv, target);
     91  if (!svm.train()){
    7792    ok=false;
    78     if (print)
    79       std::cerr << "Training failured" << std::endl;
     93    *error << "Training failured" << std::endl;
    8094  }
    8195
    82   theplu::gslapi::vector alpha = classifier.alpha();
     96  theplu::gslapi::vector alpha = svm.alpha();
    8397     
    8498  // Comparing alpha to alpha_matlab
     
    86100  diff_alpha-=alpha_matlab;
    87101  if (diff_alpha*diff_alpha> 1e-10 ){
    88     if (print)
    89       std::cerr << "Difference to matlab alphas too large\n";
     102    *error << "Difference to matlab alphas too large\n";
    90103    ok=false;
    91104  }
    92105
    93106  // Comparing output to target
    94   theplu::gslapi::vector output(classifier.output());
     107  theplu::gslapi::vector output(svm.output());
    95108  double slack = 0;
    96109  for (unsigned int i=0; i<target.size(); i++){
     
    101114  double slack_bound=2e-7;
    102115  if (slack > slack_bound){
    103     if (print){
    104       std::cerr << "Slack too large. Is the bias correct?\n";
    105       std::cerr << "slack: " << slack << std::endl;
    106       std::cerr << "expected less than " << slack_bound << std::endl;
    107     }
     116    *error << "Slack too large. Is the bias correct?\n";
     117    *error << "slack: " << slack << std::endl;
     118    *error << "expected less than " << slack_bound << std::endl;
    108119    ok = false;
    109120  }
     
    112123  delete kf2;
    113124
     125  if (error!=&std::cerr)
     126    delete error;
     127
    114128  if(ok)
    115129    return 0;
Note: See TracChangeset for help on using the changeset viewer.