source: trunk/test/ncc_test.cc @ 916

Last change on this file since 916 was 916, checked in by Peter, 16 years ago

Sorry this commit is a bit to big.

Adding a yat_assert. The yat assert are turned on by providing a
'-DYAT_DEBUG' flag to preprocessor if normal cassert is turned
on. This flag is activated for developers running configure with
--enable-debug. The motivation is that we can use these yat_asserts in
header files and the yat_asserts will be invisible to the normal user
also if he uses C-asserts.

added output operator in DataLookup2D and removed output operator in
MatrixLookup?

Removed template function add_values in Averager and weighted version

Added function to AveragerWeighted? taking iterator to four ranges.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.4 KB
Line 
1// $Id: ncc_test.cc 916 2007-09-30 00:50:10Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér
5
6  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "yat/classifier/IGP.h"
25#include "yat/classifier/MatrixLookup.h"
26#include "yat/classifier/MatrixLookupWeighted.h"
27#include "yat/classifier/NCC.h"
28#include "yat/classifier/Target.h"
29#include "yat/utility/matrix.h"
30#include "yat/statistics/vector_distance_ptr.h"
31#include "yat/statistics/pearson_vector_distance.h"
32#include "yat/utility/utility.h"
33
34#include <cassert>
35#include <fstream>
36#include <iostream>
37#include <sstream>
38#include <string>
39#include <limits>
40#include <cmath>
41
42using namespace theplu::yat;
43
44int main(const int argc,const char* argv[])
45{ 
46
47  std::ostream* error;
48  if (argc>1 && argv[1]==std::string("-v"))
49    error = &std::cerr;
50  else {
51    error = new std::ofstream("/dev/null");
52    if (argc>1)
53      std::cout << "ncc_test -v : for printing extra information\n";
54  }
55  *error << "testing ncc" << std::endl;
56  bool ok = true;
57
58  classifier::MatrixLookupWeighted ml(4,4);
59  std::vector<std::string> vec(4, "pos");
60  vec[3]="bjds";
61  classifier::Target target(vec);
62  statistics::vector_distance_lookup_weighted_ptr dist=
63    statistics::vector_distance<statistics::pearson_vector_distance_tag>;
64  classifier::NCC ncctmp(ml,target,dist);
65  *error << "training...\n";
66  ncctmp.train();
67 
68  std::ifstream is("data/sorlie_centroid_data.txt");
69  utility::matrix data(is,'\t');
70  is.close();
71
72  is.open("data/sorlie_centroid_classes.txt");
73  classifier::Target targets(is);
74  is.close();
75
76  // Generate weight matrix with 0 for missing values and 1 for others.
77  utility::matrix weights(data.rows(),data.columns(),0.0);
78  utility::nan(data,weights);
79     
80  classifier::MatrixLookupWeighted dataviewweighted(data,weights);
81  statistics::vector_distance_lookup_weighted_ptr distance=
82    statistics::vector_distance<statistics::pearson_vector_distance_tag>;
83  classifier::NCC ncc(dataviewweighted,targets,distance);
84  *error << "training...\n";
85  ncc.train();
86
87  // Comparing the centroids to stored result
88  is.open("data/sorlie_centroids.txt");
89  utility::matrix centroids(is);
90  is.close();
91
92  if(centroids.rows() != ncc.centroids().rows() ||
93     centroids.columns() != ncc.centroids().columns()) {
94    *error << "Error in the dimensionality of centroids\n";
95    *error << "Nof rows: " << centroids.rows() << " expected: " 
96           << ncc.centroids().rows() << std::endl;
97    *error << "Nof columns: " << centroids.columns() << " expected: " 
98           << ncc.centroids().columns() << std::endl;
99  }
100
101  double slack = 0;
102  for (size_t i=0; i<centroids.rows(); i++){
103    for (size_t j=0; j<centroids.columns(); j++){
104      slack += fabs(centroids(i,j)-ncc.centroids()(i,j));
105    }
106  }
107  slack /= (centroids.columns()*centroids.rows());
108  double slack_bound=2e-7;
109  if (slack > slack_bound || std::isnan(slack)){
110    *error << "Difference to stored centroids too large\n";
111    *error << "slack: " << slack << std::endl;
112    *error << "expected less than " << slack_bound << std::endl;
113    ok = false;
114  }
115
116  *error << "prediction...\n";
117  utility::matrix prediction;
118  ncc.predict(dataviewweighted,prediction);
119 
120  // Comparing the prediction to stored result
121  is.open("data/sorlie_centroid_predictions.txt");
122  utility::matrix result(is,'\t');
123  is.close();
124
125  slack = 0;
126  for (size_t i=0; i<result.rows(); i++){
127    for (size_t j=0; j<result.columns(); j++){
128        slack += fabs(result(i,j)-prediction(i,j));
129    }
130  }
131  slack /= (result.columns()*result.rows());
132  if (slack > slack_bound || std::isnan(slack)){
133    *error << "Difference to stored prediction too large\n";
134    *error << "slack: " << slack << std::endl;
135    *error << "expected less than " << slack_bound << std::endl;
136    ok = false;
137  }
138 
139
140  if(ok)
141    *error << "OK" << std::endl;
142
143
144  if (error!=&std::cerr)
145    delete error;
146
147  if(ok) 
148    return 0;
149  return -1;
150 
151}
Note: See TracBrowser for help on using the repository browser.