Changeset 2718


Ignore:
Timestamp:
Apr 12, 2012, 2:55:53 AM (11 years ago)
Author:
Peter
Message:

implement ROC p-value handling weights and ties. closes #211

Location:
trunk
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/roc.cc

    r2712 r2718  
    5353void test_p_exact_with_ties(test::Suite& suite);
    5454void test_p_with_weights(test::Suite& suite);
     55void test_p_with_weights_and_ties(test::Suite& suite);
    5556void test_ties(test::Suite& suite);
    5657
     
    246247  test_p_exact_weighted(suite);
    247248  test_p_approx_weighted(suite);
     249  test_p_with_weights_and_ties(suite);
     250}
     251
     252
     253void test_p_with_weights_and_ties(test::Suite& suite)
     254{
     255  suite.out() << "test p with weights and ties\n";
     256  statistics::ROC roc;
     257  roc.add(10, true, 1.0);
     258  roc.add(10, false, 3.0);
     259  roc.add(20, true, 2.0);
     260  roc.add(30, true, 1.0);
     261  if (!suite.equal(roc.area(), 0.875)) {
     262    suite.add(false);
     263    suite.out() << "roc area: " << roc.area() << "\n";
     264  }
     265  double p = roc.p_value_one_sided();
     266  if (!suite.equal(p, 8.0/24.0)) {
     267    suite.add(false);
     268    suite.out() << "p_value_one_sided() failed\n";
     269  }
     270  p = roc.p_value();
     271  if (!suite.equal(p, (8.0+6.0)/24.0)) {
     272    suite.add(false);
     273    suite.out() << "p_value() failed\n";
     274  }
    248275}
    249276
     
    306333  unsigned long perm = 0;
    307334  unsigned long k = 0;
     335  unsigned long k2 = 0;
    308336  while (true) {
    309337    ++perm;
     
    313341    if (roc2.area() >= roc.area())
    314342      ++k;
     343    if (roc2.area() <= 1-roc.area()+1e-10)
     344      ++k2;
    315345
    316346    if (!next_permutation(w.begin(), w.end()))
    317347      break;
    318348  }
    319   if (!suite.xadd(suite.equal(roc.p_value_one_sided(),
    320                               static_cast<double>(k)/perm))) {
     349  double p_value = roc.p_value_one_sided();
     350  roc.p_value_one_sided();
     351  if (!suite.add(suite.equal(p_value, static_cast<double>(k)/perm))) {
    321352    suite.out() << "area: " << roc.area() << "\n"
    322353                << perm << " permutations of which\n"
    323354                << k << " with larger (or equal) area "
    324355                << "corresponding to P=" << static_cast<double>(k)/perm << "\n"
    325                 << "p_value_one_sided() returned: " << roc.p_value_one_sided()
     356                << "p_value_one_sided() returned: " << p_value
    326357                << "\n";
    327358  }
    328 
     359  p_value = roc.p_value();
     360  if (!suite.add(suite.equal(p_value, static_cast<double>(k+k2)/perm))) {
     361    suite.out() << "area: " << roc.area() << "\n"
     362                << perm << " permutations of which\n"
     363                << k << " with larger (or equal) area and\n"
     364                << k2 << " with smaller (or equal) area\n"
     365                << "corresponding to P="
     366                << static_cast<double>(k+k2)/perm << "\n"
     367                << "p_value() returned: " << p_value
     368                << "\n";
     369  }
    329370}
    330371
  • trunk/yat/statistics/ROC.cc

    r2710 r2718  
    110110
    111111
     112  bool ROC::is_weighted(void) const
     113  {
     114    return pos_weights_.variance() || neg_weights_.variance()
     115      || pos_weights_.mean() != neg_weights_.mean();
     116  }
     117
    112118  unsigned int& ROC::minimum_size(void)
    113119  {
     
    146152
    147153
    148   double ROC::p_exact(double area) const
    149   {
     154  double ROC::p_exact_left(double area) const
     155  {
     156    if (is_weighted())
     157      return p_left_weighted(area);
     158    return p_exact_with_ties(multimap_.rbegin(), multimap_.rend(),
     159                             (1-area)*pos_weights_.n()*neg_weights_.n(),
     160                             pos_weights_.n(), neg_weights_.n());
     161  }
     162
     163
     164  double ROC::p_exact_right(double area) const
     165  {
     166    if (is_weighted())
     167      return p_right_weighted(area);
    150168    return p_exact_with_ties(multimap_.begin(), multimap_.end(),
    151169                             area*pos_weights_.n()*neg_weights_.n(),
    152170                             pos_weights_.n(), neg_weights_.n());
     171  }
     172
     173
     174  double ROC::p_left_weighted(double area) const
     175  {
     176    return count(utility::pair_first_iterator(multimap_.begin()),
     177                 utility::pair_first_iterator(multimap_.end()), 1-area);
     178  }
     179
     180
     181  double ROC::p_right_weighted(double area) const
     182  {
     183    return count(utility::pair_first_iterator(multimap_.rbegin()),
     184                 utility::pair_first_iterator(multimap_.rend()), area);
    153185  }
    154186
     
    166198      double p = 0;
    167199      double abs_area = std::max(area, 1-area);
    168       p = p_exact(abs_area);
     200      p = p_exact_right(abs_area);
    169201      if (has_ties_) {
    170         p += p_exact_with_ties(multimap_.rbegin(), multimap_.rend(),
    171                                abs_area*pos_weights_.n()*neg_weights_.n(),
    172                                pos_weights_.n(), neg_weights_.n());
     202        p += p_exact_left(1.0 - abs_area);
    173203      }
    174204      else
     
    191221      return std::numeric_limits<double>::quiet_NaN();
    192222    if (use_exact_method())
    193       return p_exact(area);
     223      return p_exact_right(area);
    194224    return get_p_approx(area);
    195225  }
     
    211241  }
    212242
     243
     244  ROC::Weights::Weights(void)
     245    : small_pos(0), small_neg(0), tied_pos(0), tied_neg(0)
     246  {}
     247
    213248}}} // of namespace statistics, yat, and theplu
  • trunk/yat/statistics/ROC.h

    r2710 r2718  
    77  Copyright (C) 2004 Peter Johansson
    88  Copyright (C) 2005, 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
    9   Copyright (C) 2011 Peter Johansson
     9  Copyright (C) 2011, 2012 Peter Johansson
    1010
    1111  This file is part of the yat library, http://dev.thep.lu.se/yat
     
    2626
    2727#include "Averager.h"
     28#include "yat/utility/stl_utility.h"
    2829#include "yat/utility/yat_assert.h"
    2930
     
    133134       \b Exact \b method: In the exact method the function goes
    134135       through all permutations and counts what fraction for which the
    135        area is greater (or equal) than area in original permutation.
     136       area is greater (or equal) than area in original
     137       permutation. In case all non-zero weights are not equal,
     138       iterating through all permutations is not sufficient so
     139       algorithm goes through all combinations instead which quickly
     140       becomes a large number (N!).
    136141
    137142       \b Large-sample \b Approximation: When many data points are
     
    194199    typedef std::multimap<double, std::pair<bool, double> > Map;
    195200
     201    // struct used i count functions
     202    struct Weights
     203    {
     204      Weights(void);
     205      double small_pos;
     206      double small_neg;
     207      double tied_pos;
     208      double tied_neg;
     209    };
     210
    196211    /// Implemented as in MatLab 13.1
    197212    double get_p_approx(double) const;
    198213
    199214    /**
     215       return false if all non-zero weights are equal
     216     */
     217    bool is_weighted(void) const;
     218
     219    /**
    200220       return (sum x)^2 / sum x^2
    201221     */
     
    203223
    204224    /*
     225      Calculate probability to get an area equal (smaller) than \a
     226      area given the distribution of weights and ties in multimap_
     227     */
     228    double p_left_weighted(double area) const;
     229
     230    /*
     231      Calculate probability to get an area equal (greater) than \a
     232      area given the distribution of weights and ties in multimap_
     233     */
     234    double p_right_weighted(double area) const;
     235
     236    /*
     237      Count number of combinations (of N!) that gives weight sum equal
     238      or larger than \a threshold.
     239
     240      Range [first, last) is used to check for ties. If, e.g., *first
     241      and *(first+1) are equal implies that the two largest values are
     242      equal.
     243     */
     244    template <typename Iterator>
     245    double count(Iterator first, Iterator last, double threshold) const;
     246
     247    /*
     248      Loop over all elements in \a weights and call count(7)
     249     */
     250    template <typename Iterator>
     251    double count(Map& weights, Iterator iter, Iterator last,
     252                 double threshold, double sum, const Weights& weight) const;
     253
     254    /*
     255      Count number of combinations in which sum>=threshold given
     256      classes and weights in \a weight. Range [iter, last) is used to
     257      handle ties.
     258     */
     259    template <typename Iterator>
     260    double count(Map& weights, Iterator iter, Iterator last,
     261                 double threshold, double sum, Weights weight,
     262                 const std::pair<bool, double>& entry) const;
     263
     264    /*
     265      Calculates probability to get \a block number of pairs correctly
     266      sorted when having \a pos positive samples and \a neg negative
     267      samples given the distribution of ties as in [first, last).
    205268     */
    206269    template<typename ForwardIterator>
     
    210273
    211274    /**
    212        \return probability to get auc >= \a area. If area<0.5
    213        probability to auc <= area is returned
    214 
    215        \note assumes all non-zero weights are equal (typically unity
    216        but not necessarily
    217     */
    218     double p_exact(double area) const;
     275       \return P(auc >= area)
     276     */
     277    double p_exact_right(double area) const;
     278
     279    /**
     280       \return P(auc <= area)
     281     */
     282    double p_exact_left(double area) const;
    219283
    220284    bool use_exact_method(void) const;
     
    273337  }
    274338
     339
     340  template <typename Iterator>
     341  double ROC::count(Iterator first, Iterator last, double threshold) const
     342  {
     343    Map map(multimap_);
     344    ROC::Weights w;
     345    w.small_pos = pos_weights_.sum_x();
     346    w.small_neg = neg_weights_.sum_x();
     347    return count(map, first, last, threshold*w.small_pos*w.small_neg, 0, w);
     348  }
     349
     350
     351
     352  template <typename Iterator>
     353  double ROC::count(Map& weights, Iterator iter, Iterator last,
     354                    double threshold, double sum, const Weights& w) const
     355  {
     356    double result = 0.0;
     357    // loop over all elements
     358    for (Map::iterator i=weights.begin(); i!=weights.end(); ++i) {
     359      Map::value_type save = *i;
     360      Map::iterator hint = i;
     361      ++hint;
     362      weights.erase(i);
     363      result += count(weights, iter, last, threshold, sum, w, save.second);
     364      i = weights.insert(hint, save);
     365    }
     366    YAT_ASSERT(weights.size());
     367    return result/weights.size();
     368  }
     369
     370  template <typename Iterator>
     371  double ROC::count(Map& weights, Iterator iter, Iterator last,
     372                    double threshold, double sum, Weights w,
     373                    const std::pair<bool, double>& entry) const
     374  {
     375    double tiny = 10e-10;
     376
     377    Iterator next(iter);
     378    ++next;
     379
     380    // update weights
     381    if (entry.first) {
     382      w.tied_pos += entry.second;
     383      w.small_pos -= entry.second;
     384    }
     385    else {
     386      w.tied_neg += entry.second;
     387      w.small_neg -= entry.second;
     388    }
     389
     390    // last entry in equal range
     391    if (next==last || *next!=*iter) {
     392      sum += 0.5*w.tied_pos*w.tied_neg + w.tied_pos * w.small_neg;
     393      w.tied_pos=0;
     394      w.tied_neg=0;
     395    }
     396
     397    // max sum happens if all pos values belong to current equal range
     398    // and none of the remaining neg values
     399    double max_sum = sum + 0.5*(w.tied_pos+w.small_pos)*w.tied_neg +
     400      (w.tied_pos+w.small_pos)*w.small_neg;
     401
     402    if (max_sum<threshold-tiny)
     403      return 0.0;
     404    if (sum >= threshold-tiny)
     405      return 1.0;
     406
     407    if (next!=last)
     408      return count(weights, next, last, threshold, sum, w);
     409    return 0.0;
     410  }
     411
    275412}}} // of namespace statistics, yat, and theplu
    276413#endif
Note: See TracChangeset for help on using the changeset viewer.