source: trunk/yat/statistics/ROC.h @ 2732

Last change on this file since 2732 was 2732, checked in by Peter, 11 years ago

fix comment typo

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 10.9 KB
Line 
1#ifndef _theplu_yat_statistics_roc_
2#define _theplu_yat_statistics_roc_
3
4// $Id: ROC.h 2732 2012-05-08 02:36:38Z peter $
5
6/*
7  Copyright (C) 2004 Peter Johansson
8  Copyright (C) 2005, 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
9  Copyright (C) 2011, 2012 Peter Johansson
10
11  This file is part of the yat library, http://dev.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 3 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with yat. If not, see <http://www.gnu.org/licenses/>.
25*/
26
27#include "Averager.h"
28#include "yat/utility/stl_utility.h"
29#include "yat/utility/yat_assert.h"
30
31#include <gsl/gsl_randist.h>
32
33#include <map>
34#include <utility>
35
36namespace theplu {
37namespace yat {
38namespace statistics {
39
40  ///
41  /// @brief Reciever Operating Characteristic.
42  ///
43  /// As the area under an ROC curve is equivalent to Mann-Whitney U
44  /// statistica, this class can be used to perform a Mann-Whitney
45  /// U-test (aka Wilcoxon).
46  ///
47  /// \see AUC
48  ///
49  class ROC
50  {
51
52  public:
53    ///
54    /// @brief Default constructor
55    ///
56    ROC(void);
57
58    /**
59       \brief Add a data value.
60
61       \param value data value
62       \param target \c true if value belongs to class positive
63
64       \param weight indicating how important the data point is. A
65       zero weight implies the data point is ignored. A negative
66       weight should be understood as removing a data point and thus
67       typically only makes sense if there is a previously added data
68       point with same \a value and \a target.
69
70    */
71    void add(double value, bool target, double weight=1.0);
72
73    /**
74       \brief Area Under Curve, AUC
75
76       \see AUC for how the area is calculated
77
78       @return Area under curve.
79    */
80    double area(void);
81
82    /**
83       \brief threshold for p_value calculation
84
85       Function can used to change the minimum_size.
86
87       \return reference to threshold minimum size
88     */
89    unsigned int& minimum_size(void);
90
91    /**
92       \brief threshold for p_value calculation
93
94       Threshold deciding whether p-value is computed using exact
95       method or a Gaussian approximation. If both number of positive
96       samples, n_pos(void), and number of negative samples,
97       n_neg(void), are smaller than minimum_size the exact method is
98       used.
99
100       \see p_value
101
102       \return const reference to minimum_size
103    */
104    const unsigned int& minimum_size(void) const;
105
106    ///
107    /// \brief number of samples
108    ///
109    /// @return sum of weights
110    ///
111    double n(void) const;
112
113    ///
114    /// \brief number of negative samples
115    ///
116    /// @return sum of weights with negative target
117    ///
118    double n_neg(void) const;
119
120    ///
121    /// \brief number of positive samples
122    ///
123    /// @return sum of weights with positive target
124    ///
125    double n_pos(void) const;
126
127    /**
128       \brief One-sided P-value
129
130       Calculates the one-sided p-value, i.e., probability to get this
131       area (or greater) given that there is no difference
132       between the two classes.
133
134       \b Exact \b method: In the exact method the function goes
135       through all permutations and counts what fraction for which the
136       area is greater (or equal) than area in original
137       permutation. In case all non-zero weights are not equal,
138       iterating through all permutations is not sufficient so
139       algorithm goes through all combinations instead which quickly
140       becomes a large number (N!).
141
142       \b Large-sample \b Approximation: When many data points are
143       available, see minimum_size(), a Gaussian approximation is used
144       and the p-value is calculated as
145       \f[
146       P = \frac{1}{\sqrt{2\pi}} \int_{-\infty}^z
147       \exp{\left(-\frac{t^2}{2}\right)} dt
148       \f]
149
150       where
151
152       \f[
153       z = \frac{\textrm{area} - 0.5 - 0.5/(n^+ \cdot n^-)}{s}
154       \f]
155
156       and
157
158       \f[
159       s^2 = \frac{n+1+\sum \left(n_x \cdot (n_x^2-1)\right)}
160       {12\cdot n^+\cdot n^-}
161       \f]
162
163       where sum runs over different data values (of ties) and \f$ n_x
164       \f$ is number data points with that value. The sum is a
165       correction term for ties and is zero if there are no ties.
166
167       The number of samples in a group, \f$ n^+ \f$, is calculated as
168       \f$ n = (\sum w)^2 / \sum w^2 \f$
169
170       \return \f$ P(a \ge \textrm{area}) \f$
171     */
172    double p_value_one_sided(void) const;
173
174    /**
175       \brief Two-sided p-value.
176
177       Calculates the probability to get an area, \c a, equal or more
178       extreme than \c area
179       \f[
180       P(a \ge \textrm{max}(\textrm{area},1-\textrm{area})) +
181       P(a \le \textrm{min}(\textrm{area}, 1-\textrm{area})) \f]
182
183       If there are no ties, distribution of \a a is symmetric, so if
184       area is greater than 0.5, this boils down to \f$ P = 2*P(a \ge
185       \textrm{area}) = 2*P_\textrm{one-sided}\f$.
186
187       \return two-sided p-value
188
189       \see p_value_one_sided
190    */
191    double p_value(void) const;
192
193    /**
194       \brief remove a data value
195
196       A data point with identical \a value, \a target, and \a weight
197       must have beed added prior calling this function; else an
198       exception is thrown.
199
200       \since New in yat 0.9
201     */
202    void remove(double value, bool target, double weight=1.0);
203
204    /**
205       @brief Set everything to zero
206    */
207    void reset(void);
208
209  private:
210    typedef std::multimap<double, std::pair<bool, double> > Map;
211
212    // struct used in count functions
213    struct Weights
214    {
215      Weights(void);
216      double small_pos;
217      double small_neg;
218      double tied_pos;
219      double tied_neg;
220    };
221
222    /// Implemented as in MatLab 13.1
223    double get_p_approx(double) const;
224
225    /**
226       return false if all non-zero weights are equal
227     */
228    bool is_weighted(void) const;
229
230    /**
231       return (sum x)^2 / sum x^2
232     */
233    size_t nof_points(const Averager& a) const;
234
235    /*
236      Calculate probability to get an area equal (smaller) than \a
237      area given the distribution of weights and ties in multimap_
238     */
239    double p_left_weighted(double area) const;
240
241    /*
242      Calculate probability to get an area equal (greater) than \a
243      area given the distribution of weights and ties in multimap_
244     */
245    double p_right_weighted(double area) const;
246
247    /*
248      Count number of combinations (of N!) that gives weight sum equal
249      or larger than \a threshold.
250
251      Range [first, last) is used to check for ties. If, e.g., *first
252      and *(first+1) are equal implies that the two largest values are
253      equal.
254     */
255    template <typename Iterator>
256    double count(Iterator first, Iterator last, double threshold) const;
257
258    /*
259      Loop over all elements in \a weights and call count(7)
260     */
261    template <typename Iterator>
262    double count(Map& weights, Iterator iter, Iterator last,
263                 double threshold, double sum, const Weights& weight) const;
264
265    /*
266      Count number of combinations in which sum>=threshold given
267      classes and weights in \a weight. Range [iter, last) is used to
268      handle ties.
269     */
270    template <typename Iterator>
271    double count(Map& weights, Iterator iter, Iterator last,
272                 double threshold, double sum, Weights weight,
273                 const std::pair<bool, double>& entry) const;
274
275    /*
276      Calculates probability to get \a block number of pairs correctly
277      sorted when having \a pos positive samples and \a neg negative
278      samples given the distribution of ties as in [first, last).
279     */
280    template<typename ForwardIterator>
281    double p_exact_with_ties(ForwardIterator first, ForwardIterator last,
282                             double block, unsigned int pos,
283                             unsigned int neg) const;
284
285    /**
286       \return P(auc >= area)
287     */
288    double p_exact_right(double area) const;
289
290    /**
291       \return P(auc <= area)
292     */
293    double p_exact_left(double area) const;
294
295    bool use_exact_method(void) const;
296
297    double area_;
298    bool has_ties_;
299    unsigned int minimum_size_;
300    Averager neg_weights_;
301    Averager pos_weights_;
302    Map multimap_;
303  };
304
305  template<typename ForwardIterator>
306  double
307  ROC::p_exact_with_ties(ForwardIterator begin, ForwardIterator end,
308                         double block, unsigned int pos,unsigned int neg) const
309  {
310    if (block <= 0)
311      return 1.0;
312    if (block > pos*neg)
313      return 0.0;
314
315    ForwardIterator iter(begin);
316    unsigned int n=0;
317    while (iter!=end && iter->first == begin->first) {
318      ++iter;
319      ++n;
320    }
321    double result = 0;
322    /*
323      pos1  neg1  |  n
324      pos2  neg2  |
325      ----  ----   ----
326      pos   neg
327     */
328
329    // ensure pos1 and neg2 are non-negative
330    unsigned int pos1 = n - std::min(n, neg);
331    // ensure pos2 and neg1 are non-negative
332    unsigned int max = std::min(n, pos);
333    YAT_ASSERT(pos1<=max);
334    for ( ; pos1<=max; ++pos1) {
335      unsigned int neg1 = n-pos1;
336      YAT_ASSERT(neg1<=n);
337      unsigned int pos2 = pos-pos1;
338      YAT_ASSERT(pos2<=pos);
339      unsigned int neg2 = neg-neg1;
340      YAT_ASSERT(neg2<=neg);
341      result += gsl_ran_hypergeometric_pdf(pos1, static_cast<unsigned int>(pos),
342                                           static_cast<unsigned int>(neg), n)
343        * p_exact_with_ties(iter, end,
344                            block - pos2*neg1 - 0.5*pos1*neg1,
345                            pos2, neg2);
346    }
347    return result;
348  }
349
350
351  template <typename Iterator>
352  double ROC::count(Iterator first, Iterator last, double threshold) const
353  {
354    Map map(multimap_);
355    ROC::Weights w;
356    w.small_pos = pos_weights_.sum_x();
357    w.small_neg = neg_weights_.sum_x();
358    return count(map, first, last, threshold*w.small_pos*w.small_neg, 0, w);
359  }
360
361
362
363  template <typename Iterator>
364  double ROC::count(Map& weights, Iterator iter, Iterator last,
365                    double threshold, double sum, const Weights& w) const
366  {
367    double result = 0.0;
368    // loop over all elements
369    for (Map::iterator i=weights.begin(); i!=weights.end(); ++i) {
370      Map::value_type save = *i;
371      Map::iterator hint = i;
372      ++hint;
373      weights.erase(i);
374      result += count(weights, iter, last, threshold, sum, w, save.second);
375      i = weights.insert(hint, save);
376    }
377    YAT_ASSERT(weights.size());
378    return result/weights.size();
379  }
380
381  template <typename Iterator>
382  double ROC::count(Map& weights, Iterator iter, Iterator last,
383                    double threshold, double sum, Weights w,
384                    const std::pair<bool, double>& entry) const
385  {
386    double tiny = 10e-10;
387
388    Iterator next(iter);
389    ++next;
390
391    // update weights
392    if (entry.first) {
393      w.tied_pos += entry.second;
394      w.small_pos -= entry.second;
395    }
396    else {
397      w.tied_neg += entry.second;
398      w.small_neg -= entry.second;
399    }
400
401    // last entry in equal range
402    if (next==last || *next!=*iter) {
403      sum += 0.5*w.tied_pos*w.tied_neg + w.tied_pos * w.small_neg;
404      w.tied_pos=0;
405      w.tied_neg=0;
406    }
407
408    // max sum happens if all pos values belong to current equal range
409    // and none of the remaining neg values
410    double max_sum = sum + 0.5*(w.tied_pos+w.small_pos)*w.tied_neg +
411      (w.tied_pos+w.small_pos)*w.small_neg;
412
413    if (max_sum<threshold-tiny)
414      return 0.0;
415    if (sum + 0.5*w.tied_pos*(w.tied_neg+w.small_neg) >= threshold-tiny)
416      return 1.0;
417
418    if (next!=last)
419      return count(weights, next, last, threshold, sum, w);
420    return 0.0;
421  }
422
423}}} // of namespace statistics, yat, and theplu
424#endif
Note: See TracBrowser for help on using the repository browser.