Changeset 1013


Ignore:
Timestamp:
Feb 1, 2008, 4:34:30 PM (13 years ago)
Author:
Markus Ringnér
Message:

Adding functionality tests for NCC

Location:
trunk
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/ncc_test.cc

    r1000 r1013  
    3131#include "yat/classifier/Target.h"
    3232#include "yat/utility/matrix.h"
     33#include "yat/statistics/euclidean_vector_distance.h"
    3334#include "yat/statistics/pearson_vector_distance.h"
    3435#include "yat/utility/utility.h"
     
    4546using namespace theplu::yat;
    4647
     48double deviation(const utility::matrix& a, const utility::matrix& b) {
     49  double sl=0;
     50  for (size_t i=0; i<a.rows(); i++){
     51    for (size_t j=0; j<a.columns(); j++){
     52      sl += fabs(a(i,j)-b(i,j));
     53    }
     54  }
     55  sl /= (a.columns()*a.rows());
     56  return sl;
     57}
     58
    4759int main(const int argc,const char* argv[])
    4860
     
    5971  bool ok = true;
    6072
     73  /////////////////////////////////////////////
     74  // First test of constructor and training 
     75  /////////////////////////////////////////////
    6176  classifier::MatrixLookup ml(4,4);
    6277  std::vector<std::string> vec(4, "pos");
     
    6681  *error << "training...\n";
    6782  ncctmp.train();
     83  *error << "done\n";
     84
     85  /////////////////////////////////////////////
     86  // A test of predictions using unweighted data
     87  /////////////////////////////////////////////
     88  utility::matrix data1(3,4);
     89  for(size_t i=0;i<3;i++) {
     90    data1(i,0)=3-i;
     91    data1(i,1)=5-i;
     92    data1(i,2)=i+1;
     93    data1(i,3)=i+3;
     94  }
     95  std::vector<std::string> vec1(4, "pos");
     96  vec1[0]="neg";
     97  vec1[1]="neg";
     98
     99  classifier::MatrixLookup ml1(data1);
     100  classifier::Target target1(vec1);
     101
     102  classifier::NCC<statistics::euclidean_vector_distance_tag> ncc1(ml1,target1);
     103  ncc1.train();
     104  utility::matrix prediction1;
     105  ncc1.predict(ml1,prediction1);
     106  double slack_bound=2e-7;
     107  utility::matrix result1(2,4);
     108  result1(0,0)=result1(0,1)=result1(1,2)=result1(1,3)=sqrt(3.0);
     109  result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=sqrt(11.0);
     110  double slack = deviation(prediction1,result1);
     111  if (slack > slack_bound || std::isnan(slack)){
     112    *error << "Difference to expected prediction too large\n";
     113    *error << "slack: " << slack << std::endl;
     114    *error << "expected less than " << slack_bound << std::endl;
     115    ok = false;
     116  }
     117
     118  //////////////////////////////////////////////////////////////////////////
     119  // A test of predictions using unweighted training and weighted test data
     120  //////////////////////////////////////////////////////////////////////////
     121  utility::matrix weights1(3,4,1.0);
     122  weights1(0,0)=weights1(1,1)=weights1(2,2)=weights1(1,3)=0.0;
     123  classifier::MatrixLookupWeighted mlw1(data1,weights1);
     124  ncc1.predict(mlw1,prediction1);
     125  result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=sqrt(15.0);
     126  slack = deviation(prediction1,result1);
     127  if (slack > slack_bound || std::isnan(slack)){
     128    *error << "Difference to expected prediction too large\n";
     129    *error << "slack: " << slack << std::endl;
     130    *error << "expected less than " << slack_bound << std::endl;
     131    ok = false;
     132  }
     133
    68134 
     135
     136  //////////////////////////////////////////////////////////////////////////
     137  // A test of predictions using Sorlie data
     138  //////////////////////////////////////////////////////////////////////////
    69139  std::ifstream is("data/sorlie_centroid_data.txt");
    70140  utility::matrix data(is,'\t');
     
    98168  }
    99169
    100   double slack = 0;
    101   for (size_t i=0; i<centroids.rows(); i++){
    102     for (size_t j=0; j<centroids.columns(); j++){
    103       slack += fabs(centroids(i,j)-ncc.centroids()(i,j));
    104     }
    105   }
    106   slack /= (centroids.columns()*centroids.rows());
    107   double slack_bound=2e-7;
     170  slack = deviation(centroids,ncc.centroids());
    108171  if (slack > slack_bound || std::isnan(slack)){
    109172    *error << "Difference to stored centroids too large\n";
     
    113176  }
    114177
    115   *error << "prediction...\n";
     178  *error << "...predicting...\n";
    116179  utility::matrix prediction;
    117180  ncc.predict(dataviewweighted,prediction);
     
    122185  is.close();
    123186
    124   slack = 0;
    125   for (size_t i=0; i<result.rows(); i++){
    126     for (size_t j=0; j<result.columns(); j++){
    127         slack += fabs(result(i,j)-prediction(i,j));
    128     }
    129   }
    130   slack /= (result.columns()*result.rows());
     187  slack = deviation(result,prediction);
    131188  if (slack > slack_bound || std::isnan(slack)){
    132189    *error << "Difference to stored prediction too large\n";
     
    135192    ok = false;
    136193  }
    137 
    138   // testing rejection of KernelLookups
     194  *error << "done\n";
     195
     196  //////////////////////////////////////////////////////////////////////////
     197  // Testing rejection of KernelLookups
     198  //////////////////////////////////////////////////////////////////////////
    139199  classifier::PolynomialKernelFunction kf;
    140200  classifier::Kernel_MEV kernel(ml,kf);
     
    161221    *error << "OK" << std::endl;
    162222
    163 
    164223  if (error!=&std::cerr)
    165224    delete error;
  • trunk/yat/classifier/NCC.h

    r1007 r1013  
    215215      }
    216216    }
    217     trained_=true;
    218     return trained_;
     217    return true;
    219218  }
    220219
     
    223222                              utility::matrix& prediction) const
    224223  {   
    225     utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
    226     utility::yat_assert<std::runtime_error>(test.rows()==centroids_->rows());
     224    utility::yat_assert<std::runtime_error>
     225      (centroids_,"NCC::predict called for untrained classifier");
     226    utility::yat_assert<std::runtime_error>
     227      (data_.rows()==test.rows(),
     228       "NCC::predict test data with incorrect number of rows");
    227229   
    228230    prediction.clone(utility::matrix(centroids_->columns(), test.columns()));       
  • trunk/yat/classifier/SVM.cc

    r1009 r1013  
    126126      sc->max_epochs(max_epochs());
    127127    }
    128     catch (std::bad_cast& e) {
     128    catch (std::bad_cast) {
    129129      std::string str =
    130130        "Error in SVM::make_classifier: DataLookup2D of unexpected class.";
     
    169169        prediction(1,i) = -prediction(0,i);
    170170    }
    171     catch (std::bad_cast& e) {
     171    catch (std::bad_cast) {
    172172      std::string str =
    173173        "Error in SVM::predict: DataLookup2D of unexpected class.";
Note: See TracChangeset for help on using the changeset viewer.