Ignore:
Timestamp:
Nov 28, 2013, 5:11:23 AM (8 years ago)
Author:
Peter
Message:

closes #772. new function that calculates mutual information

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/statistics/utility.h

    r3136 r3137  
    3636#include "yat/utility/deprecate.h"
    3737#include "yat/utility/iterator_traits.h"
     38#include "yat/utility/Vector.h"
    3839#include "yat/utility/VectorBase.h"
    3940#include "yat/utility/yat_assert.h"
    4041
    4142#include <boost/concept_check.hpp>
     43#include <gsl/gsl_math.h>
     44#include <gsl/gsl_statistics_double.h>
    4245
    4346#include <algorithm>
     
    4649#include <stdexcept>
    4750#include <vector>
    48 
    49 #include <gsl/gsl_statistics_double.h>
    5051
    5152namespace theplu {
     
    173174                bool sorted=false);
    174175
     176
     177  /**
     178     \brief Calculates the mutual information of \a A.
     179
     180     The elements in A are unnormalized probabilies of the joint
     181     distribution.
     182
     183     The mutual information is calculated as \f$ \sum \sum p(x,y) \log_2
     184     \frac {p(x,y)} {p(x)p(y)} \f$ where
     185     \f$ p(x,y) = \frac {A_{xy}}{\sum_{x,y} A_{xy}} \f$;
     186     \f$ p(x) = \sum_y A_{xy} / \sum_{x,y} A_{xy} \f$;
     187     \f$ p(y) = \sum_x A_{xy} / \sum_{x,y} A_{xy} \f$
     188
     189     Requirements:
     190     - \c T must be a model of \ref concept_container_2d
     191     - \c T::value_type must be convertible to \c double
     192
     193     \return mutual information in bits; if you want in natural base
     194     multiply with \c M_LN2 (defined in \c gsl/gsl_math.h )
     195
     196     \since New in yat 0.12
     197   */
     198  template<class T>
     199  double mutual_information(const T& A);
    175200
    176201  /**
     
    305330    return median(ad.begin(), ad.end(), true);
    306331  }
    307  
    308 
    309   template <class RandomAccessIterator>
    310   double median(RandomAccessIterator first, RandomAccessIterator last,
    311                 bool sorted)
    312   {
    313     return percentile2(first, last, 50.0, sorted);
    314   }
    315 
    316 
    317   template <class RandomAccessIterator>
    318   double percentile(RandomAccessIterator first, RandomAccessIterator last,
     332
     333
     334  template <class RandomAccessIterator>
     335  double median(RandomAccessIterator first, RandomAccessIterator last,
     336                bool sorted)
     337  {
     338    return percentile2(first, last, 50.0, sorted);
     339  }
     340
     341
     342  template<class T>
     343  double mutual_information(const T& n)
     344  {
     345    BOOST_CONCEPT_ASSERT((utility::Container2D<T>));
     346    using boost::Convertible;
     347    BOOST_CONCEPT_ASSERT((Convertible<typename T::value_type,double>));
     348
     349    // p_x = \sum_y p_xy
     350
     351    // Mutual Information is defined as
     352    // \sum_xy p_xy * log (p_xy / (p_x p_y)) =
     353    // \sum_xy p_xy * [log p_xy - log p_x - log p_y]
     354    // \sum_xy p_xy log p_xy - p_xy log p_x - p_xy log p_y
     355    // \sum_xy p_xy log p_xy - \sum_x p_x log p_x - \sum_y p_y log p_y
     356    // - entropy_xy + entropy_x + entropy_y
     357
     358    utility::Vector rowsum(n.columns(), 0);
     359    for (size_t c = 0; c<n.columns(); ++c)
     360      rowsum(c) = std::accumulate(n.begin_column(c), n.end_column(c), 0);
     361
     362    utility::Vector colsum(n.rows(), 0);
     363    for (size_t r = 0; r<n.rows(); ++r)
     364      colsum(r) = std::accumulate(n.begin_row(r), n.end_row(r), 0);
     365
     366    double mi  = - entropy(n.begin(), n.end());
     367    mi += entropy(rowsum.begin(), rowsum.end());
     368    mi += entropy(colsum.begin(), colsum.end());
     369
     370    return mi/M_LN2;
     371  }
     372
     373
     374  template <class RandomAccessIterator>
     375  double percentile(RandomAccessIterator first, RandomAccessIterator last,
    319376                    double p, bool sorted)
    320377  {
Note: See TracChangeset for help on using the changeset viewer.