Changeset 4105


Ignore:
Timestamp:
Sep 24, 2021, 6:55:48 AM (2 years ago)
Author:
Peter
Message:

implement +=, -=, and *= operators for DiagonalMatrix?; prefer using arithmatics provided by Vector class rather than own loops.

Location:
trunk
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/diagonal_matrix.cc

    r3660 r4105  
    22
    33/*
    4   Copyright (C) 2017 Peter Johansson
     4  Copyright (C) 2017, 2021 Peter Johansson
    55
    66  This file is part of the yat library, http://dev.thep.lu.se/yat
     
    118118  suite.out() << "DiagonalMatrix * DiagnoalMatrix\n";
    119119  check(D1 * D1, D2 * D2, suite);
    120 
     120  // test different dimension cases square x square; square x portrait etc
     121  for (size_t i=1; i<4; ++i)
     122    for (size_t j=1; j<4; ++j) {
     123      DiagonalMatrix X1(i,j);
     124      Matrix M1(i,j);
     125      for (size_t k1=0; k1<X1.rows() && k1<X1.columns(); ++k1) {
     126        X1(k1) = 10+k1;
     127        M1(k1, k1) = X1(k1, k1);
     128      }
     129      for (size_t k=1; k<4; ++k) {
     130        DiagonalMatrix X2(j,k);
     131        Matrix M2(j,k);
     132        for (size_t k2=0; k2<X2.rows() && k2<X2.columns(); ++k2) {
     133          X2(k2) = 2+k2;
     134          M2(k2, k2) = X2(k2, k2);
     135        }
     136        check(X1*X2, M1*M2, suite);
     137      }
     138    }
    121139  suite.out() << "DiagonalMatrix * Matrix\n";
    122140  check(D1 * B, D2 * B, suite);
  • trunk/yat/utility/DiagonalMatrix.cc

    r3655 r4105  
    22
    33/*
    4   Copyright (C) 2017 Peter Johansson
     4  Copyright (C) 2017, 2021 Peter Johansson
    55
    66  This file is part of the yat library, http://dev.thep.lu.se/yat
     
    2626#include "Matrix.h"
    2727#include "VectorBase.h"
     28#include "VectorConstView.h"
     29#include "VectorView.h"
    2830
    2931#include <algorithm>
     
    9092
    9193
     94  DiagonalMatrix& DiagonalMatrix::operator*=(const DiagonalMatrix& rhs)
     95  {
     96    assert(columns() == rhs.rows());
     97    // length of new diagonal
     98    size_t n = std::min(row_, rhs.col_);
     99    size_t n1 = data_.size();
     100    size_t n2 = rhs.data_.size();
     101    assert(n==n1 || (n>n1 && n>n2) || (n<n1 && n==n2));
     102
     103    if (n == n1) {
     104      assert(n <= n2);
     105      data_.mul(VectorConstView(rhs.data_, 0, n));
     106    }
     107    else if (n>n1) {
     108      assert(n>n2);
     109      Vector tmp(n, 0);
     110      if (n1 <= n2) {
     111        VectorView view(tmp, 0, n1);
     112        view = data_;
     113        view.mul(VectorConstView(rhs.data_, 0, n1));
     114      }
     115      else {
     116        VectorView view(tmp, 0, n2);
     117        view = VectorConstView(data_, 0, n2);
     118        view.mul(rhs.data_);
     119      }
     120      data_ = std::move(tmp);
     121    }
     122    else {
     123      assert(n < n1);
     124      assert(n == n2);
     125      Vector tmp(rhs.data_);
     126      tmp.mul(VectorConstView(data_, 0, n));
     127      data_ = std::move(tmp);
     128    }
     129    col_ = rhs.col_;
     130    assert(data_.size() == n);
     131    return *this;
     132  }
     133
     134
    92135  DiagonalMatrix operator*(const DiagonalMatrix& lhs, const DiagonalMatrix& rhs)
    93136  {
    94137    assert(lhs.columns() == rhs.rows());
    95     DiagonalMatrix res(lhs.rows(), rhs.columns());
    96     size_t n = std::min(res.rows(), res.columns());
    97     for (size_t i=0; i<n; ++i)
    98       res(i) = lhs(i,i) * rhs(i,i);
     138    DiagonalMatrix res(lhs);
     139    res *= rhs;
    99140    return res;
     141  }
     142
     143
     144  DiagonalMatrix& DiagonalMatrix::operator+=(const DiagonalMatrix& rhs)
     145  {
     146    assert(rows() == rhs.rows());
     147    assert(columns() == rhs.columns());
     148    assert(data_.size() == rhs.data_.size());
     149    data_ += rhs.data_;
     150    return *this;
    100151  }
    101152
     
    107158    assert(lhs.columns() == rhs.columns());
    108159    DiagonalMatrix res(lhs);
    109     size_t n = std::min(res.rows(), res.columns());
    110     for (size_t i=0; i<n; ++i)
    111       res(i) += rhs(i, i);
     160    res += rhs;
    112161    return res;
     162  }
     163
     164
     165  DiagonalMatrix& DiagonalMatrix::operator-=(const DiagonalMatrix& rhs)
     166  {
     167    assert(rows() == rhs.rows());
     168    assert(columns() == rhs.columns());
     169    assert(data_.size() == rhs.data_.size());
     170    data_ -= rhs.data_;
     171    return *this;
    113172  }
    114173
     
    120179    assert(lhs.columns() == rhs.columns());
    121180    DiagonalMatrix res(lhs);
    122     size_t n = std::min(res.rows(), res.columns());
    123     for (size_t i=0; i<n; ++i)
    124       res(i) -= rhs(i, i);
     181    res -= rhs;
    125182    return res;
    126183  }
  • trunk/yat/utility/DiagonalMatrix.h

    r3655 r4105  
    55
    66/*
    7   Copyright (C) 2017 Peter Johansson
     7  Copyright (C) 2017, 2021 Peter Johansson
    88
    99  This file is part of the yat library, http://dev.thep.lu.se/yat
     
    9292     */
    9393    double& operator()(size_t i);
     94
     95    /**
     96       \brief Multiplication and assign operator
     97
     98       Same as doing *this = *this * rhs.
     99
     100       \since New in yat 0.20
     101     */
     102    DiagonalMatrix& operator*=(const DiagonalMatrix& rhs);
     103
     104    /**
     105       \brief Add and assign operator
     106
     107       Elementwise addition of the diagnonal elements.
     108
     109       \since New in yat 0.20
     110     */
     111    DiagonalMatrix& operator+=(const DiagonalMatrix& rhs);
     112
     113    /**
     114       \brief Subtract and assign operator
     115
     116       Elementwise addition of the diagnonal elements.
     117
     118       \since New in yat 0.20
     119     */
     120    DiagonalMatrix& operator-=(const DiagonalMatrix& rhs);
    94121  private:
    95122    Vector data_;
Note: See TracChangeset for help on using the changeset viewer.