Changeset 3908


Ignore:
Timestamp:
May 13, 2020, 8:58:38 AM (3 weeks ago)
Author:
Peter
Message:

#closes #881

Lift out expression classes to their own files, so they can be
included separately (and avoiding cyclic inclusion).

Replace the two traits classes for matriox expression with one
class. Use this class in MatrixVector? expression, which makes
transposition a no-op and we can therefore implement Vector * Matrix
as transpose(Matrix) * Vector (since there is no distinction between
row and column vectors in yat).

Location:
trunk/yat/utility
Files:
6 added
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/utility/BLAS_level2.h

    r3816 r3908  
    2525#include "BasicMatrix.h"
    2626#include "BasicVector.h"
     27#include "BLAS_level3.h"
     28#include "VectorExpression.h"
    2729#include "yat_assert.h"
    2830
    29 #include "VectorExpression.h"
     31#include "expression/MatrixTraits.h"
    3032
    3133#include "gsl/gsl_blas.h"
     
    4850    {
    4951    public:
     52
    5053      MatrixVector(const BasicMatrix<MATRIX>& lhs,
    5154                   const BasicVector<VECTOR>& rhs)
     
    5457        YAT_ASSERT(lhs.rows());
    5558        this->allocate_memory(lhs.rows());
    56         gsl_blas_dgemv(CblasNoTrans, 1.0, lhs.gsl_matrix_p(),
     59
     60        MatrixTraits<BasicMatrix<MATRIX> > traits;
     61
     62        gsl_blas_dgemv(traits.transpose_type(), traits.factor(lhs),
     63                       traits.get_gsl_matrix_p(lhs),
    5764                       rhs.gsl_vector_p(), 0.0, this->v_);
    58       }
    59 
    60       MatrixVector(const BasicVector<VECTOR>& lhs,
    61                    const BasicMatrix<MATRIX>& rhs)
    62       {
    63         YAT_ASSERT(lhs.size() == rhs.rows());
    64         YAT_ASSERT(rhs.columns());
    65         this->allocate_memory(rhs.columns());
    66         gsl_blas_dgemv(CblasTrans, 1.0, rhs.gsl_matrix_p(),
    67                        lhs.gsl_vector_p(), 0.0, this->v_);
    6865      }
    6966
     
    8683      void calculate_gsl_vector_p(void) const
    8784      {
     85        // This should never be called since v_ is constructed in
     86        // constructor
    8887        YAT_ASSERT(this->v_);
    8988        YAT_ASSERT(0);
     
    122121   */
    123122  template<class MATRIX, class VECTOR>
    124   expression::MatrixVector<MATRIX, VECTOR>
     123  expression::MatrixVector<MatrixExpression<
     124                             expression::TransposedMatrix<MATRIX> >,
     125                           VECTOR>
    125126  operator*(const BasicVector<VECTOR>& lhs, const BasicMatrix<MATRIX>& rhs)
    126127  {
    127128    YAT_ASSERT(lhs.size() == rhs.rows());
    128     return expression::MatrixVector<MATRIX, VECTOR>(lhs, rhs);
     129    return transpose(rhs) * lhs;
    129130  }
    130131
  • trunk/yat/utility/BLAS_level3.h

    r3904 r3908  
    2525#include "BasicMatrix.h"
    2626#include "BLAS_utility.h"
    27 #include "MatrixExpression.h"
     27
     28#include "expression/MatrixBinary.h"
     29#include "expression/MatrixProduct.h"
     30#include "expression/ScaledMatrix.h"
     31#include "expression/TransposedMatrix.h"
    2832
    2933#include <gsl/gsl_blas.h>
     
    3741  // This file defines operations using both Matrix (but not Vector)
    3842
    39   /// \cond IGNORE_DOXYGEN
    40 
    41   namespace expression {
    42     template<typename LHS, typename RHS, class OP>
    43     class MatrixBinary
    44       : public MatrixExpression<MatrixBinary<LHS, RHS, OP> >
    45     {
    46     public:
    47       MatrixBinary(const BasicMatrix<LHS>& lhs, const BasicMatrix<RHS>& rhs)
    48         : lhs_(lhs), rhs_(rhs)
    49       {
    50       }
    51 
    52       size_t rows(void) const { return lhs_.rows(); }
    53       size_t columns(void) const { return rhs_.columns(); }
    54 
    55       double operator()(size_t row, size_t column) const
    56       { return get(row, column, op_); }
    57 
    58       void calculate_matrix(gsl_matrix*& result) const
    59       {
    60         detail::reallocate(result, this->rows(), this->columns());
    61         calculate_matrix(result, op_);
    62       }
    63 
    64     private:
    65       const BasicMatrix<LHS>& lhs_;
    66       const BasicMatrix<RHS>& rhs_;
    67       OP op_;
    68 
    69       template<class T>
    70       void calculate_matrix(gsl_matrix*& result, T) const
    71       {
    72         YAT_ASSERT(detail::rows(result) == this->rows());
    73         YAT_ASSERT(detail::columns(result) == this->columns());
    74         for (size_t i=0; i<rows(); ++i)
    75           for (size_t j=0; j<columns(); ++j)
    76             gsl_matrix_set(result, i, j, (*this)(i, j));
    77       }
    78 
    79 
    80       double get(size_t row, size_t column, Plus) const
    81       {
    82         return lhs_(row, column) + rhs_(row, column);
    83       }
    84 
    85 
    86       double get(size_t row, size_t column, Minus) const
    87       {
    88         return lhs_(row, column) - rhs_(row, column);
    89       }
    90 
    91     };
    92 
    93 
    94     template<class T>
    95     class ScaledMatrix : public MatrixExpression<ScaledMatrix<T> >
    96     {
    97     public:
    98       ScaledMatrix(double factor, const BasicMatrix<T>& A)
    99         : A_(A), factor_(factor) {}
    100 
    101       size_t rows(void) const { return A_.rows(); }
    102       size_t columns(void) const { return A_.columns(); }
    103 
    104 
    105       double operator()(size_t i, size_t j) const
    106       {
    107         return factor_ * A_(i, j);
    108       }
    109 
    110 
    111       const T& base(void) const
    112       {
    113         return static_cast<const T&>(A_);
    114       }
    115 
    116 
    117       void calculate_matrix(gsl_matrix*& result) const
    118       {
    119         detail::copy(result, A_.gsl_matrix_p());
    120         gsl_matrix_scale(result, factor_);
    121       }
    122 
    123 
    124       double factor(void) const
    125       {
    126         return factor_;
    127       }
    128 
    129     private:
    130       const BasicMatrix<T>& A_;
    131       double factor_;
    132     };
    133 
    134 
    135     template<class T>
    136     class TransposedMatrix : public MatrixExpression<TransposedMatrix<T> >
    137     {
    138     public:
    139       TransposedMatrix(const BasicMatrix<T>& A)
    140         : A_(A) {}
    141 
    142       size_t rows(void) const { return A_.columns(); }
    143       size_t columns(void) const { return A_.rows(); }
    144 
    145       double operator()(size_t i, size_t j) const
    146       {
    147         return A_(j, i);
    148       }
    149 
    150       void calculate_matrix(gsl_matrix*& result) const
    151       {
    152         detail::reallocate(result, rows(), columns());
    153         gsl_matrix_transpose_memcpy(result, A_.gsl_matrix_p());
    154       }
    155 
    156       const T& base(void) const
    157       {
    158         return static_cast<const T&>(A_);
    159       }
    160 
    161     private:
    162       const BasicMatrix<T>& A_;
    163     };
    164 
    165 
    166     // Helper class for ScaledMatrixProduct
    167     template<class T>
    168     struct MatrixTrait
    169     {
    170       CBLAS_TRANSPOSE_t transpose_type(void) const { return CblasNoTrans; }
    171       const gsl_matrix* get_gsl_matrix_p(const T& m) const
    172       { return m.gsl_matrix_p(); }
    173     };
    174 
    175     // Specialization for TransposedMatrix
    176     template<class Base>
    177     struct MatrixTrait<BasicMatrix<MatrixExpression<TransposedMatrix<Base> > > >
    178     {
    179       // returned the base type's transpose_type, flipped.
    180       CBLAS_TRANSPOSE_t transpose_type(void) const
    181       {
    182         if (base_trait_.transpose_type() == CblasTrans)
    183           return CblasNoTrans;
    184         return CblasTrans;
    185       }
    186 
    187 
    188       const gsl_matrix*
    189       get_gsl_matrix_p(const BasicMatrix<MatrixExpression<TransposedMatrix<Base> > >& m) const
    190       {
    191         const TransposedMatrix<Base>* me =
    192           static_cast<const TransposedMatrix<Base>*>(&m);
    193         return base_trait_.get_gsl_matrix_p(me->base());
    194       }
    195 
    196 
    197       private:
    198       MatrixTrait<BasicMatrix<Base> > base_trait_;
    199     };
    200 
    201 
    202     template<class DerivedA, class DerivedB>
    203     class ScaledMatrixProduct :
    204       public MatrixExpression<ScaledMatrixProduct<DerivedA, DerivedB> >
    205     {
    206     public:
    207       ScaledMatrixProduct(double alpha,
    208                           const BasicMatrix<DerivedA>& A,
    209                           const BasicMatrix<DerivedB>& B)
    210         : alpha_(alpha), A_(A), B_(B) {}
    211 
    212       size_t rows(void) const { return A_.rows(); }
    213       size_t columns(void) const { return B_.columns(); }
    214       double operator()(size_t row, size_t column) const
    215       { return gsl_matrix_get(this->gsl_matrix_p(), row, column); }
    216 
    217       void calculate_matrix(gsl_matrix*& result, double beta=0.0) const
    218       {
    219         detail::reallocate(result, this->rows(), this->columns());
    220         YAT_ASSERT(detail::rows(result) == this->rows());
    221         YAT_ASSERT(detail::columns(result) == this->columns());
    222         YAT_ASSERT(A_.columns() == B_.rows());
    223         MatrixTrait<BasicMatrix<DerivedA> > traitA;
    224         MatrixTrait<BasicMatrix<DerivedB> > traitB;
    225         traitA.get_gsl_matrix_p(A_);
    226         traitB.get_gsl_matrix_p(B_);
    227         gsl_blas_dgemm(traitA.transpose_type(), traitB.transpose_type(),
    228                        alpha_,
    229                        traitA.get_gsl_matrix_p(A_),
    230                        traitB.get_gsl_matrix_p(B_),
    231                        beta, result);
    232         YAT_ASSERT(result);
    233       }
    234       private:
    235 
    236       double alpha_;
    237       const BasicMatrix<DerivedA>& A_;
    238       const BasicMatrix<DerivedB>& B_;
    239     };
    240 
    241   } // end namespace expression
    242 
    243   /// \endcond
    244 
    24543  /**
    24644     \brief Matrix addition operator
     
    28583
    28684
    287   /// \cond IGNORE_DOXYGEN
    288 
    289   namespace detail {
    290 
    291     // Helper class used in MatrixExpression * MatrixExpression
    292     // The class factorises a matrix expression into
    293     // scalar * matrix expression
    294     // For the default case there is nothing to factorise, so scalar
    295     // is 1.0 and matrix expression is the passed matrix expression.
    296     template<class T>
    297     struct UnscaleTrait
    298     {
    299       UnscaleTrait(const T& matrix) : matrix_(matrix) {}
    300       typedef typename T::derived_type result_arg;
    301       double factor(void) const { return 1.0; }
    302       const T& matrix(void) const { return matrix_; }
    303     private:
    304       const T& matrix_;
    305     };
    306 
    307     // Specialization for ScaledMatrix for which we factor out the
    308     // scalar and the matrix.
    309     template<class Base>
    310     struct UnscaleTrait<BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > > >
    311     {
    312       UnscaleTrait(const BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > >& m)
    313         : matrix_(m),
    314           scaled_matrix_(static_cast<const expression::ScaledMatrix<Base>&>(m))
    315       {}
    316 
    317 
    318       typedef Base result_arg;
    319 
    320       double factor(void) const
    321       {
    322         return scaled_matrix_.factor();
    323       }
    324 
    325 
    326       const Base& matrix(void) const
    327       {
    328         return scaled_matrix_.base();
    329       }
    330     private:
    331       const BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > >&
    332       matrix_;
    333       const expression::ScaledMatrix<Base>& scaled_matrix_;
    334     };
    335 
    336   } // end namespace detail
    337   /// \endcond
    338 
    33985  /**
    34086     \brief Matrix multiplication operator
     
    34793  */
    34894  template<class Derived1, class Derived2>
    349   expression::ScaledMatrixProduct<
    350     typename detail::UnscaleTrait<BasicMatrix<Derived1> >::result_arg,
    351     typename detail::UnscaleTrait<BasicMatrix<Derived2> >::result_arg
    352     >
     95  expression::MatrixProduct<BasicMatrix<Derived1>, BasicMatrix<Derived2> >
    35396  operator*(const BasicMatrix<Derived1>& lhs, const BasicMatrix<Derived2>& rhs)
    35497  {
    35598    YAT_ASSERT(lhs.columns() == rhs.rows());
    356     detail::UnscaleTrait<BasicMatrix<Derived1> > unscale1(lhs);
    357     detail::UnscaleTrait<BasicMatrix<Derived2> > unscale2(rhs);
    358 
    359     return expression::ScaledMatrixProduct<
    360       typename detail::UnscaleTrait<BasicMatrix<Derived1> >::result_arg,
    361       typename detail::UnscaleTrait<BasicMatrix<Derived2> >::result_arg
    362       >(unscale1.factor() * unscale2.factor(),
    363         unscale1.matrix(), unscale2.matrix());
     99    return expression::MatrixProduct<BasicMatrix<Derived1>,
     100                                     BasicMatrix<Derived2>
     101                                     >(lhs, rhs);
    364102  }
    365103
  • trunk/yat/utility/Makefile.am

    r3902 r3908  
    140140
    141141nobase_nodist_include_HEADERS = yat/utility/config_public.h
     142
     143include yat/utility/expression/Makefile.am
Note: See TracChangeset for help on using the changeset viewer.