source: branches/0.13-stable/test/roc.cc @ 3437

Last change on this file since 3437 was 3437, checked in by Peter, 7 years ago

update copyright years

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 11.8 KB
Line 
1// $Id: roc.cc 3437 2015-11-20 04:44:02Z peter $
2
3/*
4  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2011, 2012, 2013, 2015 Peter Johansson
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include <config.h>
24
25#include "Suite.h"
26
27#include "yat/classifier/DataLookupWeighted1D.h"
28#include "yat/random/random.h"
29#include "yat/classifier/Target.h"
30#include "yat/statistics/Averager.h"
31#include "yat/statistics/Fisher.h"
32#include "yat/statistics/ROC.h"
33#include "yat/statistics/utility.h"
34#include "yat/utility/Vector.h"
35
36#include <gsl/gsl_cdf.h>
37
38#include <cassert>
39#include <cmath>
40#include <deque>
41#include <fstream>
42#include <iostream>
43#include <set>
44#include <vector>
45
46using namespace theplu::yat;
47
48void test_empty(test::Suite&);
49void test_p_approx_weighted(test::Suite& suite);
50void test_p_approx_with_ties(test::Suite& suite);
51void test_p_approx(test::Suite& suite);
52void test_p_double_weights(test::Suite& suite);
53void test_p_exact(test::Suite& suite);
54void test_p_exact_weighted(test::Suite& suite);
55void test_p_exact_with_ties(test::Suite& suite);
56void test_p_with_weights(test::Suite& suite);
57void test_p_with_weights_and_ties(test::Suite& suite);
58void test_remove(test::Suite& suite);
59void test_ties(test::Suite& suite);
60
61int main(int argc, char* argv[])
62{
63  test::Suite suite(argc, argv);
64
65  suite.err() << "testing ROC" << std::endl;
66  utility::Vector value(31);
67  std::vector<std::string> label(31,"negative");
68  for (size_t i=0; i<16; i++)
69    label[i] = "positive";
70  classifier::Target target(label);
71  for (size_t i=0; i<value.size(); i++)
72    value(i)=i;
73  statistics::ROC roc;
74  add(roc, value.begin(), value.end(), target);
75  double area = roc.area();
76  if (!suite.equal(area,0.0)){
77    suite.err() << "test_roc: area is " << area << " should be 0.0"
78                << std::endl;
79    suite.add(false);
80  }
81  target.set_binary(0,false);
82  target.set_binary(1,true);
83  roc.reset();
84  add(roc, value.begin(), value.end(), target);
85  area = roc.area();
86  if (!suite.equal(area,1.0)){
87    suite.err() << "test_roc: area is " << area << " should be 1.0"
88                << std::endl;
89    suite.add(false);
90  }
91
92  double p = roc.p_right();
93  double p2 = roc.p_value();
94  double p_matlab = 0.00000115;
95  if (!(p/p_matlab < 1.01 && p/p_matlab > 0.99)){
96    suite.err() << "get_p_approx: p-value not correct" << std::endl;
97    suite.err() << p << " expected " << p_matlab << std::endl;
98    suite.add(false);
99  }
100  if (!(p2==2*p)) {
101    suite.add(false);
102    suite.err() << "Two-sided P-value should equal 2 * p_right.\n";
103  }
104  roc.minimum_size() = 20;
105  p = roc.p_right();
106  p2 = roc.p_value();
107  if (!( p < 1e-8 && p > 1e-9) ){
108    suite.err() << "get_p_exact: p-value not correct" << std::endl;
109    suite.add(false);
110  }
111  if (!( p2==2*p)) {
112    suite.add(false);
113    suite.err() << "Two-sided P-value should equal 2 * p_right.\n";
114  }
115
116  classifier::DataLookupWeighted1D dlw(target.size(),1.3);
117  add(roc, dlw.begin(), dlw.end(), target);
118  test_ties(suite);
119  test_p_approx_with_ties(suite);
120  test_p_exact_with_ties(suite);
121  test_p_approx(suite);
122  test_p_exact(suite);
123  test_empty(suite);
124  test_p_with_weights(suite);
125  test_remove(suite);
126  return suite.return_value();
127}
128
129
130void test_p_exact_with_ties(test::Suite& suite)
131{
132  suite.out() << "test p exact with ties\n";
133  statistics::ROC roc;
134  /*
135    +++-- 6
136    ++-+- 5 4.5 *** our case ***
137    +-++- 4 4.5
138    ++--+ 4 3.5
139    +-+-+ 3 3.5
140    +--++ 2 2
141    -+++- 3 3
142    -++-+ 2 2
143    -+-++ 1 0.5 *** our second case ***
144    --+++ 0 0.5
145   */
146  roc.add(2, true);
147  roc.add(1, true);
148  roc.add(1, false);
149  roc.add(0, true);
150  roc.add(-1, false);
151  roc.area();
152  if (!suite.equal(roc.p_right(), 3.0/10.0)) {
153    suite.add(false);
154    suite.out() << "  p_right: expected 0.3\n";
155  }
156  else
157    suite.add(true);
158  if (!suite.equal(roc.p_value(), 5.0/10.0)) {
159    suite.add(false);
160    suite.out() << "  (two-sided) p_value: expected 0.5\n";
161  }
162  else
163    suite.add(true);
164
165  suite.out() << "test p exact with ties II\n";
166  roc.reset();
167  roc.add(2, false);
168  roc.add(1, true);
169  roc.add(1, false);
170  roc.add(0, true);
171  roc.add(-1, true);
172  suite.add(suite.equal(roc.area(), 0.5/6));
173  if (!suite.add(suite.equal(roc.p_right(), 10.0/10.0)))
174    suite.out() << "  p_right: expected 0.3\n";
175  if (!suite.add(suite.equal(roc.p_value(), 3.0/10.0)))
176    suite.out() << "  (two-sided) p_value: expected 0.5\n";
177}
178
179
180void test_p_approx_with_ties(test::Suite& suite)
181{
182  suite.out() << "test p approx with ties\n";
183  statistics::ROC roc;
184  for (size_t i=0; i<100; ++i) {
185    roc.add(1, i<60);
186    roc.add(0, i<40);
187  }
188  suite.add(suite.equal(roc.area(), 0.6));
189  // Having only two data values, 0 and 1, data can be represented as
190  // a 2x2 contigency table, and ROC test is same as Fisher's exact
191  // test.
192  statistics::Fisher fisher;
193  fisher.oddsratio(60, 40, 40, 60);
194  suite.add(suite.equal_fix(roc.p_value(), fisher.p_value(), 0.0002));
195}
196
197void test_ties(test::Suite& suite)
198{
199  suite.out() << "test ties\n";
200  statistics::ROC roc;
201  for (size_t i=0; i<20; ++i)
202    roc.add(10.0, i<10);
203  if (!suite.add(suite.equal(roc.area(), 0.5))) {
204    suite.err() << "error: roc with ties: area: " << roc.area() << "\n";
205  }
206}
207
208void test_p_exact(test::Suite& suite)
209{
210  suite.out() << "test_p_exact\n";
211  statistics::ROC roc;
212  for (size_t i=0; i<9; ++i)
213    roc.add(i, i<5);
214  if (roc.p_right()<0.5) {
215    suite.add(false);
216    suite.err() << "error: expected p-value>0.5\n  found: "
217                << roc.p_right() << "\n";
218  }
219}
220
221
222void test_p_approx(test::Suite& suite)
223{
224  suite.out() << "test_p_approx\n";
225  statistics::ROC roc;
226  for (size_t i=0; i<100; ++i)
227    roc.add(i, i<50);
228  if (roc.p_right()<0.5) {
229    suite.add(false);
230    suite.err() << "error: expected p-value>0.5\n  found: "
231                << roc.p_right() << "\n";
232  }
233  if (roc.p_value() > 1.0) {
234    suite.err() << "error: expected p-value <= 1\n    found: "
235                << roc.p_value() << "\n";
236    suite.add(false);
237  }
238}
239
240
241void test_empty(test::Suite& suite)
242{
243  suite.err() << "test empty\n";
244  // testing bug #669
245  statistics::ROC roc;
246  roc.p_value();
247  roc.area();
248  suite.err() << "test empty done\n";
249}
250
251
252void test_p_with_weights(test::Suite& suite)
253{
254  suite.out() << "test p with weights\n";
255  test_p_double_weights(suite);
256  test_p_exact_weighted(suite);
257  test_p_approx_weighted(suite);
258  test_p_with_weights_and_ties(suite);
259}
260
261
262void test_p_with_weights_and_ties(test::Suite& suite)
263{
264  suite.out() << "test p with weights and ties\n";
265  statistics::ROC roc;
266  roc.add(10, true, 1.0);
267  roc.add(10, false, 3.0);
268  roc.add(20, true, 2.0);
269  roc.add(30, true, 1.0);
270  if (!suite.equal(roc.area(), 0.875)) {
271    suite.add(false);
272    suite.out() << "roc area: " << roc.area() << "\n";
273  }
274  double p = roc.p_right();
275  if (!suite.equal(p, 8.0/24.0)) {
276    suite.add(false);
277    suite.out() << "p_right() failed\n";
278  }
279  p = roc.p_value();
280  if (!suite.equal(p, (8.0+6.0)/24.0)) {
281    suite.add(false);
282    suite.out() << "p_value() failed\n";
283  }
284}
285
286
287void test_p_double_weights(test::Suite& suite)
288{
289  suite.out() << "test p with double weights\n";
290  std::vector<double> w(5,1.0);
291  w[0]=0.1;
292  w[4]=10;
293  std::vector<double> x(5);
294  for (size_t i=0; i<x.size(); ++i)
295    x[i] = i;
296  statistics::ROC roc;
297  statistics::ROC roc2;
298  for (size_t i=0; i<x.size(); ++i) {
299    roc.add(x[i], i<2, w[i]);
300    roc2.add(x[i], i<2, 2*w[i]);
301  }
302  if (!suite.equal(roc.area(), roc2.area())) {
303    suite.add(false);
304    suite.err() << "area failed\n";
305  }
306  if (!suite.equal(roc.p_right(), roc2.p_right())) {
307    suite.add(false);
308    suite.err() << "exact p failed\n";
309  }
310  roc.minimum_size() = 0;
311  roc2.minimum_size() = 0;
312  if (!suite.equal(roc.p_right(), roc2.p_right())) {
313    suite.add(false);
314    suite.err() << "approximative p failed\n";
315  }
316}
317
318
319void test_p_exact_weighted(test::Suite& suite)
320{
321  suite.out() << "test p exact weighted\n";
322  std::vector<double> x(10);
323  std::vector<std::pair<bool, double> > w(10);
324  for (size_t i=0; i<x.size(); ++i) {
325    x[i] = i;
326    w[i].first = false;
327    w[i].second = 1.0;
328  }
329  w[3].first = true;
330  w[7].first = true;
331  w[8].first = true;
332  w[1].second = 10;
333  w[7].second = 10;
334  w[9].second = 0.1;
335
336  statistics::ROC roc;
337  for (size_t i=0; i<x.size(); ++i)
338    roc.add(x[i], w[i].first, w[i].second);
339  roc.minimum_size() = 100;
340
341  std::sort(w.begin(), w.end());
342  unsigned long perm = 0;
343  unsigned long k = 0;
344  unsigned long k2 = 0;
345  while (true) {
346    ++perm;
347    statistics::ROC roc2;
348    for (size_t i=0; i<x.size(); ++i)
349      roc2.add(x[i], w[i].first, w[i].second);
350    if (roc2.area() >= roc.area())
351      ++k;
352    if (roc2.area() <= 1-roc.area()+1e-10)
353      ++k2;
354
355    if (!next_permutation(w.begin(), w.end()))
356      break;
357  }
358  double p_value = roc.p_right();
359  if (!suite.add(suite.equal(p_value, static_cast<double>(k)/perm))) {
360    suite.out() << "area: " << roc.area() << "\n"
361                << perm << " permutations of which\n"
362                << k << " with larger (or equal) area "
363                << "corresponding to P=" << static_cast<double>(k)/perm << "\n"
364                << "p_right() returned: " << p_value
365                << "\n";
366  }
367  p_value = roc.p_value();
368  if (!suite.add(suite.equal(p_value, static_cast<double>(k+k2)/perm))) {
369    suite.out() << "area: " << roc.area() << "\n"
370                << perm << " permutations of which\n"
371                << k << " with larger (or equal) area and\n"
372                << k2 << " with smaller (or equal) area\n"
373                << "corresponding to P="
374                << static_cast<double>(k+k2)/perm << "\n"
375                << "p_value() returned: " << p_value
376                << "\n";
377  }
378}
379
380
381void test_p_approx_weighted(test::Suite& suite)
382{
383  suite.out() << "test p approx weighted\n";
384  std::vector<double> x(200);
385  std::vector<double> w(200, 1.0);
386  std::deque<bool> label(200);
387
388  for (size_t i=0; i<x.size(); ++i) {
389    x[i] = i;
390    label[i] = i>30 && i<70;
391    if (i<100)
392      w[i] = 100.0 / (100+i);
393    else
394      w[i] = 0.0001;
395  }
396
397  statistics::ROC roc;
398  for (size_t i=0; i<x.size(); ++i)
399    roc.add(x[i], label[i], w[i]);
400  roc.minimum_size() = 0;
401  double p = roc.p_right();
402
403  std::set<size_t> checkpoints;
404  size_t perm = 100000;
405  checkpoints.insert(10);
406  checkpoints.insert(100);
407  checkpoints.insert(1000);
408  checkpoints.insert(10000);
409  checkpoints.insert(perm);
410  statistics::Averager averager;
411  for (size_t i=1; i<=perm; ++i) {
412    theplu::yat::random::random_shuffle(x.begin(), x.end());
413    statistics::ROC roc2;
414    for (size_t j=0; j<x.size(); ++j)
415      roc2.add(x[j], label[j], w[j]);
416    if (roc2.area()>=roc.area())
417      averager.add(1.0);
418    else
419      averager.add(0.0);
420    if (checkpoints.find(i)!=checkpoints.end()) {
421      if (gsl_cdf_binomial_P(averager.sum_x(), p, averager.n())<1e-10 ||
422          gsl_cdf_binomial_Q(averager.sum_x(), p, averager.n())<1e-10) {
423        suite.err() << "error: approx p value and permutation p-value "
424                    << "deviate more than expected\n"
425                    << "area: " << roc.area() << "\n"
426                    << "approx p: " << p << "\n"
427                    << "permutations: " << averager.n() << "\n"
428                    << "successful: " << averager.sum_x() << "\n"
429                    << "corresponds to P=" << averager.mean() << "\n";
430        suite.add(false);
431        return;
432      }
433    }
434  }
435}
436
437void test_remove(test::Suite& suite)
438{
439  using statistics::ROC;
440  ROC roc;
441  roc.add(1, true);
442  roc.add(2, false);
443  ROC roc2(roc);
444  if (!suite.add(suite.equal(roc.area(), roc2.area())))
445    suite.out() << "test_remove failed: copy failed\n";
446  roc.add(2.3, true, 1.2);
447  try {
448    roc.remove(2.3, true, 1.2);
449  }
450  catch (std::runtime_error& e) {
451    suite.add(false);
452    suite.out() << "exception what(): " << e.what() << "\n";
453  }
454  if (!suite.add(suite.equal(roc.area(), roc2.area())))
455    suite.out() << "test remove failed\n";
456  try {
457    roc.remove(2, true);
458    suite.out() << "no exception thrown\n";
459    suite.add(false);
460  }
461  catch (std::runtime_error& e) {
462    suite.add(true);
463  }
464}
Note: See TracBrowser for help on using the repository browser.