source: branches/0.13-stable/yat/statistics/ROC.cc @ 3433

Last change on this file since 3433 was 3433, checked in by Peter, 7 years ago

fixes #846

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 6.2 KB
Line 
1// $Id: ROC.cc 3433 2015-10-27 05:40:36Z peter $
2
3/*
4  Copyright (C) 2004, 2005 Peter Johansson
5  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
6  Copyright (C) 2011, 2012, 2013 Peter Johansson
7
8  This file is part of the yat library, http://dev.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 3 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with yat. If not, see <http://www.gnu.org/licenses/>.
22*/
23
24#include <config.h>
25
26#include "ROC.h"
27#include "AUC.h"
28
29#include "yat/utility/Exception.h"
30
31#include <gsl/gsl_cdf.h>
32
33#include <algorithm>
34#include <cassert>
35#include <cmath>
36#include <limits>
37#include <sstream>
38#include <utility>
39
40namespace theplu {
41namespace yat {
42namespace statistics {
43
44  ROC::ROC(void)
45    :minimum_size_(10)
46  {
47    reset();
48  }
49
50
51  void ROC::add(double x, bool target, double w)
52  {
53    if (!w)
54      return;
55    ROC::Map::value_type element(x, std::make_pair(target, w));
56    ROC::Map::iterator lower = multimap_.lower_bound(x);
57    if (lower!=multimap_.end() && lower->first == x)
58      has_ties_ = true;
59    multimap_.insert(lower, element);
60    if (target)
61      pos_weights_.add(w);
62    else
63      neg_weights_.add(w);
64    area_ = std::numeric_limits<double>::quiet_NaN();
65  }
66
67
68  double ROC::area(void) const
69  {
70    if (std::isnan(area_)){
71      AUC auc(false);
72      area_=auc.score(multimap_);
73    }
74    return area_;
75  }
76
77
78  double ROC::get_p_approx(double x) const
79  {
80    size_t n_pos = nof_points(pos_weights_);
81    size_t n_neg = nof_points(neg_weights_);
82    size_t nof_samples = n_pos + n_neg;
83    // make x standard normal
84    x -= 0.5;
85    // Not integrating from the middle of the bin, but from the inner edge.
86    if (x>0)
87      x -= 0.5/(n_pos*n_neg);
88    else if(x<0)
89      x += 0.5/(n_pos*n_neg);
90    else
91      return 0.5;
92    double var = 1.0+nof_samples;
93    if (has_ties_) {
94      double correction = 0;
95      Map::const_iterator first = multimap_.begin();
96      Map::const_iterator last = multimap_.begin();
97      while (first!=multimap_.end()) {
98        size_t n = 0;
99        while (first->first == last->first) {
100          ++n;
101          ++last;
102        }
103        correction += n * (n-1) * (n+1);
104        first = last;
105      }
106      /*
107        mn(N+1)/12-[mn/(12N(N-1)) * sum(t(t-1)(t+1))] =
108        mn/12 [ N+1 - 1/(N(N-1)) * sum(t(t-1)(t+1)) ]
109      */
110      var -= correction/(nof_samples * (nof_samples-1));
111    }
112    var = var / (12*n_pos*n_neg);
113    return gsl_cdf_gaussian_Q(x, std::sqrt(var));
114  }
115
116
117  bool ROC::is_weighted(void) const
118  {
119    return pos_weights_.variance() || neg_weights_.variance()
120      || pos_weights_.mean() != neg_weights_.mean();
121  }
122
123
124  unsigned int& ROC::minimum_size(void)
125  {
126    return minimum_size_;
127  }
128
129
130  const unsigned int& ROC::minimum_size(void) const
131  {
132    return minimum_size_;
133  }
134
135
136  double ROC::n(void) const
137  {
138    return n_pos()+n_neg();
139  }
140
141
142  double ROC::n_neg(void) const
143  {
144    return neg_weights_.sum_x();
145  }
146
147
148  double ROC::n_pos(void) const
149  {
150    return pos_weights_.sum_x();
151  }
152
153
154  size_t ROC::nof_points(const Averager& a) const
155  {
156    return static_cast<size_t>(a.sum_x()*a.sum_x()/a.sum_xx());
157  }
158
159
160  double ROC::p_exact_left(double area) const
161  {
162    if (is_weighted())
163      return p_left_weighted(area);
164    return p_exact_with_ties(multimap_.rbegin(), multimap_.rend(),
165                             (1-area)*pos_weights_.n()*neg_weights_.n(),
166                             pos_weights_.n(), neg_weights_.n());
167  }
168
169
170  double ROC::p_exact_right(double area) const
171  {
172    if (is_weighted())
173      return p_right_weighted(area);
174    return p_exact_with_ties(multimap_.begin(), multimap_.end(),
175                             area*pos_weights_.n()*neg_weights_.n(),
176                             pos_weights_.n(), neg_weights_.n());
177  }
178
179
180  double ROC::p_left_weighted(double area) const
181  {
182    return count(utility::pair_first_iterator(multimap_.begin()),
183                 utility::pair_first_iterator(multimap_.end()), 1-area);
184  }
185
186
187  double ROC::p_right_weighted(double area) const
188  {
189    return count(utility::pair_first_iterator(multimap_.rbegin()),
190                 utility::pair_first_iterator(multimap_.rend()), area);
191  }
192
193
194  double ROC::p_left() const
195  {
196    if (std::isnan(area()))
197      return std::numeric_limits<double>::quiet_NaN();
198    if (use_exact_method())
199      return p_exact_left(area());
200    return get_p_approx(1-area());
201  }
202
203
204  double ROC::p_right() const
205  {
206    if (std::isnan(area()))
207      return std::numeric_limits<double>::quiet_NaN();
208    if (use_exact_method())
209      return p_exact_right(area());
210    return get_p_approx(area());
211  }
212
213
214  double ROC::p_value() const
215  {
216    if (std::isnan(area()))
217      return std::numeric_limits<double>::quiet_NaN();
218    if (use_exact_method()) {
219      double p = 0;
220      double abs_area = std::max(area(), 1-area());
221      p = p_exact_right(abs_area);
222      if (has_ties_) {
223        p += p_exact_left(1.0 - abs_area);
224      }
225      else
226        p *= 2.0;
227      // avoid double counting when area is 0.5
228      return std::min(p, 1.0);
229    }
230    return 2*get_p_approx(std::max(area(), 1-area()));
231  }
232
233
234  double ROC::p_value_one_sided() const
235  {
236    return p_right();
237  }
238
239
240  void ROC::remove(double value, bool target, double weight)
241  {
242    if (!weight)
243      return;
244    std::pair<Map::iterator, Map::iterator> iter = multimap_.equal_range(value);
245    while (iter.first!=iter.second) {
246      if (iter.first->second.first==target && iter.first->second.first==target){
247        multimap_.erase(iter.first);
248        if (target)
249          pos_weights_.add(weight, -1);
250        else
251          neg_weights_.add(weight, -1);
252        area_ = std::numeric_limits<double>::quiet_NaN();
253        return;
254      }
255      ++iter.first;
256    }
257    std::stringstream ss;
258    ss << "ROC::remove(" << value << "," << target << "," << weight << "): "
259       << "no such element";
260    throw utility::runtime_error(ss.str());
261  }
262
263
264  void ROC::reset(void)
265  {
266    area_ = std::numeric_limits<double>::quiet_NaN();
267    has_ties_ = false;
268    neg_weights_.reset();
269    pos_weights_.reset();
270    multimap_.clear();
271  }
272
273
274  bool ROC::use_exact_method(void) const
275  {
276    return (n_pos() < minimum_size_) || (n_neg() < minimum_size_);
277  }
278
279
280  ROC::Weights::Weights(void)
281    : small_pos(0), small_neg(0), tied_pos(0), tied_neg(0)
282  {}
283
284}}} // of namespace statistics, yat, and theplu
Note: See TracBrowser for help on using the repository browser.