source: trunk/test/ncc_test.cc @ 865

Last change on this file since 865 was 865, checked in by Peter, 16 years ago

changing URL to http://trac.thep.lu.se/trac/yat

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.1 KB
Line 
1// $Id: ncc_test.cc 865 2007-09-10 19:41:04Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér
5
6  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "yat/classifier/IGP.h"
25#include "yat/classifier/MatrixLookup.h"
26#include "yat/classifier/MatrixLookupWeighted.h"
27#include "yat/classifier/NCC.h"
28#include "yat/classifier/Target.h"
29#include "yat/utility/matrix.h"
30#include "yat/statistics/PearsonDistance.h"
31#include "yat/utility/utility.h"
32
33#include <cassert>
34#include <fstream>
35#include <iostream>
36#include <sstream>
37#include <string>
38#include <limits>
39#include <cmath>
40
41using namespace theplu::yat;
42
43int main(const int argc,const char* argv[])
44{ 
45
46  std::ostream* error;
47  if (argc>1 && argv[1]==std::string("-v"))
48    error = &std::cerr;
49  else {
50    error = new std::ofstream("/dev/null");
51    if (argc>1)
52      std::cout << "ncc_test -v : for printing extra information\n";
53  }
54  *error << "testing ncc" << std::endl;
55  bool ok = true;
56
57  std::ifstream is("data/sorlie_centroid_data.txt");
58  utility::matrix data(is,'\t');
59  is.close();
60
61  is.open("data/sorlie_centroid_classes.txt");
62  classifier::Target targets(is);
63  is.close();
64
65  // Generate weight matrix with 0 for missing values and 1 for others.
66  utility::matrix weights(data.rows(),data.columns(),0.0);
67  for(size_t i=0;i<data.rows();++i)
68    for(size_t j=0;j<data.columns();++j)
69      if(!std::isnan(data(i,j)))
70        weights(i,j)=1.0;
71     
72  classifier::MatrixLookupWeighted dataviewweighted(data,weights);
73  statistics::PearsonDistance pearson; 
74  classifier::NCC ncc(dataviewweighted,targets,pearson);
75  ncc.train();
76
77  // Comparing the centroids to stored result
78  is.open("data/sorlie_centroids.txt");
79  utility::matrix centroids(is);
80  is.close();
81
82  if(centroids.rows() != ncc.centroids().rows() ||
83     centroids.columns() != ncc.centroids().columns()) {
84    *error << "Error in the dimensionality of centroids\n";
85    *error << "Nof rows: " << centroids.rows() << " expected: " 
86           << ncc.centroids().rows() << std::endl;
87    *error << "Nof columns: " << centroids.columns() << " expected: " 
88           << ncc.centroids().columns() << std::endl;
89  }
90
91  double slack = 0;
92  for (size_t i=0; i<centroids.rows(); i++){
93    for (size_t j=0; j<centroids.columns(); j++){
94      slack += fabs(centroids(i,j)-ncc.centroids()(i,j));
95    }
96  }
97  slack /= (centroids.columns()*centroids.rows());
98  double slack_bound=2e-7;
99  if (slack > slack_bound || std::isnan(slack)){
100    *error << "Difference to stored centroids too large\n";
101    *error << "slack: " << slack << std::endl;
102    *error << "expected less than " << slack_bound << std::endl;
103    ok = false;
104  }
105
106  utility::matrix prediction;
107  ncc.predict(dataviewweighted,prediction);
108 
109  // Comparing the prediction to stored result
110  is.open("data/sorlie_centroid_predictions.txt");
111  utility::matrix result(is,'\t');
112  is.close();
113
114  slack = 0;
115  for (size_t i=0; i<result.rows(); i++){
116    for (size_t j=0; j<result.columns(); j++){
117        slack += fabs(result(i,j)-prediction(i,j));
118    }
119  }
120  slack /= (result.columns()*result.rows());
121  if (slack > slack_bound || std::isnan(slack)){
122    *error << "Difference to stored prediction too large\n";
123    *error << "slack: " << slack << std::endl;
124    *error << "expected less than " << slack_bound << std::endl;
125    ok = false;
126  }
127 
128
129  // Testing IGP 
130  classifier:: MatrixLookup dataview(data);
131  *error << "testing igp" << std::endl;
132  classifier::IGP igp(dataview,targets,pearson);
133  *error << igp.score() << std::endl;
134
135  if(ok)
136    *error << "OK" << std::endl;
137
138
139  if (error!=&std::cerr)
140    delete error;
141
142  if(ok) 
143    return 0;
144  return -1;
145 
146}
Note: See TracBrowser for help on using the repository browser.