source: trunk/yat/utility/matrix.cc @ 1017

Last change on this file since 1017 was 1017, checked in by Peter, 15 years ago

passing VectorBase? in matrix - refs #256

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.1 KB
Line 
1// $Id: matrix.cc 1017 2008-02-01 16:53:04Z peter $
2
3/*
4  Copyright (C) 2003 Daniel Dalevi, Peter Johansson
5  Copyright (C) 2004 Jari Häkkinen, Peter Johansson
6  Copyright (C) 2005, 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
7  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "matrix.h"
28#include "vector.h"
29#include "VectorBase.h"
30#include "VectorView.h"
31#include "utility.h"
32
33#include <cassert>
34#include <cmath>
35#include <sstream>
36#include <vector>
37
38#include <gsl/gsl_blas.h>
39
40namespace theplu {
41namespace yat {
42namespace utility {
43
44
45  matrix::matrix(void)
46    : blas_result_(NULL), m_(NULL), view_(NULL), view_const_(NULL),
47      proxy_m_(NULL)
48  {
49  }
50
51
52  matrix::matrix(const size_t& r, const size_t& c, double init_value)
53    : blas_result_(NULL), m_(gsl_matrix_alloc(r,c)), view_(NULL),
54      view_const_(NULL), proxy_m_(m_)
55  {
56    if (!m_)
57      throw utility::GSL_error("matrix::matrix failed to allocate memory");
58    all(init_value);
59  }
60
61
62  matrix::matrix(const matrix& o)
63    : blas_result_(NULL), m_(o.create_gsl_matrix_copy()), view_(NULL),
64      view_const_(NULL), proxy_m_(m_)
65  {
66  }
67
68
69  matrix::matrix(matrix& m, size_t offset_row, size_t offset_column,
70                 size_t n_row, size_t n_column)
71    : blas_result_(NULL), view_const_(NULL)
72  {
73    view_ = new gsl_matrix_view(gsl_matrix_submatrix(m.m_,
74                                                     offset_row,offset_column,
75                                                     n_row,n_column));
76    if (!view_)
77      throw utility::GSL_error("matrix::matrix failed to setup view");
78    proxy_m_ = m_ = &(view_->matrix);
79  }
80
81
82  // Constructor that gets data from istream
83  matrix::matrix(std::istream& is, char sep) 
84    throw (utility::IO_error,std::exception)
85    : blas_result_(NULL), view_(NULL), view_const_(NULL)
86  {
87    // read the data file and store in stl vectors (dynamically
88    // expandable)
89    std::vector<std::vector<double> > data_matrix;
90    u_int nof_columns=0;
91    u_int nof_rows = 0;
92    std::string line;
93    while(getline(is, line, '\n')){
94      // Ignoring empty lines
95      if (!line.size()) {
96        continue;
97      }
98      nof_rows++;
99      std::vector<double> v;
100      std::string element;
101      std::stringstream ss(line);
102     
103      bool ok=true;
104      while(ok) {
105        if(sep=='\0')
106          ok=(ss>>element);
107        else
108          ok=getline(ss, element, sep);
109        if(!ok)
110          break;
111       
112        if(utility::is_double(element)) {
113          v.push_back(atof(element.c_str()));
114        }
115        else if (!element.size() || utility::is_nan(element)) {
116          v.push_back(std::numeric_limits<double>::quiet_NaN());
117        }
118        else {
119          std::stringstream ss("Warning: '");
120          ss << element << "' is not accepted as a matrix element.";
121          throw IO_error(ss.str());
122        }
123      }           
124      if(sep!='\0' && line[line.size()-1]==sep) // add NaN for final separator
125          v.push_back(std::numeric_limits<double>::quiet_NaN());
126      if (!nof_columns)
127        nof_columns=v.size();
128      else if (v.size()!=nof_columns) {
129        std::ostringstream s;
130        s << "matrix::matrix(std::istream&, char) data file error: "
131          << "line " << nof_rows << " has " << v.size()
132          << " columns; expected " << nof_columns << " columns.";
133        throw utility::IO_error(s.str());
134      }
135      data_matrix.push_back(v);
136    }
137
138    // manipulate the state of the stream to be good
139    is.clear(std::ios::goodbit);
140    // convert the data to a gsl matrix
141    proxy_m_ = m_ = gsl_matrix_alloc ( nof_rows, nof_columns );
142    if (!m_)
143      throw utility::GSL_error("matrix::matrix failed to allocate memory");
144
145    // if gsl error handler disabled, out of bounds index will not
146    // abort the program.
147    for(u_int i=0;i<nof_rows;i++)
148      for(u_int j=0;j<nof_columns;j++)
149        gsl_matrix_set( m_, i, j, data_matrix[i][j] );
150  }
151
152
153  matrix::~matrix(void)
154  {
155    delete_allocated_memory();
156    if (blas_result_)
157      gsl_matrix_free(blas_result_);
158  }
159
160
161  const matrix& matrix::clone(const matrix& other)
162  {
163    if (this!=&other) {
164
165      delete_allocated_memory();
166
167      if (other.view_) {
168        view_ = new gsl_matrix_view(*other.view_);
169        proxy_m_ = m_ = &(view_->matrix);
170      }
171      else if (other.view_const_) {
172        view_const_ = new gsl_matrix_const_view(*other.view_const_);
173        proxy_m_ = &(view_const_->matrix);
174      }
175      else if (other.m_)
176        proxy_m_ = m_ = other.create_gsl_matrix_copy();
177
178      // no need to delete blas_result_ if the number of rows fit, it
179      // may be useful later.
180      if (blas_result_ && (blas_result_->size1!=rows())) {
181        gsl_matrix_free(blas_result_);
182        blas_result_=NULL;
183      }
184    }
185    return *this;
186  } 
187
188
189  size_t matrix::columns(void) const
190  {
191    if (!proxy_m_)
192      return 0;
193    return proxy_m_->size2;
194  }
195
196
197  gsl_matrix* matrix::create_gsl_matrix_copy(void) const
198  {
199    gsl_matrix* m = gsl_matrix_alloc(rows(),columns());
200    if (!m)
201      throw utility::GSL_error("matrix::create_gsl_matrix_copy failed to allocate memory");
202    if (gsl_matrix_memcpy(m,proxy_m_))
203      throw utility::GSL_error("matrix::create_gsl_matrix_copy dimension mis-match");
204    return m;
205  }
206
207
208  void matrix::delete_allocated_memory(void)
209  {
210    if (view_)
211      delete view_;
212    else if (view_const_)
213      delete view_const_;
214    else if (m_)
215      gsl_matrix_free(m_);
216    blas_result_=NULL;
217    proxy_m_=m_=NULL;
218  }
219
220
221  void matrix::div(const matrix& other)
222  {
223    assert(m_);
224    int status=gsl_matrix_div_elements(m_, other.gsl_matrix_p());
225    if (status)
226      throw utility::GSL_error(std::string("matrix::div_elements",status));
227  }
228
229
230  bool matrix::equal(const matrix& other, const double d) const
231  {
232    if (this==&other)
233      return true;
234    if (columns()!=other.columns() || rows()!=other.rows())
235      return false;
236    for (size_t i=0; i<rows(); i++)
237      for (size_t j=0; j<columns(); j++)
238        // The two last condition checks are needed for NaN detection
239        if (fabs( (*this)(i,j)-other(i,j) ) > d ||
240            (*this)(i,j)!=(*this)(i,j) || other(i,j)!=other(i,j))
241          return false;
242    return true;
243  }
244
245
246  const gsl_matrix* matrix::gsl_matrix_p(void) const
247  {
248    return proxy_m_;
249  }
250
251
252  gsl_matrix* matrix::gsl_matrix_p(void)
253  {
254    return m_;
255  }
256
257
258  bool matrix::isview(void) const
259  {
260    return view_ || view_const_;
261  }
262
263
264  void matrix::mul(const matrix& other)
265  {
266    assert(m_);
267    int status=gsl_matrix_mul_elements(m_, other.gsl_matrix_p());
268    if (status)
269      throw utility::GSL_error(std::string("matrix::mul_elements",status));
270  }
271
272
273  void matrix::resize(size_t r, size_t c, double init_value)
274  {
275    delete_allocated_memory();
276
277    proxy_m_ = m_ = gsl_matrix_alloc(r,c);
278    if (!m_)
279      throw utility::GSL_error("matrix::matrix failed to allocate memory");
280    all(init_value);
281
282    // no need to delete blas_result_ if the number of rows fit, it
283    // may be useful later.
284    if (blas_result_ && (blas_result_->size1!=rows())) {
285      gsl_matrix_free(blas_result_);
286      blas_result_=NULL;
287    }
288  }
289
290
291  size_t matrix::rows(void) const
292  {
293    if (!proxy_m_)
294      return 0;
295    return proxy_m_->size1;
296  }
297
298
299  void matrix::all(const double value)
300  {
301    assert(m_);
302    gsl_matrix_set_all(m_, value);
303  }
304
305
306  VectorView matrix::column_vec(size_t col)
307  {
308    VectorView res(*this, col, false);
309    return res;
310  }
311
312
313  const VectorView matrix::column_vec(size_t col) const
314  {
315    return VectorView(*this, col, false);
316  }
317
318
319  const VectorView matrix::row_vec(size_t col) const
320  {
321    return VectorView(*this, col, true);
322  }
323
324
325  VectorView matrix::row_vec(size_t row)
326  {
327    VectorView res(*this, row, true);
328    return res;
329  }
330
331
332  void matrix::swap_columns(const size_t i, const size_t j)
333  {
334    assert(m_);
335    int status=gsl_matrix_swap_columns(m_, i, j);
336    if (status)
337      throw utility::GSL_error(std::string("matrix::swap_columns",status));
338  }
339
340
341  void matrix::swap_rowcol(const size_t i, const size_t j)
342  {
343    assert(m_);
344    int status=gsl_matrix_swap_rowcol(m_, i, j);
345    if (status)
346      throw utility::GSL_error(std::string("matrix::swap_rowcol",status));
347  }
348
349
350  void matrix::swap_rows(const size_t i, const size_t j)
351  {
352    assert(m_);
353    int status=gsl_matrix_swap_rows(m_, i, j);
354    if (status)
355      throw utility::GSL_error(std::string("matrix::swap_rows",status));
356  }
357
358
359  void matrix::transpose(void)
360  {
361    assert(m_);
362    if (columns()==rows())
363      gsl_matrix_transpose(m_); // this never fails
364    else {
365      gsl_matrix* transposed = gsl_matrix_alloc(columns(),rows());
366      if (!transposed)
367        throw utility::GSL_error("matrix::transpose failed to allocate memory");
368      // next line never fails if allocation above succeeded.
369      gsl_matrix_transpose_memcpy(transposed,m_);
370      gsl_matrix_free(m_);
371      proxy_m_ = m_ = transposed;
372      if (blas_result_) {
373        gsl_matrix_free(blas_result_);
374        blas_result_=NULL;
375      }
376    }
377  }
378
379
380  double& matrix::operator()(size_t row, size_t column)
381  {
382    assert(m_);
383    assert(row<rows());
384    assert(column<columns());
385    double* d=gsl_matrix_ptr(m_, row, column);
386    if (!d)
387      throw utility::GSL_error("matrix::operator()",GSL_EINVAL);
388    return *d;
389  }
390
391
392  const double& matrix::operator()(size_t row, size_t column) const
393  {
394    assert(row<rows());
395    assert(column<columns());
396    const double* d=gsl_matrix_const_ptr(proxy_m_, row, column);
397    if (!d)
398      throw utility::GSL_error("matrix::operator()",GSL_EINVAL);
399    return *d;
400  }
401
402
403  bool matrix::operator==(const matrix& other) const
404  {
405    return equal(other);
406  }
407
408
409  bool matrix::operator!=(const matrix& other) const
410  {
411    return !equal(other);
412  }
413
414
415  const matrix& matrix::operator=( const matrix& other )
416  {
417    assert(m_);
418    if (this!=&other)
419      if (gsl_matrix_memcpy(m_, other.gsl_matrix_p()))
420        throw utility::GSL_error("matrix::create_gsl_matrix_copy dimension mis-match");
421    return *this;
422  }
423
424
425  const matrix& matrix::operator+=(const matrix& other)
426  {
427    assert(m_);
428    int status=gsl_matrix_add(m_, other.proxy_m_);
429    if (status)
430      throw utility::GSL_error(std::string("matrix::operator+=", status));
431    return *this;
432  }
433
434
435  const matrix& matrix::operator+=(const double d)
436  {
437    assert(m_);
438    gsl_matrix_add_constant(m_, d);
439    return *this;
440  }
441
442
443  const matrix& matrix::operator-=(const matrix& other)
444  {
445    assert(m_);
446    int status=gsl_matrix_sub(m_, other.proxy_m_);
447    if (status)
448      throw utility::GSL_error(std::string("matrix::operator-=", status));
449    return *this;
450  }
451
452
453  const matrix& matrix::operator-=(const double d)
454  {
455    assert(m_);
456    gsl_matrix_add_constant(m_, -d);
457    return *this;
458  }
459
460
461  const matrix& matrix::operator*=(const matrix& other)
462  {
463    assert(m_);
464    if ( blas_result_ && ((blas_result_->size1!=rows()) ||
465                          (blas_result_->size2!=other.columns())) ) {
466      gsl_matrix_free(blas_result_);
467      blas_result_=NULL;
468    }
469    if (!blas_result_) {
470      blas_result_ = gsl_matrix_alloc(rows(),other.columns());
471      if (!blas_result_)
472        throw utility::GSL_error("matrix::operator*= failed to allocate memory");
473    }
474    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, m_, other.proxy_m_, 0.0,
475                   blas_result_);
476    gsl_matrix* tmp=m_;
477    proxy_m_ = m_ = blas_result_;
478    blas_result_=tmp;
479    return *this;
480  }
481
482
483  const matrix& matrix::operator*=(const double d)
484  {
485    assert(m_);
486    gsl_matrix_scale(m_, d);
487    return *this;
488  }
489
490
491  bool isnull(const matrix& other)
492  {
493    return gsl_matrix_isnull(other.gsl_matrix_p());
494  }
495
496
497  double max(const matrix& other)
498  {
499    return gsl_matrix_max(other.gsl_matrix_p());
500  }
501
502
503  double min(const matrix& other)
504  {
505    return gsl_matrix_min(other.gsl_matrix_p());
506  }
507
508
509  void minmax_index(const matrix& other,
510                    std::pair<size_t,size_t>& min, std::pair<size_t,size_t>& max)
511  {
512    gsl_matrix_minmax_index(other.gsl_matrix_p(), &min.first, &min.second,
513                            &max.first, &max.second);
514  }
515
516
517  bool nan(const matrix& templat, matrix& flag)
518  {
519    size_t rows=templat.rows();
520    size_t columns=templat.columns();
521    if (rows!=flag.rows() && columns!=flag.columns())
522      flag.clone(matrix(rows,columns,1.0));
523    else
524      flag.all(1.0);
525    bool nan=false;
526    for (size_t i=0; i<rows; i++)
527      for (size_t j=0; j<columns; j++) 
528        if (std::isnan(templat(i,j))) {
529          flag(i,j)=0;
530          nan=true;
531        }
532    return nan;
533  }
534
535
536  void swap(matrix& a, matrix& b)
537  {
538    assert(a.gsl_matrix_p()); assert(b.gsl_matrix_p());
539    int status=gsl_matrix_swap(a.gsl_matrix_p(), b.gsl_matrix_p());
540    if (status)
541      throw utility::GSL_error(std::string("swap(matrix&,matrix&)",status));
542  }
543
544
545  std::ostream& operator<<(std::ostream& s, const matrix& m)
546  {
547    s.setf(std::ios::dec);
548    s.precision(12);
549    for(size_t i=0, j=0; i<m.rows(); i++)
550      for (j=0; j<m.columns(); j++) {
551        s << m(i,j);
552        if (j<m.columns()-1)
553          s << s.fill();
554        else if (i<m.rows()-1)
555          s << "\n";
556      }
557    return s;
558  }
559
560
561  vector operator*(const matrix& m, const VectorBase& v)
562  {
563    utility::vector res(m.rows());
564    for (size_t i=0; i<res.size(); ++i)
565      res(i) = VectorView(m,i) * v;
566    return res;
567  }
568
569
570  vector operator*(const VectorBase& v, const matrix& m)
571  {
572    utility::vector res(m.columns());
573    for (size_t i=0; i<res.size(); ++i)
574      res(i) = v * VectorView(m,i,false);
575    return res;
576  }
577
578}}} // of namespace utility, yat and thep
Note: See TracBrowser for help on using the repository browser.