source: trunk/yat/statistics/AUC.cc @ 2551

Last change on this file since 2551 was 2551, checked in by Peter, 10 years ago

refs #144. Fix ROC::area for the tied case

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 3.4 KB
Line 
1// $Id: AUC.cc 2551 2011-08-12 02:14:01Z peter $
2
3/*
4  Copyright (C) 2004, 2005 Peter Johansson
5  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include <iostream>
24
25#include "AUC.h"
26#include "yat/classifier/DataLookupWeighted1D.h"
27#include "yat/classifier/Target.h"
28#include "yat/utility/stl_utility.h"
29#include "yat/utility/VectorBase.h"
30
31#include <cassert>
32#include <cmath>
33#include <utility>
34#include <map>
35
36namespace theplu {
37namespace yat {
38namespace statistics { 
39
40  AUC::AUC(bool absolute)
41    : Score(absolute)
42  {
43  }
44
45  double AUC::score(const classifier::Target& target, 
46                    const utility::VectorBase& value) const
47  {
48    assert(target.size()==value.size());
49    // key data, pair<target, weight>
50    std::multimap<double, std::pair<bool, double> > m;
51    for (unsigned int i=0; i<target.size(); i++)
52      m.insert(std::make_pair(value(i), 
53                              std::make_pair(target.binary(i),1.0)));
54       
55    return score(m);
56  }
57
58
59
60  double AUC::score(const classifier::Target& target, 
61                    const classifier::DataLookupWeighted1D& value) const
62  {
63    assert(target.size()==value.size());
64    // key data, pair<target, weight>
65    std::multimap<double, std::pair<bool, double> > m;
66    for (unsigned int i=0; i<target.size(); i++)
67      if (value.weight(i))
68        m.insert(std::make_pair(value.data(i), 
69                                std::make_pair(target.binary(i), 
70                                               value.weight(i))));
71       
72    return score(m);
73  }
74
75
76  double AUC::score(const classifier::Target& target, 
77                    const utility::VectorBase& value, 
78                    const utility::VectorBase& weight) const
79  {
80    assert(target.size()==value.size());
81    assert(target.size()==weight.size());
82    // key data, pair<target, weight>
83    std::multimap<double, std::pair<bool, double> > m;
84    for (unsigned int i=0; i<target.size(); i++)
85      if (weight(i))
86        m.insert(std::make_pair(value(i), 
87                                std::make_pair(target.binary(i), weight(i))));
88       
89    return score(m);
90  }
91
92
93  double AUC::score(const MultiMap& m) const
94  {
95    double area=0;
96    double cumsum_pos_w=0;
97    double cumsum_neg_w=0;
98    typedef MultiMap::const_iterator iter;
99
100    iter first = m.begin();
101    while (first!=m.end()) {
102      double local_cumsum_pos_w=0;
103      double local_cumsum_neg_w=0;
104      iter last = first;
105      while (last!=m.end() && first->first==last->first) {
106        if (last->second.first)
107          local_cumsum_pos_w += last->second.second;
108        else
109          local_cumsum_neg_w += last->second.second;
110        ++last;
111      } 
112      area += local_cumsum_pos_w * ( cumsum_neg_w + 0.5*local_cumsum_neg_w );
113      cumsum_pos_w += local_cumsum_pos_w;
114      cumsum_neg_w += local_cumsum_neg_w;
115      first = last;
116    }
117    // max area is cumsum_neg_w * cumsum_pos_w
118    area/=(cumsum_neg_w*cumsum_pos_w);
119   
120    if (area<0.5 && absolute_)
121      return 1.0-area;
122    return area;
123}
124
125}}} // of namespace statistics, yat, and theplu
Note: See TracBrowser for help on using the repository browser.