# Changeset 3904

Ignore:
Timestamp:
May 7, 2020, 1:41:47 PM (3 years ago)
Message:

Closes #892

Speed up matrix multiplications when they involve a transposed
matrix. Rather than transposing and then performing the
multiplication, just set the appropriate argument when calling
gsl_blas_dgemm.

Location:
trunk
Files:
2 edited

Unmodified
Added
Removed
• ## trunk/test/matrix_expression.cc

 r3654 Matrix B(4, 3, 2); Matrix C(4, 3, 4); // addition operator+ { Matrix m2 = transpose(A * transpose(B) * C); check(m2, m, suite); } Matrix m3 = transpose(transpose(m)); check(m3, m, suite); Matrix m4 = transpose(2*transpose(m)); Matrix m5 = transpose(transpose(2*m)); check(m4, m5, suite); } // test dgemm expression { suite.out() << "testing dgemm\n"; Matrix oldres(Bt); oldres *= 0.5; oldres *= C; Matrix res = 0.5 * transpose(B) * C; check(res, oldres, suite); res = 0.5 * Bt * C; check(res, oldres, suite); res = 0.5 * transpose(B) * 1.0 * transpose(Ct); check(res, oldres, suite); res = 1.0 * Bt * 0.5 * transpose(Ct); check(res, oldres, suite); res = transpose(B) * C; } return suite.return_value(); }
• ## trunk/yat/utility/BLAS_level3.h

 r3654 /* Copyright (C) 2017 Peter Johansson Copyright (C) 2017, 2020 Peter Johansson This file is part of the yat library, http://dev.thep.lu.se/yat #include "MatrixExpression.h" #include #include #include const BasicMatrix& rhs_; OP op_; void calculate_matrix(gsl_matrix*& result, Multiplies) const { YAT_ASSERT(detail::rows(result) == this->rows()); YAT_ASSERT(detail::columns(result) == this->columns()); YAT_ASSERT(lhs_.columns() == rhs_.rows()); gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, lhs_.gsl_matrix_p(), rhs_.gsl_matrix_p(),0.0, result); YAT_ASSERT(result); } template } double get(size_t row, size_t column, Multiplies) const { return gsl_matrix_get(this->gsl_matrix_p(), row, column); } }; const T& base(void) const { return static_cast(A_); } void calculate_matrix(gsl_matrix*& result) const { detail::copy(result, A_.gsl_matrix_p()); gsl_matrix_scale(result, factor_); } double factor(void) const { return factor_; } } const T& base(void) const { return static_cast(A_); } private: const BasicMatrix& A_; }; // Helper class for ScaledMatrixProduct template struct MatrixTrait { CBLAS_TRANSPOSE_t transpose_type(void) const { return CblasNoTrans; } const gsl_matrix* get_gsl_matrix_p(const T& m) const { return m.gsl_matrix_p(); } }; // Specialization for TransposedMatrix template struct MatrixTrait > > > { // returned the base type's transpose_type, flipped. CBLAS_TRANSPOSE_t transpose_type(void) const { if (base_trait_.transpose_type() == CblasTrans) return CblasNoTrans; return CblasTrans; } const gsl_matrix* get_gsl_matrix_p(const BasicMatrix > >& m) const { const TransposedMatrix* me = static_cast*>(&m); return base_trait_.get_gsl_matrix_p(me->base()); } private: MatrixTrait > base_trait_; }; template class ScaledMatrixProduct : public MatrixExpression > { public: ScaledMatrixProduct(double alpha, const BasicMatrix& A, const BasicMatrix& B) : alpha_(alpha), A_(A), B_(B) {} size_t rows(void) const { return A_.rows(); } size_t columns(void) const { return B_.columns(); } double operator()(size_t row, size_t column) const { return gsl_matrix_get(this->gsl_matrix_p(), row, column); } void calculate_matrix(gsl_matrix*& result, double beta=0.0) const { detail::reallocate(result, this->rows(), this->columns()); YAT_ASSERT(detail::rows(result) == this->rows()); YAT_ASSERT(detail::columns(result) == this->columns()); YAT_ASSERT(A_.columns() == B_.rows()); MatrixTrait > traitA; MatrixTrait > traitB; traitA.get_gsl_matrix_p(A_); traitB.get_gsl_matrix_p(B_); gsl_blas_dgemm(traitA.transpose_type(), traitB.transpose_type(), alpha_, traitA.get_gsl_matrix_p(A_), traitB.get_gsl_matrix_p(B_), beta, result); YAT_ASSERT(result); } private: double alpha_; const BasicMatrix& A_; const BasicMatrix& B_; }; } // end namespace expression /// \cond IGNORE_DOXYGEN namespace detail { // Helper class used in MatrixExpression * MatrixExpression // The class factorises a matrix expression into // scalar * matrix expression // For the default case there is nothing to factorise, so scalar // is 1.0 and matrix expression is the passed matrix expression. template struct UnscaleTrait { UnscaleTrait(const T& matrix) : matrix_(matrix) {} typedef typename T::derived_type result_arg; double factor(void) const { return 1.0; } const T& matrix(void) const { return matrix_; } private: const T& matrix_; }; // Specialization for ScaledMatrix for which we factor out the // scalar and the matrix. template struct UnscaleTrait > > > { UnscaleTrait(const BasicMatrix > >& m) : matrix_(m), scaled_matrix_(static_cast&>(m)) {} typedef Base result_arg; double factor(void) const { return scaled_matrix_.factor(); } const Base& matrix(void) const { return scaled_matrix_.base(); } private: const BasicMatrix > >& matrix_; const expression::ScaledMatrix& scaled_matrix_; }; } // end namespace detail /// \endcond /** \brief Matrix multiplication operator */ template expression::MatrixBinary expression::ScaledMatrixProduct< typename detail::UnscaleTrait >::result_arg, typename detail::UnscaleTrait >::result_arg > operator*(const BasicMatrix& lhs, const BasicMatrix& rhs) { YAT_ASSERT(lhs.columns() == rhs.rows()); return expression::MatrixBinary(lhs, rhs); } detail::UnscaleTrait > unscale1(lhs); detail::UnscaleTrait > unscale2(rhs); return expression::ScaledMatrixProduct< typename detail::UnscaleTrait >::result_arg, typename detail::UnscaleTrait >::result_arg >(unscale1.factor() * unscale2.factor(), unscale1.matrix(), unscale2.matrix()); } /** /** Specialization for ScaledMatrix \since New in yat 0.18 */ template expression::ScaledMatrix operator*(const BasicMatrix > >& A, double k) { const expression::ScaledMatrix& sm = static_cast&>(A); return expression::ScaledMatrix(k*sm.factor(), sm.base()); } /** Specialization for ScaledMatrix \since New in yat 0.18 */ template expression::ScaledMatrix operator*(double k, const BasicMatrix > >& A) { const expression::ScaledMatrix& sm = static_cast&>(A); return expression::ScaledMatrix(k*sm.factor(), sm.base()); } /** Specialization for ScaledMatrix \since New in yat 0.18 */ template expression::ScaledMatrix operator-(const BasicMatrix > >& A) { const expression::ScaledMatrix& sm = static_cast&>(A); return expression::ScaledMatrix(-sm.factor(), sm.base()); } /** \brief transpose function
Note: See TracChangeset for help on using the changeset viewer.