Changeset 3904


Ignore:
Timestamp:
May 7, 2020, 1:41:47 PM (3 years ago)
Author:
Peter
Message:

Closes #892

Speed up matrix multiplications when they involve a transposed
matrix. Rather than transposing and then performing the
multiplication, just set the appropriate argument when calling
gsl_blas_dgemm.

Location:
trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/matrix_expression.cc

    r3654 r3904  
    5353  Matrix B(4, 3, 2);
    5454  Matrix C(4, 3, 4);
    55 
    5655  // addition operator+
    5756  {
     
    176175    Matrix m2 = transpose(A * transpose(B) * C);
    177176    check(m2, m, suite);
    178   }
    179 
    180 
     177    Matrix m3 = transpose(transpose(m));
     178    check(m3, m, suite);
     179    Matrix m4 = transpose(2*transpose(m));
     180    Matrix m5 = transpose(transpose(2*m));
     181    check(m4, m5, suite);
     182  }
     183
     184
     185  // test dgemm expression
     186  {
     187    suite.out() << "testing dgemm\n";
     188
     189    Matrix oldres(Bt);
     190    oldres *= 0.5;
     191    oldres *= C;
     192    Matrix res = 0.5 * transpose(B) * C;
     193    check(res, oldres, suite);
     194
     195    res = 0.5 * Bt * C;
     196    check(res, oldres, suite);
     197
     198    res = 0.5 * transpose(B) * 1.0 * transpose(Ct);
     199    check(res, oldres, suite);
     200
     201    res = 1.0 * Bt * 0.5 * transpose(Ct);
     202    check(res, oldres, suite);
     203
     204    res = transpose(B) * C;
     205  }
    181206  return suite.return_value();
    182207}
  • trunk/yat/utility/BLAS_level3.h

    r3654 r3904  
    55
    66/*
    7   Copyright (C) 2017 Peter Johansson
     7  Copyright (C) 2017, 2020 Peter Johansson
    88
    99  This file is part of the yat library, http://dev.thep.lu.se/yat
     
    2727#include "MatrixExpression.h"
    2828
     29#include <gsl/gsl_blas.h>
     30#include <gsl/gsl_cblas.h>
    2931#include <gsl/gsl_matrix.h>
    3032
     
    6466      const BasicMatrix<RHS>& rhs_;
    6567      OP op_;
    66 
    67       void calculate_matrix(gsl_matrix*& result, Multiplies) const
    68       {
    69         YAT_ASSERT(detail::rows(result) == this->rows());
    70         YAT_ASSERT(detail::columns(result) == this->columns());
    71         YAT_ASSERT(lhs_.columns() == rhs_.rows());
    72         gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0,
    73                        lhs_.gsl_matrix_p(), rhs_.gsl_matrix_p(),0.0, result);
    74         YAT_ASSERT(result);
    75       }
    7668
    7769      template<class T>
     
    9789      }
    9890
    99 
    100       double get(size_t row, size_t column, Multiplies) const
    101       {
    102         return gsl_matrix_get(this->gsl_matrix_p(), row, column);
    103       }
    10491    };
    10592
     
    122109
    123110
     111      const T& base(void) const
     112      {
     113        return static_cast<const T&>(A_);
     114      }
     115
     116
    124117      void calculate_matrix(gsl_matrix*& result) const
    125118      {
    126119        detail::copy(result, A_.gsl_matrix_p());
    127120        gsl_matrix_scale(result, factor_);
     121      }
     122
     123
     124      double factor(void) const
     125      {
     126        return factor_;
    128127      }
    129128
     
    155154      }
    156155
     156      const T& base(void) const
     157      {
     158        return static_cast<const T&>(A_);
     159      }
     160
    157161    private:
    158162      const BasicMatrix<T>& A_;
    159163    };
    160164
     165
     166    // Helper class for ScaledMatrixProduct
     167    template<class T>
     168    struct MatrixTrait
     169    {
     170      CBLAS_TRANSPOSE_t transpose_type(void) const { return CblasNoTrans; }
     171      const gsl_matrix* get_gsl_matrix_p(const T& m) const
     172      { return m.gsl_matrix_p(); }
     173    };
     174
     175    // Specialization for TransposedMatrix
     176    template<class Base>
     177    struct MatrixTrait<BasicMatrix<MatrixExpression<TransposedMatrix<Base> > > >
     178    {
     179      // returned the base type's transpose_type, flipped.
     180      CBLAS_TRANSPOSE_t transpose_type(void) const
     181      {
     182        if (base_trait_.transpose_type() == CblasTrans)
     183          return CblasNoTrans;
     184        return CblasTrans;
     185      }
     186
     187
     188      const gsl_matrix*
     189      get_gsl_matrix_p(const BasicMatrix<MatrixExpression<TransposedMatrix<Base> > >& m) const
     190      {
     191        const TransposedMatrix<Base>* me =
     192          static_cast<const TransposedMatrix<Base>*>(&m);
     193        return base_trait_.get_gsl_matrix_p(me->base());
     194      }
     195
     196
     197      private:
     198      MatrixTrait<BasicMatrix<Base> > base_trait_;
     199    };
     200
     201
     202    template<class DerivedA, class DerivedB>
     203    class ScaledMatrixProduct :
     204      public MatrixExpression<ScaledMatrixProduct<DerivedA, DerivedB> >
     205    {
     206    public:
     207      ScaledMatrixProduct(double alpha,
     208                          const BasicMatrix<DerivedA>& A,
     209                          const BasicMatrix<DerivedB>& B)
     210        : alpha_(alpha), A_(A), B_(B) {}
     211
     212      size_t rows(void) const { return A_.rows(); }
     213      size_t columns(void) const { return B_.columns(); }
     214      double operator()(size_t row, size_t column) const
     215      { return gsl_matrix_get(this->gsl_matrix_p(), row, column); }
     216
     217      void calculate_matrix(gsl_matrix*& result, double beta=0.0) const
     218      {
     219        detail::reallocate(result, this->rows(), this->columns());
     220        YAT_ASSERT(detail::rows(result) == this->rows());
     221        YAT_ASSERT(detail::columns(result) == this->columns());
     222        YAT_ASSERT(A_.columns() == B_.rows());
     223        MatrixTrait<BasicMatrix<DerivedA> > traitA;
     224        MatrixTrait<BasicMatrix<DerivedB> > traitB;
     225        traitA.get_gsl_matrix_p(A_);
     226        traitB.get_gsl_matrix_p(B_);
     227        gsl_blas_dgemm(traitA.transpose_type(), traitB.transpose_type(),
     228                       alpha_,
     229                       traitA.get_gsl_matrix_p(A_),
     230                       traitB.get_gsl_matrix_p(B_),
     231                       beta, result);
     232        YAT_ASSERT(result);
     233      }
     234      private:
     235
     236      double alpha_;
     237      const BasicMatrix<DerivedA>& A_;
     238      const BasicMatrix<DerivedB>& B_;
     239    };
    161240
    162241  } // end namespace expression
     
    206285
    207286
     287  /// \cond IGNORE_DOXYGEN
     288
     289  namespace detail {
     290
     291    // Helper class used in MatrixExpression * MatrixExpression
     292    // The class factorises a matrix expression into
     293    // scalar * matrix expression
     294    // For the default case there is nothing to factorise, so scalar
     295    // is 1.0 and matrix expression is the passed matrix expression.
     296    template<class T>
     297    struct UnscaleTrait
     298    {
     299      UnscaleTrait(const T& matrix) : matrix_(matrix) {}
     300      typedef typename T::derived_type result_arg;
     301      double factor(void) const { return 1.0; }
     302      const T& matrix(void) const { return matrix_; }
     303    private:
     304      const T& matrix_;
     305    };
     306
     307    // Specialization for ScaledMatrix for which we factor out the
     308    // scalar and the matrix.
     309    template<class Base>
     310    struct UnscaleTrait<BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > > >
     311    {
     312      UnscaleTrait(const BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > >& m)
     313        : matrix_(m),
     314          scaled_matrix_(static_cast<const expression::ScaledMatrix<Base>&>(m))
     315      {}
     316
     317
     318      typedef Base result_arg;
     319
     320      double factor(void) const
     321      {
     322        return scaled_matrix_.factor();
     323      }
     324
     325
     326      const Base& matrix(void) const
     327      {
     328        return scaled_matrix_.base();
     329      }
     330    private:
     331      const BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > >&
     332      matrix_;
     333      const expression::ScaledMatrix<Base>& scaled_matrix_;
     334    };
     335
     336  } // end namespace detail
     337  /// \endcond
     338
    208339  /**
    209340     \brief Matrix multiplication operator
     
    216347  */
    217348  template<class Derived1, class Derived2>
    218   expression::MatrixBinary<Derived1, Derived2, expression::Multiplies>
     349  expression::ScaledMatrixProduct<
     350    typename detail::UnscaleTrait<BasicMatrix<Derived1> >::result_arg,
     351    typename detail::UnscaleTrait<BasicMatrix<Derived2> >::result_arg
     352    >
    219353  operator*(const BasicMatrix<Derived1>& lhs, const BasicMatrix<Derived2>& rhs)
    220354  {
    221355    YAT_ASSERT(lhs.columns() == rhs.rows());
    222     return expression::MatrixBinary<Derived1, Derived2,
    223                                     expression::Multiplies>(lhs, rhs);
    224   }
    225 
     356    detail::UnscaleTrait<BasicMatrix<Derived1> > unscale1(lhs);
     357    detail::UnscaleTrait<BasicMatrix<Derived2> > unscale2(rhs);
     358
     359    return expression::ScaledMatrixProduct<
     360      typename detail::UnscaleTrait<BasicMatrix<Derived1> >::result_arg,
     361      typename detail::UnscaleTrait<BasicMatrix<Derived2> >::result_arg
     362      >(unscale1.factor() * unscale2.factor(),
     363        unscale1.matrix(), unscale2.matrix());
     364  }
    226365
    227366  /**
     
    277416
    278417  /**
     418     Specialization for ScaledMatrix
     419
     420     \since New in yat 0.18
     421   */
     422  template<class T>
     423  expression::ScaledMatrix<T>
     424  operator*(const BasicMatrix<MatrixExpression<expression::ScaledMatrix<T> > >& A,
     425            double k)
     426  {
     427    const expression::ScaledMatrix<T>& sm =
     428      static_cast<const expression::ScaledMatrix<T>&>(A);
     429    return expression::ScaledMatrix<T>(k*sm.factor(), sm.base());
     430  }
     431
     432
     433  /**
     434     Specialization for ScaledMatrix
     435
     436     \since New in yat 0.18
     437   */
     438  template<class T>
     439  expression::ScaledMatrix<T>
     440  operator*(double k,
     441            const BasicMatrix<MatrixExpression<expression::ScaledMatrix<T> > >& A)
     442  {
     443    const expression::ScaledMatrix<T>& sm =
     444      static_cast<const expression::ScaledMatrix<T>&>(A);
     445    return expression::ScaledMatrix<T>(k*sm.factor(), sm.base());
     446  }
     447
     448
     449  /**
     450     Specialization for ScaledMatrix
     451
     452     \since New in yat 0.18
     453   */
     454  template<class T>
     455  expression::ScaledMatrix<T>
     456  operator-(const BasicMatrix<MatrixExpression<expression::ScaledMatrix<T> > >& A)
     457  {
     458    const expression::ScaledMatrix<T>& sm =
     459      static_cast<const expression::ScaledMatrix<T>&>(A);
     460    return expression::ScaledMatrix<T>(-sm.factor(), sm.base());
     461  }
     462
     463
     464  /**
    279465     \brief transpose function
    280466
Note: See TracChangeset for help on using the changeset viewer.