source: trunk/test/ncc_test.cc @ 1000

Last change on this file since 1000 was 1000, checked in by Jari Häkkinen, 15 years ago

trac moved to new location.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.9 KB
Line 
1// $Id: ncc_test.cc 1000 2007-12-23 20:09:15Z jari $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér
5
6  This file is part of the yat library, http://trac.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "yat/classifier/IGP.h"
25#include "yat/classifier/Kernel_MEV.h"
26#include "yat/classifier/KernelLookup.h"
27#include "yat/classifier/MatrixLookup.h"
28#include "yat/classifier/MatrixLookupWeighted.h"
29#include "yat/classifier/NCC.h"
30#include "yat/classifier/PolynomialKernelFunction.h"
31#include "yat/classifier/Target.h"
32#include "yat/utility/matrix.h"
33#include "yat/statistics/pearson_vector_distance.h"
34#include "yat/utility/utility.h"
35
36#include <cassert>
37#include <fstream>
38#include <iostream>
39#include <stdexcept>
40#include <sstream>
41#include <string>
42#include <limits>
43#include <cmath>
44
45using namespace theplu::yat;
46
47int main(const int argc,const char* argv[])
48{ 
49
50  std::ostream* error;
51  if (argc>1 && argv[1]==std::string("-v"))
52    error = &std::cerr;
53  else {
54    error = new std::ofstream("/dev/null");
55    if (argc>1)
56      std::cout << "ncc_test -v : for printing extra information\n";
57  }
58  *error << "testing ncc" << std::endl;
59  bool ok = true;
60
61  classifier::MatrixLookup ml(4,4);
62  std::vector<std::string> vec(4, "pos");
63  vec[3]="bjds";
64  classifier::Target target(vec);
65  classifier::NCC<statistics::pearson_vector_distance_tag> ncctmp(ml,target);
66  *error << "training...\n";
67  ncctmp.train();
68 
69  std::ifstream is("data/sorlie_centroid_data.txt");
70  utility::matrix data(is,'\t');
71  is.close();
72
73  is.open("data/sorlie_centroid_classes.txt");
74  classifier::Target targets(is);
75  is.close();
76
77  // Generate weight matrix with 0 for missing values and 1 for others.
78  utility::matrix weights(data.rows(),data.columns(),0.0);
79  utility::nan(data,weights);
80     
81  classifier::MatrixLookupWeighted dataviewweighted(data,weights);
82  classifier::NCC<statistics::pearson_vector_distance_tag> ncc(dataviewweighted,targets);
83  *error << "training...\n";
84  ncc.train();
85
86  // Comparing the centroids to stored result
87  is.open("data/sorlie_centroids.txt");
88  utility::matrix centroids(is);
89  is.close();
90
91  if(centroids.rows() != ncc.centroids().rows() ||
92     centroids.columns() != ncc.centroids().columns()) {
93    *error << "Error in the dimensionality of centroids\n";
94    *error << "Nof rows: " << centroids.rows() << " expected: " 
95           << ncc.centroids().rows() << std::endl;
96    *error << "Nof columns: " << centroids.columns() << " expected: " 
97           << ncc.centroids().columns() << std::endl;
98  }
99
100  double slack = 0;
101  for (size_t i=0; i<centroids.rows(); i++){
102    for (size_t j=0; j<centroids.columns(); j++){
103      slack += fabs(centroids(i,j)-ncc.centroids()(i,j));
104    }
105  }
106  slack /= (centroids.columns()*centroids.rows());
107  double slack_bound=2e-7;
108  if (slack > slack_bound || std::isnan(slack)){
109    *error << "Difference to stored centroids too large\n";
110    *error << "slack: " << slack << std::endl;
111    *error << "expected less than " << slack_bound << std::endl;
112    ok = false;
113  }
114
115  *error << "prediction...\n";
116  utility::matrix prediction;
117  ncc.predict(dataviewweighted,prediction);
118 
119  // Comparing the prediction to stored result
120  is.open("data/sorlie_centroid_predictions.txt");
121  utility::matrix result(is,'\t');
122  is.close();
123
124  slack = 0;
125  for (size_t i=0; i<result.rows(); i++){
126    for (size_t j=0; j<result.columns(); j++){
127        slack += fabs(result(i,j)-prediction(i,j));
128    }
129  }
130  slack /= (result.columns()*result.rows());
131  if (slack > slack_bound || std::isnan(slack)){
132    *error << "Difference to stored prediction too large\n";
133    *error << "slack: " << slack << std::endl;
134    *error << "expected less than " << slack_bound << std::endl;
135    ok = false;
136  }
137
138  // testing rejection of KernelLookups
139  classifier::PolynomialKernelFunction kf;
140  classifier::Kernel_MEV kernel(ml,kf);
141  classifier::DataLookup2D* dl_kernel = new classifier::KernelLookup(kernel); 
142  try {
143    ok=0; // should catch error here
144    *error << ncc.make_classifier(*dl_kernel,target) << std::endl;
145  }
146  catch (std::runtime_error) {   
147    *error << "caught expected bad cast runtime_error" << std::endl;
148    ok=1;
149  }
150  try {
151    ok=0; // should catch error here
152    ncc.predict(*dl_kernel,prediction); 
153  }
154  catch (std::runtime_error) {   
155    *error << "caught expected bad cast runtime_error" << std::endl;
156    ok=1;
157  }
158  delete dl_kernel;
159 
160  if(ok)
161    *error << "OK" << std::endl;
162
163
164  if (error!=&std::cerr)
165    delete error;
166
167  if(ok) 
168    return 0;
169  return -1; 
170}
Note: See TracBrowser for help on using the repository browser.