Changeset 2718
- Timestamp:
- Apr 12, 2012, 2:55:53 AM (11 years ago)
- Location:
- trunk
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/test/roc.cc
r2712 r2718 53 53 void test_p_exact_with_ties(test::Suite& suite); 54 54 void test_p_with_weights(test::Suite& suite); 55 void test_p_with_weights_and_ties(test::Suite& suite); 55 56 void test_ties(test::Suite& suite); 56 57 … … 246 247 test_p_exact_weighted(suite); 247 248 test_p_approx_weighted(suite); 249 test_p_with_weights_and_ties(suite); 250 } 251 252 253 void test_p_with_weights_and_ties(test::Suite& suite) 254 { 255 suite.out() << "test p with weights and ties\n"; 256 statistics::ROC roc; 257 roc.add(10, true, 1.0); 258 roc.add(10, false, 3.0); 259 roc.add(20, true, 2.0); 260 roc.add(30, true, 1.0); 261 if (!suite.equal(roc.area(), 0.875)) { 262 suite.add(false); 263 suite.out() << "roc area: " << roc.area() << "\n"; 264 } 265 double p = roc.p_value_one_sided(); 266 if (!suite.equal(p, 8.0/24.0)) { 267 suite.add(false); 268 suite.out() << "p_value_one_sided() failed\n"; 269 } 270 p = roc.p_value(); 271 if (!suite.equal(p, (8.0+6.0)/24.0)) { 272 suite.add(false); 273 suite.out() << "p_value() failed\n"; 274 } 248 275 } 249 276 … … 306 333 unsigned long perm = 0; 307 334 unsigned long k = 0; 335 unsigned long k2 = 0; 308 336 while (true) { 309 337 ++perm; … … 313 341 if (roc2.area() >= roc.area()) 314 342 ++k; 343 if (roc2.area() <= 1-roc.area()+1e-10) 344 ++k2; 315 345 316 346 if (!next_permutation(w.begin(), w.end())) 317 347 break; 318 348 } 319 if (!suite.xadd(suite.equal(roc.p_value_one_sided(), 320 static_cast<double>(k)/perm))) { 349 double p_value = roc.p_value_one_sided(); 350 roc.p_value_one_sided(); 351 if (!suite.add(suite.equal(p_value, static_cast<double>(k)/perm))) { 321 352 suite.out() << "area: " << roc.area() << "\n" 322 353 << perm << " permutations of which\n" 323 354 << k << " with larger (or equal) area " 324 355 << "corresponding to P=" << static_cast<double>(k)/perm << "\n" 325 << "p_value_one_sided() returned: " << roc.p_value_one_sided()356 << "p_value_one_sided() returned: " << p_value 326 357 << "\n"; 327 358 } 328 359 p_value = roc.p_value(); 360 if (!suite.add(suite.equal(p_value, static_cast<double>(k+k2)/perm))) { 361 suite.out() << "area: " << roc.area() << "\n" 362 << perm << " permutations of which\n" 363 << k << " with larger (or equal) area and\n" 364 << k2 << " with smaller (or equal) area\n" 365 << "corresponding to P=" 366 << static_cast<double>(k+k2)/perm << "\n" 367 << "p_value() returned: " << p_value 368 << "\n"; 369 } 329 370 } 330 371 -
trunk/yat/statistics/ROC.cc
r2710 r2718 110 110 111 111 112 bool ROC::is_weighted(void) const 113 { 114 return pos_weights_.variance() || neg_weights_.variance() 115 || pos_weights_.mean() != neg_weights_.mean(); 116 } 117 112 118 unsigned int& ROC::minimum_size(void) 113 119 { … … 146 152 147 153 148 double ROC::p_exact(double area) const 149 { 154 double ROC::p_exact_left(double area) const 155 { 156 if (is_weighted()) 157 return p_left_weighted(area); 158 return p_exact_with_ties(multimap_.rbegin(), multimap_.rend(), 159 (1-area)*pos_weights_.n()*neg_weights_.n(), 160 pos_weights_.n(), neg_weights_.n()); 161 } 162 163 164 double ROC::p_exact_right(double area) const 165 { 166 if (is_weighted()) 167 return p_right_weighted(area); 150 168 return p_exact_with_ties(multimap_.begin(), multimap_.end(), 151 169 area*pos_weights_.n()*neg_weights_.n(), 152 170 pos_weights_.n(), neg_weights_.n()); 171 } 172 173 174 double ROC::p_left_weighted(double area) const 175 { 176 return count(utility::pair_first_iterator(multimap_.begin()), 177 utility::pair_first_iterator(multimap_.end()), 1-area); 178 } 179 180 181 double ROC::p_right_weighted(double area) const 182 { 183 return count(utility::pair_first_iterator(multimap_.rbegin()), 184 utility::pair_first_iterator(multimap_.rend()), area); 153 185 } 154 186 … … 166 198 double p = 0; 167 199 double abs_area = std::max(area, 1-area); 168 p = p_exact (abs_area);200 p = p_exact_right(abs_area); 169 201 if (has_ties_) { 170 p += p_exact_with_ties(multimap_.rbegin(), multimap_.rend(), 171 abs_area*pos_weights_.n()*neg_weights_.n(), 172 pos_weights_.n(), neg_weights_.n()); 202 p += p_exact_left(1.0 - abs_area); 173 203 } 174 204 else … … 191 221 return std::numeric_limits<double>::quiet_NaN(); 192 222 if (use_exact_method()) 193 return p_exact (area);223 return p_exact_right(area); 194 224 return get_p_approx(area); 195 225 } … … 211 241 } 212 242 243 244 ROC::Weights::Weights(void) 245 : small_pos(0), small_neg(0), tied_pos(0), tied_neg(0) 246 {} 247 213 248 }}} // of namespace statistics, yat, and theplu -
trunk/yat/statistics/ROC.h
r2710 r2718 7 7 Copyright (C) 2004 Peter Johansson 8 8 Copyright (C) 2005, 2006, 2007, 2008 Jari Häkkinen, Peter Johansson 9 Copyright (C) 2011 Peter Johansson9 Copyright (C) 2011, 2012 Peter Johansson 10 10 11 11 This file is part of the yat library, http://dev.thep.lu.se/yat … … 26 26 27 27 #include "Averager.h" 28 #include "yat/utility/stl_utility.h" 28 29 #include "yat/utility/yat_assert.h" 29 30 … … 133 134 \b Exact \b method: In the exact method the function goes 134 135 through all permutations and counts what fraction for which the 135 area is greater (or equal) than area in original permutation. 136 area is greater (or equal) than area in original 137 permutation. In case all non-zero weights are not equal, 138 iterating through all permutations is not sufficient so 139 algorithm goes through all combinations instead which quickly 140 becomes a large number (N!). 136 141 137 142 \b Large-sample \b Approximation: When many data points are … … 194 199 typedef std::multimap<double, std::pair<bool, double> > Map; 195 200 201 // struct used i count functions 202 struct Weights 203 { 204 Weights(void); 205 double small_pos; 206 double small_neg; 207 double tied_pos; 208 double tied_neg; 209 }; 210 196 211 /// Implemented as in MatLab 13.1 197 212 double get_p_approx(double) const; 198 213 199 214 /** 215 return false if all non-zero weights are equal 216 */ 217 bool is_weighted(void) const; 218 219 /** 200 220 return (sum x)^2 / sum x^2 201 221 */ … … 203 223 204 224 /* 225 Calculate probability to get an area equal (smaller) than \a 226 area given the distribution of weights and ties in multimap_ 227 */ 228 double p_left_weighted(double area) const; 229 230 /* 231 Calculate probability to get an area equal (greater) than \a 232 area given the distribution of weights and ties in multimap_ 233 */ 234 double p_right_weighted(double area) const; 235 236 /* 237 Count number of combinations (of N!) that gives weight sum equal 238 or larger than \a threshold. 239 240 Range [first, last) is used to check for ties. If, e.g., *first 241 and *(first+1) are equal implies that the two largest values are 242 equal. 243 */ 244 template <typename Iterator> 245 double count(Iterator first, Iterator last, double threshold) const; 246 247 /* 248 Loop over all elements in \a weights and call count(7) 249 */ 250 template <typename Iterator> 251 double count(Map& weights, Iterator iter, Iterator last, 252 double threshold, double sum, const Weights& weight) const; 253 254 /* 255 Count number of combinations in which sum>=threshold given 256 classes and weights in \a weight. Range [iter, last) is used to 257 handle ties. 258 */ 259 template <typename Iterator> 260 double count(Map& weights, Iterator iter, Iterator last, 261 double threshold, double sum, Weights weight, 262 const std::pair<bool, double>& entry) const; 263 264 /* 265 Calculates probability to get \a block number of pairs correctly 266 sorted when having \a pos positive samples and \a neg negative 267 samples given the distribution of ties as in [first, last). 205 268 */ 206 269 template<typename ForwardIterator> … … 210 273 211 274 /** 212 \return probability to get auc >= \a area. If area<0.5 213 probability to auc <= area is returned 214 215 \note assumes all non-zero weights are equal (typically unity 216 but not necessarily 217 */ 218 double p_exact(double area) const; 275 \return P(auc >= area) 276 */ 277 double p_exact_right(double area) const; 278 279 /** 280 \return P(auc <= area) 281 */ 282 double p_exact_left(double area) const; 219 283 220 284 bool use_exact_method(void) const; … … 273 337 } 274 338 339 340 template <typename Iterator> 341 double ROC::count(Iterator first, Iterator last, double threshold) const 342 { 343 Map map(multimap_); 344 ROC::Weights w; 345 w.small_pos = pos_weights_.sum_x(); 346 w.small_neg = neg_weights_.sum_x(); 347 return count(map, first, last, threshold*w.small_pos*w.small_neg, 0, w); 348 } 349 350 351 352 template <typename Iterator> 353 double ROC::count(Map& weights, Iterator iter, Iterator last, 354 double threshold, double sum, const Weights& w) const 355 { 356 double result = 0.0; 357 // loop over all elements 358 for (Map::iterator i=weights.begin(); i!=weights.end(); ++i) { 359 Map::value_type save = *i; 360 Map::iterator hint = i; 361 ++hint; 362 weights.erase(i); 363 result += count(weights, iter, last, threshold, sum, w, save.second); 364 i = weights.insert(hint, save); 365 } 366 YAT_ASSERT(weights.size()); 367 return result/weights.size(); 368 } 369 370 template <typename Iterator> 371 double ROC::count(Map& weights, Iterator iter, Iterator last, 372 double threshold, double sum, Weights w, 373 const std::pair<bool, double>& entry) const 374 { 375 double tiny = 10e-10; 376 377 Iterator next(iter); 378 ++next; 379 380 // update weights 381 if (entry.first) { 382 w.tied_pos += entry.second; 383 w.small_pos -= entry.second; 384 } 385 else { 386 w.tied_neg += entry.second; 387 w.small_neg -= entry.second; 388 } 389 390 // last entry in equal range 391 if (next==last || *next!=*iter) { 392 sum += 0.5*w.tied_pos*w.tied_neg + w.tied_pos * w.small_neg; 393 w.tied_pos=0; 394 w.tied_neg=0; 395 } 396 397 // max sum happens if all pos values belong to current equal range 398 // and none of the remaining neg values 399 double max_sum = sum + 0.5*(w.tied_pos+w.small_pos)*w.tied_neg + 400 (w.tied_pos+w.small_pos)*w.small_neg; 401 402 if (max_sum<threshold-tiny) 403 return 0.0; 404 if (sum >= threshold-tiny) 405 return 1.0; 406 407 if (next!=last) 408 return count(weights, next, last, threshold, sum, w); 409 return 0.0; 410 } 411 275 412 }}} // of namespace statistics, yat, and theplu 276 413 #endif
Note: See TracChangeset
for help on using the changeset viewer.