source: trunk/test/knn_test.cc @ 1107

Last change on this file since 1107 was 1107, checked in by Markus Ringnér, 15 years ago

Ticket #259 fixed for KNN

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.6 KB
Line 
1// $Id: knn_test.cc 1107 2008-02-19 15:23:52Z markus $
2
3/*
4  Copyright (C) 2007 Peter Johansson, Markus Ringnér
5
6  This file is part of the yat library, http://trac.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "yat/classifier/KNN.h"
25#include "yat/classifier/MatrixLookup.h"
26#include "yat/classifier/MatrixLookupWeighted.h"
27#include "yat/statistics/EuclideanDistance.h"
28#include "yat/utility/matrix.h"
29
30
31#include <cassert>
32#include <fstream>
33#include <iostream>
34#include <list>
35#include <string>
36#include <vector>
37
38
39using namespace theplu::yat;
40
41double deviation(const utility::matrix& a, const utility::matrix& b) {
42  double sl=0;
43  for (size_t i=0; i<a.rows(); i++){
44    for (size_t j=0; j<a.columns(); j++){
45      sl += fabs(a(i,j)-b(i,j));
46    }
47  }
48  sl /= (a.columns()*a.rows());
49  return sl;
50}
51
52int main(const int argc,const char* argv[])
53
54{ 
55  std::ostream* error;
56  if (argc>1 && argv[1]==std::string("-v"))
57    error = &std::cerr;
58  else {
59    error = new std::ofstream("/dev/null");
60    if (argc>1)
61      std::cout << "knn_test -v : for printing extra information\n";
62  }
63  *error << "testing knn" << std::endl;
64  bool ok = true;
65
66  ////////////////////////////////////////////////////////////////
67  // A test of training and predictions using unweighted data
68  ////////////////////////////////////////////////////////////////
69  *error << "test of predictions using unweighted training and test data\n";
70  utility::matrix data1(3,4);
71  for(size_t i=0;i<3;i++) {
72    data1(i,0)=3-i;
73    data1(i,1)=5-i;
74    data1(i,2)=i+1;
75    data1(i,3)=i+3;
76  }
77  std::vector<std::string> vec1(4, "pos");
78  vec1[0]="neg";
79  vec1[1]="neg";
80 
81  classifier::MatrixLookup ml1(data1);
82  classifier::Target target1(vec1);
83 
84  classifier::KNN<statistics::EuclideanDistance> knn1(ml1,target1);
85  knn1.k(3);
86  knn1.train();
87  utility::matrix prediction1;
88  knn1.predict(ml1,prediction1);
89  double slack_bound=2e-7;
90  utility::matrix result1(2,4);
91  result1(0,0)=result1(0,1)=result1(1,2)=result1(1,3)=2.0/3.0;
92  result1(0,2)=result1(0,3)=result1(1,0)=result1(1,1)=1.0/3.0;
93  double slack = deviation(prediction1,result1); 
94  if (slack > slack_bound || std::isnan(slack)){
95    *error << "Difference to expected prediction too large\n";
96    *error << "slack: " << slack << std::endl;
97    *error << "expected less than " << slack_bound << std::endl;
98    ok = false;
99  }
100 
101
102  ////////////////////////////////////////////////////////////////
103  // A test of training unweighted and test weighted
104  ////////////////////////////////////////////////////////////////
105  *error << "test of predictions using unweighted training and weighted test data\n";
106  utility::matrix weights1(3,4,1.0);
107  weights1(2,0)=0;
108  classifier::MatrixLookupWeighted mlw1(data1,weights1);
109  knn1.predict(mlw1,prediction1); 
110  result1(0,0)=1.0/3.0;
111  result1(1,0)=2.0/3.0;
112  slack = deviation(prediction1,result1);
113  if (slack > slack_bound || std::isnan(slack)){
114    *error << "Difference to expected prediction too large\n";
115    *error << "slack: " << slack << std::endl;
116    *error << "expected less than " << slack_bound << std::endl;
117    ok = false;
118  } 
119
120  ////////////////////////////////////////////////////////////////
121  // A test of training and test both weighted
122  ////////////////////////////////////////////////////////////////
123  *error << "test of predictions using weighted training and test data\n";
124  weights1(0,1)=0;
125  utility::matrix weights2(3,4,1.0);
126  weights2(2,3)=0;
127  classifier::MatrixLookupWeighted mlw2(data1,weights2);
128  classifier::KNN<statistics::EuclideanDistance> knn2(mlw2,target1);
129  knn2.k(3);
130  knn2.train();
131  knn2.predict(mlw1,prediction1); 
132  result1(0,1)=1.0/3.0;
133  result1(1,1)=2.0/3.0;
134  slack = deviation(prediction1,result1);
135  if (slack > slack_bound || std::isnan(slack)){
136    *error << "Difference to expected prediction too large\n";
137    *error << "slack: " << slack << std::endl;
138    *error << "expected less than " << slack_bound << std::endl;
139    ok = false;
140  } 
141
142
143  if(!ok) {
144    *error << "knn_test failed" << std::endl;
145  }
146  else {
147    *error << "OK" << std::endl;
148  }
149  if (error!=&std::cerr)
150    delete error;
151  if (ok=true) 
152    return 0;
153  return -1;
154}
155
156
Note: See TracBrowser for help on using the repository browser.