Changeset 3904
- Timestamp:
- May 7, 2020, 1:41:47 PM (3 years ago)
- Location:
- trunk
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/test/matrix_expression.cc
r3654 r3904 53 53 Matrix B(4, 3, 2); 54 54 Matrix C(4, 3, 4); 55 56 55 // addition operator+ 57 56 { … … 176 175 Matrix m2 = transpose(A * transpose(B) * C); 177 176 check(m2, m, suite); 178 } 179 180 177 Matrix m3 = transpose(transpose(m)); 178 check(m3, m, suite); 179 Matrix m4 = transpose(2*transpose(m)); 180 Matrix m5 = transpose(transpose(2*m)); 181 check(m4, m5, suite); 182 } 183 184 185 // test dgemm expression 186 { 187 suite.out() << "testing dgemm\n"; 188 189 Matrix oldres(Bt); 190 oldres *= 0.5; 191 oldres *= C; 192 Matrix res = 0.5 * transpose(B) * C; 193 check(res, oldres, suite); 194 195 res = 0.5 * Bt * C; 196 check(res, oldres, suite); 197 198 res = 0.5 * transpose(B) * 1.0 * transpose(Ct); 199 check(res, oldres, suite); 200 201 res = 1.0 * Bt * 0.5 * transpose(Ct); 202 check(res, oldres, suite); 203 204 res = transpose(B) * C; 205 } 181 206 return suite.return_value(); 182 207 } -
trunk/yat/utility/BLAS_level3.h
r3654 r3904 5 5 6 6 /* 7 Copyright (C) 2017 Peter Johansson7 Copyright (C) 2017, 2020 Peter Johansson 8 8 9 9 This file is part of the yat library, http://dev.thep.lu.se/yat … … 27 27 #include "MatrixExpression.h" 28 28 29 #include <gsl/gsl_blas.h> 30 #include <gsl/gsl_cblas.h> 29 31 #include <gsl/gsl_matrix.h> 30 32 … … 64 66 const BasicMatrix<RHS>& rhs_; 65 67 OP op_; 66 67 void calculate_matrix(gsl_matrix*& result, Multiplies) const68 {69 YAT_ASSERT(detail::rows(result) == this->rows());70 YAT_ASSERT(detail::columns(result) == this->columns());71 YAT_ASSERT(lhs_.columns() == rhs_.rows());72 gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0,73 lhs_.gsl_matrix_p(), rhs_.gsl_matrix_p(),0.0, result);74 YAT_ASSERT(result);75 }76 68 77 69 template<class T> … … 97 89 } 98 90 99 100 double get(size_t row, size_t column, Multiplies) const101 {102 return gsl_matrix_get(this->gsl_matrix_p(), row, column);103 }104 91 }; 105 92 … … 122 109 123 110 111 const T& base(void) const 112 { 113 return static_cast<const T&>(A_); 114 } 115 116 124 117 void calculate_matrix(gsl_matrix*& result) const 125 118 { 126 119 detail::copy(result, A_.gsl_matrix_p()); 127 120 gsl_matrix_scale(result, factor_); 121 } 122 123 124 double factor(void) const 125 { 126 return factor_; 128 127 } 129 128 … … 155 154 } 156 155 156 const T& base(void) const 157 { 158 return static_cast<const T&>(A_); 159 } 160 157 161 private: 158 162 const BasicMatrix<T>& A_; 159 163 }; 160 164 165 166 // Helper class for ScaledMatrixProduct 167 template<class T> 168 struct MatrixTrait 169 { 170 CBLAS_TRANSPOSE_t transpose_type(void) const { return CblasNoTrans; } 171 const gsl_matrix* get_gsl_matrix_p(const T& m) const 172 { return m.gsl_matrix_p(); } 173 }; 174 175 // Specialization for TransposedMatrix 176 template<class Base> 177 struct MatrixTrait<BasicMatrix<MatrixExpression<TransposedMatrix<Base> > > > 178 { 179 // returned the base type's transpose_type, flipped. 180 CBLAS_TRANSPOSE_t transpose_type(void) const 181 { 182 if (base_trait_.transpose_type() == CblasTrans) 183 return CblasNoTrans; 184 return CblasTrans; 185 } 186 187 188 const gsl_matrix* 189 get_gsl_matrix_p(const BasicMatrix<MatrixExpression<TransposedMatrix<Base> > >& m) const 190 { 191 const TransposedMatrix<Base>* me = 192 static_cast<const TransposedMatrix<Base>*>(&m); 193 return base_trait_.get_gsl_matrix_p(me->base()); 194 } 195 196 197 private: 198 MatrixTrait<BasicMatrix<Base> > base_trait_; 199 }; 200 201 202 template<class DerivedA, class DerivedB> 203 class ScaledMatrixProduct : 204 public MatrixExpression<ScaledMatrixProduct<DerivedA, DerivedB> > 205 { 206 public: 207 ScaledMatrixProduct(double alpha, 208 const BasicMatrix<DerivedA>& A, 209 const BasicMatrix<DerivedB>& B) 210 : alpha_(alpha), A_(A), B_(B) {} 211 212 size_t rows(void) const { return A_.rows(); } 213 size_t columns(void) const { return B_.columns(); } 214 double operator()(size_t row, size_t column) const 215 { return gsl_matrix_get(this->gsl_matrix_p(), row, column); } 216 217 void calculate_matrix(gsl_matrix*& result, double beta=0.0) const 218 { 219 detail::reallocate(result, this->rows(), this->columns()); 220 YAT_ASSERT(detail::rows(result) == this->rows()); 221 YAT_ASSERT(detail::columns(result) == this->columns()); 222 YAT_ASSERT(A_.columns() == B_.rows()); 223 MatrixTrait<BasicMatrix<DerivedA> > traitA; 224 MatrixTrait<BasicMatrix<DerivedB> > traitB; 225 traitA.get_gsl_matrix_p(A_); 226 traitB.get_gsl_matrix_p(B_); 227 gsl_blas_dgemm(traitA.transpose_type(), traitB.transpose_type(), 228 alpha_, 229 traitA.get_gsl_matrix_p(A_), 230 traitB.get_gsl_matrix_p(B_), 231 beta, result); 232 YAT_ASSERT(result); 233 } 234 private: 235 236 double alpha_; 237 const BasicMatrix<DerivedA>& A_; 238 const BasicMatrix<DerivedB>& B_; 239 }; 161 240 162 241 } // end namespace expression … … 206 285 207 286 287 /// \cond IGNORE_DOXYGEN 288 289 namespace detail { 290 291 // Helper class used in MatrixExpression * MatrixExpression 292 // The class factorises a matrix expression into 293 // scalar * matrix expression 294 // For the default case there is nothing to factorise, so scalar 295 // is 1.0 and matrix expression is the passed matrix expression. 296 template<class T> 297 struct UnscaleTrait 298 { 299 UnscaleTrait(const T& matrix) : matrix_(matrix) {} 300 typedef typename T::derived_type result_arg; 301 double factor(void) const { return 1.0; } 302 const T& matrix(void) const { return matrix_; } 303 private: 304 const T& matrix_; 305 }; 306 307 // Specialization for ScaledMatrix for which we factor out the 308 // scalar and the matrix. 309 template<class Base> 310 struct UnscaleTrait<BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > > > 311 { 312 UnscaleTrait(const BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > >& m) 313 : matrix_(m), 314 scaled_matrix_(static_cast<const expression::ScaledMatrix<Base>&>(m)) 315 {} 316 317 318 typedef Base result_arg; 319 320 double factor(void) const 321 { 322 return scaled_matrix_.factor(); 323 } 324 325 326 const Base& matrix(void) const 327 { 328 return scaled_matrix_.base(); 329 } 330 private: 331 const BasicMatrix<MatrixExpression<expression::ScaledMatrix<Base> > >& 332 matrix_; 333 const expression::ScaledMatrix<Base>& scaled_matrix_; 334 }; 335 336 } // end namespace detail 337 /// \endcond 338 208 339 /** 209 340 \brief Matrix multiplication operator … … 216 347 */ 217 348 template<class Derived1, class Derived2> 218 expression::MatrixBinary<Derived1, Derived2, expression::Multiplies> 349 expression::ScaledMatrixProduct< 350 typename detail::UnscaleTrait<BasicMatrix<Derived1> >::result_arg, 351 typename detail::UnscaleTrait<BasicMatrix<Derived2> >::result_arg 352 > 219 353 operator*(const BasicMatrix<Derived1>& lhs, const BasicMatrix<Derived2>& rhs) 220 354 { 221 355 YAT_ASSERT(lhs.columns() == rhs.rows()); 222 return expression::MatrixBinary<Derived1, Derived2, 223 expression::Multiplies>(lhs, rhs); 224 } 225 356 detail::UnscaleTrait<BasicMatrix<Derived1> > unscale1(lhs); 357 detail::UnscaleTrait<BasicMatrix<Derived2> > unscale2(rhs); 358 359 return expression::ScaledMatrixProduct< 360 typename detail::UnscaleTrait<BasicMatrix<Derived1> >::result_arg, 361 typename detail::UnscaleTrait<BasicMatrix<Derived2> >::result_arg 362 >(unscale1.factor() * unscale2.factor(), 363 unscale1.matrix(), unscale2.matrix()); 364 } 226 365 227 366 /** … … 277 416 278 417 /** 418 Specialization for ScaledMatrix 419 420 \since New in yat 0.18 421 */ 422 template<class T> 423 expression::ScaledMatrix<T> 424 operator*(const BasicMatrix<MatrixExpression<expression::ScaledMatrix<T> > >& A, 425 double k) 426 { 427 const expression::ScaledMatrix<T>& sm = 428 static_cast<const expression::ScaledMatrix<T>&>(A); 429 return expression::ScaledMatrix<T>(k*sm.factor(), sm.base()); 430 } 431 432 433 /** 434 Specialization for ScaledMatrix 435 436 \since New in yat 0.18 437 */ 438 template<class T> 439 expression::ScaledMatrix<T> 440 operator*(double k, 441 const BasicMatrix<MatrixExpression<expression::ScaledMatrix<T> > >& A) 442 { 443 const expression::ScaledMatrix<T>& sm = 444 static_cast<const expression::ScaledMatrix<T>&>(A); 445 return expression::ScaledMatrix<T>(k*sm.factor(), sm.base()); 446 } 447 448 449 /** 450 Specialization for ScaledMatrix 451 452 \since New in yat 0.18 453 */ 454 template<class T> 455 expression::ScaledMatrix<T> 456 operator-(const BasicMatrix<MatrixExpression<expression::ScaledMatrix<T> > >& A) 457 { 458 const expression::ScaledMatrix<T>& sm = 459 static_cast<const expression::ScaledMatrix<T>&>(A); 460 return expression::ScaledMatrix<T>(-sm.factor(), sm.base()); 461 } 462 463 464 /** 279 465 \brief transpose function 280 466
Note: See TracChangeset
for help on using the changeset viewer.