Changeset 3137


Ignore:
Timestamp:
Nov 28, 2013, 5:11:23 AM (7 years ago)
Author:
Peter
Message:

closes #772. new function that calculates mutual information

Location:
trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/statistics.cc

    r3135 r3137  
    3333#include "yat/statistics/tTest.h"
    3434#include "yat/utility/DataWeight.h"
     35#include "yat/utility/Matrix.h"
    3536#include "yat/utility/MatrixWeighted.h"
    3637#include "yat/utility/Vector.h"
     
    4950void test_entropy(test::Suite&);
    5051void test_mad(test::Suite&);
     52void test_mutual_information(test::Suite&);
    5153
    5254void test_median_empty(test::Suite&);
     
    5557
    5658template<typename RandomAccessIterator>
    57 void test_percentiler(test::Suite&, RandomAccessIterator, 
     59void test_percentiler(test::Suite&, RandomAccessIterator,
    5860                      RandomAccessIterator,
    5961                      double p, double correct);
    6062
    6163template<typename RandomAccessIterator1, typename RandomAccessIterator2>
    62 void cmp_percentiler(test::Suite&, 
    63                      RandomAccessIterator1, 
     64void cmp_percentiler(test::Suite&,
     65                     RandomAccessIterator1,
    6466                     RandomAccessIterator1,
    6567                     RandomAccessIterator2,
     
    6769
    6870int main(int argc, char* argv[])
    69 { 
     71{
    7072  test::Suite suite(argc, argv);
    7173
     
    135137  test_entropy(suite);
    136138  test_median_empty(suite);
     139  test_mutual_information(suite);
    137140  return suite.return_value();
    138141}
     
    218221
    219222
     223void test_mutual_information(test::Suite& suite)
     224{
     225  suite.out() << "testing mutual_information\n";
     226  using statistics::mutual_information;
     227  utility::Matrix x(2,2);
     228  x(0,0) = 100;
     229  x(1,1) = 100;
     230  double mi = mutual_information(x);
     231  if (!suite.add(suite.equal(mi,1.0,100))) {
     232    suite.err() << "error: mutual information: " << mi << "\n";
     233  }
     234
     235  // testing a non-square Matrix
     236  x.resize(3,4,0);
     237  x(0,0) = 1;
     238  x(1,1) = 1;
     239  x(2,2) = 1;
     240  x(2,3) = 1;
     241  mi = mutual_information(x);
     242  suite.out() << "mi: " << mi << "\n";
     243}
     244
     245
    220246// test for ticket #660
    221247void test_median_empty(test::Suite& suite)
  • trunk/yat/statistics/utility.h

    r3136 r3137  
    3636#include "yat/utility/deprecate.h"
    3737#include "yat/utility/iterator_traits.h"
     38#include "yat/utility/Vector.h"
    3839#include "yat/utility/VectorBase.h"
    3940#include "yat/utility/yat_assert.h"
    4041
    4142#include <boost/concept_check.hpp>
     43#include <gsl/gsl_math.h>
     44#include <gsl/gsl_statistics_double.h>
    4245
    4346#include <algorithm>
     
    4649#include <stdexcept>
    4750#include <vector>
    48 
    49 #include <gsl/gsl_statistics_double.h>
    5051
    5152namespace theplu {
     
    173174                bool sorted=false);
    174175
     176
     177  /**
     178     \brief Calculates the mutual information of \a A.
     179
     180     The elements in A are unnormalized probabilies of the joint
     181     distribution.
     182
     183     The mutual information is calculated as \f$ \sum \sum p(x,y) \log_2
     184     \frac {p(x,y)} {p(x)p(y)} \f$ where
     185     \f$ p(x,y) = \frac {A_{xy}}{\sum_{x,y} A_{xy}} \f$;
     186     \f$ p(x) = \sum_y A_{xy} / \sum_{x,y} A_{xy} \f$;
     187     \f$ p(y) = \sum_x A_{xy} / \sum_{x,y} A_{xy} \f$
     188
     189     Requirements:
     190     - \c T must be a model of \ref concept_container_2d
     191     - \c T::value_type must be convertible to \c double
     192
     193     \return mutual information in bits; if you want in natural base
     194     multiply with \c M_LN2 (defined in \c gsl/gsl_math.h )
     195
     196     \since New in yat 0.12
     197   */
     198  template<class T>
     199  double mutual_information(const T& A);
    175200
    176201  /**
     
    305330    return median(ad.begin(), ad.end(), true);
    306331  }
    307  
    308 
    309   template <class RandomAccessIterator>
    310   double median(RandomAccessIterator first, RandomAccessIterator last,
    311                 bool sorted)
    312   {
    313     return percentile2(first, last, 50.0, sorted);
    314   }
    315 
    316 
    317   template <class RandomAccessIterator>
    318   double percentile(RandomAccessIterator first, RandomAccessIterator last,
     332
     333
     334  template <class RandomAccessIterator>
     335  double median(RandomAccessIterator first, RandomAccessIterator last,
     336                bool sorted)
     337  {
     338    return percentile2(first, last, 50.0, sorted);
     339  }
     340
     341
     342  template<class T>
     343  double mutual_information(const T& n)
     344  {
     345    BOOST_CONCEPT_ASSERT((utility::Container2D<T>));
     346    using boost::Convertible;
     347    BOOST_CONCEPT_ASSERT((Convertible<typename T::value_type,double>));
     348
     349    // p_x = \sum_y p_xy
     350
     351    // Mutual Information is defined as
     352    // \sum_xy p_xy * log (p_xy / (p_x p_y)) =
     353    // \sum_xy p_xy * [log p_xy - log p_x - log p_y]
     354    // \sum_xy p_xy log p_xy - p_xy log p_x - p_xy log p_y
     355    // \sum_xy p_xy log p_xy - \sum_x p_x log p_x - \sum_y p_y log p_y
     356    // - entropy_xy + entropy_x + entropy_y
     357
     358    utility::Vector rowsum(n.columns(), 0);
     359    for (size_t c = 0; c<n.columns(); ++c)
     360      rowsum(c) = std::accumulate(n.begin_column(c), n.end_column(c), 0);
     361
     362    utility::Vector colsum(n.rows(), 0);
     363    for (size_t r = 0; r<n.rows(); ++r)
     364      colsum(r) = std::accumulate(n.begin_row(r), n.end_row(r), 0);
     365
     366    double mi  = - entropy(n.begin(), n.end());
     367    mi += entropy(rowsum.begin(), rowsum.end());
     368    mi += entropy(colsum.begin(), colsum.end());
     369
     370    return mi/M_LN2;
     371  }
     372
     373
     374  template <class RandomAccessIterator>
     375  double percentile(RandomAccessIterator first, RandomAccessIterator last,
    319376                    double p, bool sorted)
    320377  {
Note: See TracChangeset for help on using the changeset viewer.