Changeset 3908
- Timestamp:
- May 13, 2020, 8:58:38 AM (3 years ago)
- Location:
- trunk/yat/utility
- Files:
-
- 6 added
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/yat/utility/BLAS_level2.h
r3816 r3908 25 25 #include "BasicMatrix.h" 26 26 #include "BasicVector.h" 27 #include "BLAS_level3.h" 28 #include "VectorExpression.h" 27 29 #include "yat_assert.h" 28 30 29 #include " VectorExpression.h"31 #include "expression/MatrixTraits.h" 30 32 31 33 #include "gsl/gsl_blas.h" … … 48 50 { 49 51 public: 52 50 53 MatrixVector(const BasicMatrix<MATRIX>& lhs, 51 54 const BasicVector<VECTOR>& rhs) … … 54 57 YAT_ASSERT(lhs.rows()); 55 58 this->allocate_memory(lhs.rows()); 56 gsl_blas_dgemv(CblasNoTrans, 1.0, lhs.gsl_matrix_p(), 59 60 MatrixTraits<BasicMatrix<MATRIX> > traits; 61 62 gsl_blas_dgemv(traits.transpose_type(), traits.factor(lhs), 63 traits.get_gsl_matrix_p(lhs), 57 64 rhs.gsl_vector_p(), 0.0, this->v_); 58 }59 60 MatrixVector(const BasicVector<VECTOR>& lhs,61 const BasicMatrix<MATRIX>& rhs)62 {63 YAT_ASSERT(lhs.size() == rhs.rows());64 YAT_ASSERT(rhs.columns());65 this->allocate_memory(rhs.columns());66 gsl_blas_dgemv(CblasTrans, 1.0, rhs.gsl_matrix_p(),67 lhs.gsl_vector_p(), 0.0, this->v_);68 65 } 69 66 … … 86 83 void calculate_gsl_vector_p(void) const 87 84 { 85 // This should never be called since v_ is constructed in 86 // constructor 88 87 YAT_ASSERT(this->v_); 89 88 YAT_ASSERT(0); … … 122 121 */ 123 122 template<class MATRIX, class VECTOR> 124 expression::MatrixVector<MATRIX, VECTOR> 123 expression::MatrixVector<MatrixExpression< 124 expression::TransposedMatrix<MATRIX> >, 125 VECTOR> 125 126 operator*(const BasicVector<VECTOR>& lhs, const BasicMatrix<MATRIX>& rhs) 126 127 { 127 128 YAT_ASSERT(lhs.size() == rhs.rows()); 128 return expression::MatrixVector<MATRIX, VECTOR>(lhs, rhs);129 return transpose(rhs) * lhs; 129 130 } 130 131 -
trunk/yat/utility/BLAS_level3.h
r3904 r3908 25 25 #include "BasicMatrix.h" 26 26 #include "BLAS_utility.h" 27 #include "MatrixExpression.h" 27 28 #include "expression/MatrixBinary.h" 29 #include "expression/MatrixProduct.h" 30 #include "expression/ScaledMatrix.h" 31 #include "expression/TransposedMatrix.h" 28 32 29 33 #include <gsl/gsl_blas.h> … … 37 41 // This file defines operations using both Matrix (but not Vector) 38 42 39 /// \cond IGNORE_DOXYGEN40 41 namespace expression {42 template<typename LHS, typename RHS, class OP>43 class MatrixBinary44 : public MatrixExpression<MatrixBinary<LHS, RHS, OP> >45 {46 public:47 MatrixBinary(const BasicMatrix<LHS>& lhs, const BasicMatrix<RHS>& rhs)48 : lhs_(lhs), rhs_(rhs)49 {50 }51 52 size_t rows(void) const { return lhs_.rows(); }53 size_t columns(void) const { return rhs_.columns(); }54 55 double operator()(size_t row, size_t column) const56 { return get(row, column, op_); }57 58 void calculate_matrix(gsl_matrix*& result) const59 {60 detail::reallocate(result, this->rows(), this->columns());61 calculate_matrix(result, op_);62 }63 64 private:65 const BasicMatrix<LHS>& lhs_;66 const BasicMatrix<RHS>& rhs_;67 OP op_;68 69 template<class T>70 void calculate_matrix(gsl_matrix*& result, T) const71 {72 YAT_ASSERT(detail::rows(result) == this->rows());73 YAT_ASSERT(detail::columns(result) == this->columns());74 for (size_t i=0; i<rows(); ++i)75 for (size_t j=0; j<columns(); ++j)76 gsl_matrix_set(result, i, j, (*this)(i, j));77 }78 79 80 double get(size_t row, size_t column, Plus) const81 {82 return lhs_(row, column) + rhs_(row, column);83 }84 85 86 double get(size_t row, size_t column, Minus) const87 {88 return lhs_(row, column) - rhs_(row, column);89 }90 91 };92 93 94 template<class T>95 class ScaledMatrix : public MatrixExpression<ScaledMatrix<T> >96 {97 public:98 ScaledMatrix(double factor, const BasicMatrix<T>& A)99 : A_(A), factor_(factor) {}100 101 size_t rows(void) const { return A_.rows(); }102 size_t columns(void) const { return A_.columns(); }103 104 105 double operator()(size_t i, size_t j) const106 {107 return factor_ * A_(i, j);108 }109 110 111 const T& base(void) const112 {113 return static_cast<const T&>(A_);114 }115 116 117 void calculate_matrix(gsl_matrix*& result) const118 {119 detail::copy(result, A_.gsl_matrix_p());120 gsl_matrix_scale(result, factor_);121 }122 123 124 double factor(void) const125 {126 return factor_;127 }128 129 private:130 const BasicMatrix<T>& A_;131 double factor_;132 };133 134 135 template<class T>136 class TransposedMatrix : public MatrixExpression<TransposedMatrix<T> >137 {138 public:139 TransposedMatrix(const BasicMatrix<T>& A)140 : A_(A) {}141 142 size_t rows(void) const { return A_.columns(); }143 size_t columns(void) const { return A_.rows(); }144 145 double operator()(size_t i, size_t j) const146 {147 return A_(j, i);148 }149 150 void calculate_matrix(gsl_matrix*& result) const151 {152 detail::reallocate(result, rows(), columns());153 gsl_matrix_transpose_memcpy(result, A_.gsl_matrix_p());154 }155 156 const T& base(void) const157 {158 return static_cast<const T&>(A_);159 }160 161 private:162 const BasicMatrix<T>& A_;163 };164 165 166 // Helper class for ScaledMatrixProduct167 template<class T>168 struct MatrixTrait169 {170 CBLAS_TRANSPOSE_t transpose_type(void) const { return CblasNoTrans; }171 const gsl_matrix* get_gsl_matrix_p(const T& m) const172 { return m.gsl_matrix_p(); }173 };174 175 // Specialization for TransposedMatrix176 template<class Base>177 struct MatrixTrait<BasicMatrix<MatrixExpression<TransposedMatrix<Base> > > >178 {179 // returned the base type's transpose_type, flipped.180 CBLAS_TRANSPOSE_t transpose_type(void) const181 {182 if (base_trait_.transpose_type() == CblasTrans)183 return CblasNoTrans;184 return CblasTrans;185 }186 187 188 const gsl_matrix*189 get_gsl_matrix_p(const BasicMatrix<MatrixExpression<TransposedMatrix<Base> > >& m) const190 {191 const TransposedMatrix<Base>* me =192 static_cast<const TransposedMatrix<Base>*>(&m);193 return base_trait_.get_gsl_matrix_p(me->base());194 }195 196 197 private:198 MatrixTrait<BasicMatrix<Base> > base_trait_;199 };200 201 202 template<class DerivedA, class DerivedB>203 class ScaledMatrixProduct :204 public MatrixExpression<ScaledMatrixProduct<DerivedA, DerivedB> >205 {206 public:207 ScaledMatrixProduct(double alpha,208 const BasicMatrix<DerivedA>& A,209 const BasicMatrix<DerivedB>& B)210 : alpha_(alpha), A_(A), B_(B) {}211 212 size_t rows(void) const { return A_.rows(); }213 size_t columns(void) const { return B_.columns(); }214 double operator()(size_t row, size_t column) const215 { return gsl_matrix_get(this->gsl_matrix_p(), row, column); }216 217 void calculate_matrix(gsl_matrix*& result, double beta=0.0) const218 {219 detail::reallocate(result, this->rows(), this->columns());220 YAT_ASSERT(detail::rows(result) == this->rows());221 YAT_ASSERT(detail::columns(result) == this->columns());222 YAT_ASSERT(A_.columns() == B_.rows());223 MatrixTrait<BasicMatrix<DerivedA> > traitA;224 MatrixTrait<BasicMatrix<DerivedB> > traitB;225 traitA.get_gsl_matrix_p(A_);226 traitB.get_gsl_matrix_p(B_);227 gsl_blas_dgemm(traitA.transpose_type(), traitB.transpose_type(),228 alpha_,229 traitA.get_gsl_matrix_p(A_),230 traitB.get_gsl_matrix_p(B_),231 beta, result);232 YAT_ASSERT(result);233 }234 private:235 236 double alpha_;237 const BasicMatrix<DerivedA>& A_;238 const BasicMatrix<DerivedB>& B_;239 };240 241 } // end namespace expression242 243 /// \endcond244 245 43 /** 246 44 \brief Matrix addition operator … … 285 83 286 84 287 /// \cond IGNORE_DOXYGEN288 289 namespace detail {290 291 // Helper class used in MatrixExpression * MatrixExpression292 // The class factorises a matrix expression into293 // scalar * matrix expression294 // For the default case there is nothing to factorise, so scalar295 // is 1.0 and matrix expression is the passed matrix expression.296 template<class T>297 struct UnscaleTrait298 {299 UnscaleTrait(const T& matrix) : matrix_(matrix) {}300 typedef typename T::derived_type result_arg;301 double factor(void) const { return 1.0; }302 const T& matrix(void) const { return matrix_; }303 private:304 const T& matrix_;305 };306 307 // Specialization for ScaledMatrix for which we factor out the308 // scalar and the matrix.309 template<class Base>310 struct UnscaleTrait<BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > > >311 {312 UnscaleTrait(const BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > >& m)313 : matrix_(m),314 scaled_matrix_(static_cast<const expression::ScaledMatrix<Base>&>(m))315 {}316 317 318 typedef Base result_arg;319 320 double factor(void) const321 {322 return scaled_matrix_.factor();323 }324 325 326 const Base& matrix(void) const327 {328 return scaled_matrix_.base();329 }330 private:331 const BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > >&332 matrix_;333 const expression::ScaledMatrix<Base>& scaled_matrix_;334 };335 336 } // end namespace detail337 /// \endcond338 339 85 /** 340 86 \brief Matrix multiplication operator … … 347 93 */ 348 94 template<class Derived1, class Derived2> 349 expression::ScaledMatrixProduct< 350 typename detail::UnscaleTrait<BasicMatrix<Derived1> >::result_arg, 351 typename detail::UnscaleTrait<BasicMatrix<Derived2> >::result_arg 352 > 95 expression::MatrixProduct<BasicMatrix<Derived1>, BasicMatrix<Derived2> > 353 96 operator*(const BasicMatrix<Derived1>& lhs, const BasicMatrix<Derived2>& rhs) 354 97 { 355 98 YAT_ASSERT(lhs.columns() == rhs.rows()); 356 detail::UnscaleTrait<BasicMatrix<Derived1> > unscale1(lhs); 357 detail::UnscaleTrait<BasicMatrix<Derived2> > unscale2(rhs); 358 359 return expression::ScaledMatrixProduct< 360 typename detail::UnscaleTrait<BasicMatrix<Derived1> >::result_arg, 361 typename detail::UnscaleTrait<BasicMatrix<Derived2> >::result_arg 362 >(unscale1.factor() * unscale2.factor(), 363 unscale1.matrix(), unscale2.matrix()); 99 return expression::MatrixProduct<BasicMatrix<Derived1>, 100 BasicMatrix<Derived2> 101 >(lhs, rhs); 364 102 } 365 103 -
trunk/yat/utility/Makefile.am
r3902 r3908 140 140 141 141 nobase_nodist_include_HEADERS = yat/utility/config_public.h 142 143 include yat/utility/expression/Makefile.am
Note: See TracChangeset
for help on using the changeset viewer.