Ignore:
Timestamp:
Mar 26, 2021, 4:29:42 AM (8 months ago)
Author:
Peter
Message:

add some test for correlation(Matrix); speed up, particuarly when #rows is much greater than #columns. Avoid copying and modifying the whole matrix, instead calculate simple sums and squared sums using BLAS functionality as much as possible.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/statistics/utility.cc

    r3792 r4053  
    5050
    5151
    52   utility::Matrix correlation(utility::Matrix x)
     52  utility::Matrix correlation(const utility::Matrix& X)
    5353  {
    54     using utility::Matrix;
    55     size_t n = x.columns();
    56     // centralise
    57     normalizer::ColumnNormalizer<normalizer::Centralizer<> > normalizer;
    58     normalizer(x, x);
    59 
    60     Matrix cov(n, n);
    61     for (size_t i=0; i<n; ++i) {
    62       for (size_t j=i; j<n; ++j) {
    63         cov(i, j) = x.column_const_view(i) * x.column_const_view(j);
    64         cov(j, i) = cov(i, j);
    65       }
     54    size_t cols = X.columns();
     55    size_t rows = X.rows();
     56    utility::Vector m(cols);
     57    utility::Vector x2(cols);
     58    utility::Vector stddev(cols);
     59    for (size_t i=0; i<cols; ++i) {
     60      utility::VectorConstView vec = X.column_const_view(i);
     61      m(i) = sum(vec);
     62      x2(i) = vec * vec;
     63      // scaled standard deviation
     64      stddev(i) = std::sqrt(x2(i) - m(i) * m(i) / rows);
    6665    }
    6766
    68     utility::Vector stddev(n);
    69     for (size_t i=0; i<n; ++i)
    70       stddev(i) = std::sqrt(cov(i, i));
     67    utility::Matrix corr(cols, cols, 1.0);
     68    for (size_t i=0; i<cols; ++i)
     69      for (size_t j=i+1; j<cols; ++j) {
     70        corr(i,j) =
     71          X.column_const_view(i) * X.column_const_view(j) - m(i)*m(j)/rows;
     72        corr(i, j) /= stddev(i) * stddev(j);
     73        // symmetry
     74        corr(j, i) = corr(i,j);
     75      }
    7176
    72     Matrix corr(cov);
    73     for (size_t i=0; i<n; ++i) {
    74       corr(i, i) = 1.0;
    75       for (size_t j=0; j<i; ++j) {
    76         corr(i, j) = cov(i, j) / (stddev(i) * stddev(j));
    77         corr(j, i) = corr(i, j);
    78       }
    79     }
    8077    return corr;
    8178  }
Note: See TracChangeset for help on using the changeset viewer.