source: trunk/yat/utility/matrix.cc @ 1028

Last change on this file since 1028 was 1028, checked in by Peter, 14 years ago

documentation for VectorConstView? and changing name of view functions in matrix

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.1 KB
Line 
1// $Id: matrix.cc 1028 2008-02-03 01:53:29Z peter $
2
3/*
4  Copyright (C) 2003 Daniel Dalevi, Peter Johansson
5  Copyright (C) 2004 Jari Häkkinen, Peter Johansson
6  Copyright (C) 2005, 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
7  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "matrix.h"
28#include "vector.h"
29#include "VectorBase.h"
30#include "VectorConstView.h"
31#include "VectorView.h"
32#include "utility.h"
33
34#include <cassert>
35#include <cmath>
36#include <sstream>
37#include <vector>
38
39#include <gsl/gsl_blas.h>
40
41namespace theplu {
42namespace yat {
43namespace utility {
44
45
46  matrix::matrix(void)
47    : blas_result_(NULL), m_(NULL), view_(NULL), view_const_(NULL),
48      proxy_m_(NULL)
49  {
50  }
51
52
53  matrix::matrix(const size_t& r, const size_t& c, double init_value)
54    : blas_result_(NULL), m_(gsl_matrix_alloc(r,c)), view_(NULL),
55      view_const_(NULL), proxy_m_(m_)
56  {
57    if (!m_)
58      throw utility::GSL_error("matrix::matrix failed to allocate memory");
59    all(init_value);
60  }
61
62
63  matrix::matrix(const matrix& o)
64    : blas_result_(NULL), m_(o.create_gsl_matrix_copy()), view_(NULL),
65      view_const_(NULL), proxy_m_(m_)
66  {
67  }
68
69
70  matrix::matrix(matrix& m, size_t offset_row, size_t offset_column,
71                 size_t n_row, size_t n_column)
72    : blas_result_(NULL), view_const_(NULL)
73  {
74    view_ = new gsl_matrix_view(gsl_matrix_submatrix(m.m_,
75                                                     offset_row,offset_column,
76                                                     n_row,n_column));
77    if (!view_)
78      throw utility::GSL_error("matrix::matrix failed to setup view");
79    proxy_m_ = m_ = &(view_->matrix);
80  }
81
82
83  // Constructor that gets data from istream
84  matrix::matrix(std::istream& is, char sep) 
85    throw (utility::IO_error,std::exception)
86    : blas_result_(NULL), view_(NULL), view_const_(NULL)
87  {
88    // read the data file and store in stl vectors (dynamically
89    // expandable)
90    std::vector<std::vector<double> > data_matrix;
91    u_int nof_columns=0;
92    u_int nof_rows = 0;
93    std::string line;
94    while(getline(is, line, '\n')){
95      // Ignoring empty lines
96      if (!line.size()) {
97        continue;
98      }
99      nof_rows++;
100      std::vector<double> v;
101      std::string element;
102      std::stringstream ss(line);
103     
104      bool ok=true;
105      while(ok) {
106        if(sep=='\0')
107          ok=(ss>>element);
108        else
109          ok=getline(ss, element, sep);
110        if(!ok)
111          break;
112       
113        if(utility::is_double(element)) {
114          v.push_back(atof(element.c_str()));
115        }
116        else if (!element.size() || utility::is_nan(element)) {
117          v.push_back(std::numeric_limits<double>::quiet_NaN());
118        }
119        else {
120          std::stringstream ss("Warning: '");
121          ss << element << "' is not accepted as a matrix element.";
122          throw IO_error(ss.str());
123        }
124      }           
125      if(sep!='\0' && line[line.size()-1]==sep) // add NaN for final separator
126          v.push_back(std::numeric_limits<double>::quiet_NaN());
127      if (!nof_columns)
128        nof_columns=v.size();
129      else if (v.size()!=nof_columns) {
130        std::ostringstream s;
131        s << "matrix::matrix(std::istream&, char) data file error: "
132          << "line " << nof_rows << " has " << v.size()
133          << " columns; expected " << nof_columns << " columns.";
134        throw utility::IO_error(s.str());
135      }
136      data_matrix.push_back(v);
137    }
138
139    // manipulate the state of the stream to be good
140    is.clear(std::ios::goodbit);
141    // convert the data to a gsl matrix
142    proxy_m_ = m_ = gsl_matrix_alloc ( nof_rows, nof_columns );
143    if (!m_)
144      throw utility::GSL_error("matrix::matrix failed to allocate memory");
145
146    // if gsl error handler disabled, out of bounds index will not
147    // abort the program.
148    for(u_int i=0;i<nof_rows;i++)
149      for(u_int j=0;j<nof_columns;j++)
150        gsl_matrix_set( m_, i, j, data_matrix[i][j] );
151  }
152
153
154  matrix::~matrix(void)
155  {
156    delete_allocated_memory();
157    if (blas_result_)
158      gsl_matrix_free(blas_result_);
159  }
160
161
162  const matrix& matrix::clone(const matrix& other)
163  {
164    if (this!=&other) {
165
166      delete_allocated_memory();
167
168      if (other.view_) {
169        view_ = new gsl_matrix_view(*other.view_);
170        proxy_m_ = m_ = &(view_->matrix);
171      }
172      else if (other.view_const_) {
173        view_const_ = new gsl_matrix_const_view(*other.view_const_);
174        proxy_m_ = &(view_const_->matrix);
175      }
176      else if (other.m_)
177        proxy_m_ = m_ = other.create_gsl_matrix_copy();
178
179      // no need to delete blas_result_ if the number of rows fit, it
180      // may be useful later.
181      if (blas_result_ && (blas_result_->size1!=rows())) {
182        gsl_matrix_free(blas_result_);
183        blas_result_=NULL;
184      }
185    }
186    return *this;
187  } 
188
189
190  size_t matrix::columns(void) const
191  {
192    if (!proxy_m_)
193      return 0;
194    return proxy_m_->size2;
195  }
196
197
198  gsl_matrix* matrix::create_gsl_matrix_copy(void) const
199  {
200    gsl_matrix* m = gsl_matrix_alloc(rows(),columns());
201    if (!m)
202      throw utility::GSL_error("matrix::create_gsl_matrix_copy failed to allocate memory");
203    if (gsl_matrix_memcpy(m,proxy_m_))
204      throw utility::GSL_error("matrix::create_gsl_matrix_copy dimension mis-match");
205    return m;
206  }
207
208
209  void matrix::delete_allocated_memory(void)
210  {
211    if (view_)
212      delete view_;
213    else if (view_const_)
214      delete view_const_;
215    else if (m_)
216      gsl_matrix_free(m_);
217    blas_result_=NULL;
218    proxy_m_=m_=NULL;
219  }
220
221
222  void matrix::div(const matrix& other)
223  {
224    assert(m_);
225    int status=gsl_matrix_div_elements(m_, other.gsl_matrix_p());
226    if (status)
227      throw utility::GSL_error(std::string("matrix::div_elements",status));
228  }
229
230
231  bool matrix::equal(const matrix& other, const double d) const
232  {
233    if (this==&other)
234      return true;
235    if (columns()!=other.columns() || rows()!=other.rows())
236      return false;
237    for (size_t i=0; i<rows(); i++)
238      for (size_t j=0; j<columns(); j++)
239        // The two last condition checks are needed for NaN detection
240        if (fabs( (*this)(i,j)-other(i,j) ) > d ||
241            (*this)(i,j)!=(*this)(i,j) || other(i,j)!=other(i,j))
242          return false;
243    return true;
244  }
245
246
247  const gsl_matrix* matrix::gsl_matrix_p(void) const
248  {
249    return proxy_m_;
250  }
251
252
253  gsl_matrix* matrix::gsl_matrix_p(void)
254  {
255    return m_;
256  }
257
258
259  bool matrix::isview(void) const
260  {
261    return view_ || view_const_;
262  }
263
264
265  void matrix::mul(const matrix& other)
266  {
267    assert(m_);
268    int status=gsl_matrix_mul_elements(m_, other.gsl_matrix_p());
269    if (status)
270      throw utility::GSL_error(std::string("matrix::mul_elements",status));
271  }
272
273
274  void matrix::resize(size_t r, size_t c, double init_value)
275  {
276    delete_allocated_memory();
277
278    proxy_m_ = m_ = gsl_matrix_alloc(r,c);
279    if (!m_)
280      throw utility::GSL_error("matrix::matrix failed to allocate memory");
281    all(init_value);
282
283    // no need to delete blas_result_ if the number of rows fit, it
284    // may be useful later.
285    if (blas_result_ && (blas_result_->size1!=rows())) {
286      gsl_matrix_free(blas_result_);
287      blas_result_=NULL;
288    }
289  }
290
291
292  size_t matrix::rows(void) const
293  {
294    if (!proxy_m_)
295      return 0;
296    return proxy_m_->size1;
297  }
298
299
300  void matrix::all(const double value)
301  {
302    assert(m_);
303    gsl_matrix_set_all(m_, value);
304  }
305
306
307  VectorView matrix::column_view(size_t col)
308  {
309    VectorView res(*this, col, false);
310    return res;
311  }
312
313
314  const VectorConstView matrix::column_const_view(size_t col) const
315  {
316    return VectorConstView(*this, col, false);
317  }
318
319
320  const VectorConstView matrix::row_const_view(size_t col) const
321  {
322    return VectorConstView(*this, col, true);
323  }
324
325
326  VectorView matrix::row_view(size_t row)
327  {
328    VectorView res(*this, row, true);
329    return res;
330  }
331
332
333  void matrix::swap_columns(const size_t i, const size_t j)
334  {
335    assert(m_);
336    int status=gsl_matrix_swap_columns(m_, i, j);
337    if (status)
338      throw utility::GSL_error(std::string("matrix::swap_columns",status));
339  }
340
341
342  void matrix::swap_rowcol(const size_t i, const size_t j)
343  {
344    assert(m_);
345    int status=gsl_matrix_swap_rowcol(m_, i, j);
346    if (status)
347      throw utility::GSL_error(std::string("matrix::swap_rowcol",status));
348  }
349
350
351  void matrix::swap_rows(const size_t i, const size_t j)
352  {
353    assert(m_);
354    int status=gsl_matrix_swap_rows(m_, i, j);
355    if (status)
356      throw utility::GSL_error(std::string("matrix::swap_rows",status));
357  }
358
359
360  void matrix::transpose(void)
361  {
362    assert(m_);
363    if (columns()==rows())
364      gsl_matrix_transpose(m_); // this never fails
365    else {
366      gsl_matrix* transposed = gsl_matrix_alloc(columns(),rows());
367      if (!transposed)
368        throw utility::GSL_error("matrix::transpose failed to allocate memory");
369      // next line never fails if allocation above succeeded.
370      gsl_matrix_transpose_memcpy(transposed,m_);
371      gsl_matrix_free(m_);
372      proxy_m_ = m_ = transposed;
373      if (blas_result_) {
374        gsl_matrix_free(blas_result_);
375        blas_result_=NULL;
376      }
377    }
378  }
379
380
381  double& matrix::operator()(size_t row, size_t column)
382  {
383    assert(m_);
384    assert(row<rows());
385    assert(column<columns());
386    double* d=gsl_matrix_ptr(m_, row, column);
387    if (!d)
388      throw utility::GSL_error("matrix::operator()",GSL_EINVAL);
389    return *d;
390  }
391
392
393  const double& matrix::operator()(size_t row, size_t column) const
394  {
395    assert(row<rows());
396    assert(column<columns());
397    const double* d=gsl_matrix_const_ptr(proxy_m_, row, column);
398    if (!d)
399      throw utility::GSL_error("matrix::operator()",GSL_EINVAL);
400    return *d;
401  }
402
403
404  bool matrix::operator==(const matrix& other) const
405  {
406    return equal(other);
407  }
408
409
410  bool matrix::operator!=(const matrix& other) const
411  {
412    return !equal(other);
413  }
414
415
416  const matrix& matrix::operator=( const matrix& other )
417  {
418    assert(m_);
419    if (this!=&other)
420      if (gsl_matrix_memcpy(m_, other.gsl_matrix_p()))
421        throw utility::GSL_error("matrix::create_gsl_matrix_copy dimension mis-match");
422    return *this;
423  }
424
425
426  const matrix& matrix::operator+=(const matrix& other)
427  {
428    assert(m_);
429    int status=gsl_matrix_add(m_, other.proxy_m_);
430    if (status)
431      throw utility::GSL_error(std::string("matrix::operator+=", status));
432    return *this;
433  }
434
435
436  const matrix& matrix::operator+=(const double d)
437  {
438    assert(m_);
439    gsl_matrix_add_constant(m_, d);
440    return *this;
441  }
442
443
444  const matrix& matrix::operator-=(const matrix& other)
445  {
446    assert(m_);
447    int status=gsl_matrix_sub(m_, other.proxy_m_);
448    if (status)
449      throw utility::GSL_error(std::string("matrix::operator-=", status));
450    return *this;
451  }
452
453
454  const matrix& matrix::operator-=(const double d)
455  {
456    assert(m_);
457    gsl_matrix_add_constant(m_, -d);
458    return *this;
459  }
460
461
462  const matrix& matrix::operator*=(const matrix& other)
463  {
464    assert(m_);
465    if ( blas_result_ && ((blas_result_->size1!=rows()) ||
466                          (blas_result_->size2!=other.columns())) ) {
467      gsl_matrix_free(blas_result_);
468      blas_result_=NULL;
469    }
470    if (!blas_result_) {
471      blas_result_ = gsl_matrix_alloc(rows(),other.columns());
472      if (!blas_result_)
473        throw utility::GSL_error("matrix::operator*= failed to allocate memory");
474    }
475    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, m_, other.proxy_m_, 0.0,
476                   blas_result_);
477    gsl_matrix* tmp=m_;
478    proxy_m_ = m_ = blas_result_;
479    blas_result_=tmp;
480    return *this;
481  }
482
483
484  const matrix& matrix::operator*=(const double d)
485  {
486    assert(m_);
487    gsl_matrix_scale(m_, d);
488    return *this;
489  }
490
491
492  bool isnull(const matrix& other)
493  {
494    return gsl_matrix_isnull(other.gsl_matrix_p());
495  }
496
497
498  double max(const matrix& other)
499  {
500    return gsl_matrix_max(other.gsl_matrix_p());
501  }
502
503
504  double min(const matrix& other)
505  {
506    return gsl_matrix_min(other.gsl_matrix_p());
507  }
508
509
510  void minmax_index(const matrix& other,
511                    std::pair<size_t,size_t>& min, std::pair<size_t,size_t>& max)
512  {
513    gsl_matrix_minmax_index(other.gsl_matrix_p(), &min.first, &min.second,
514                            &max.first, &max.second);
515  }
516
517
518  bool nan(const matrix& templat, matrix& flag)
519  {
520    size_t rows=templat.rows();
521    size_t columns=templat.columns();
522    if (rows!=flag.rows() && columns!=flag.columns())
523      flag.clone(matrix(rows,columns,1.0));
524    else
525      flag.all(1.0);
526    bool nan=false;
527    for (size_t i=0; i<rows; i++)
528      for (size_t j=0; j<columns; j++) 
529        if (std::isnan(templat(i,j))) {
530          flag(i,j)=0;
531          nan=true;
532        }
533    return nan;
534  }
535
536
537  void swap(matrix& a, matrix& b)
538  {
539    assert(a.gsl_matrix_p()); assert(b.gsl_matrix_p());
540    int status=gsl_matrix_swap(a.gsl_matrix_p(), b.gsl_matrix_p());
541    if (status)
542      throw utility::GSL_error(std::string("swap(matrix&,matrix&)",status));
543  }
544
545
546  std::ostream& operator<<(std::ostream& s, const matrix& m)
547  {
548    s.setf(std::ios::dec);
549    s.precision(12);
550    for(size_t i=0, j=0; i<m.rows(); i++)
551      for (j=0; j<m.columns(); j++) {
552        s << m(i,j);
553        if (j<m.columns()-1)
554          s << s.fill();
555        else if (i<m.rows()-1)
556          s << "\n";
557      }
558    return s;
559  }
560
561
562  vector operator*(const matrix& m, const VectorBase& v)
563  {
564    utility::vector res(m.rows());
565    for (size_t i=0; i<res.size(); ++i)
566      res(i) = VectorConstView(m,i) * v;
567    return res;
568  }
569
570
571  vector operator*(const VectorBase& v, const matrix& m)
572  {
573    utility::vector res(m.columns());
574    for (size_t i=0; i<res.size(); ++i)
575      res(i) = v * VectorConstView(m,i,false);
576    return res;
577  }
578
579}}} // of namespace utility, yat and thep
Note: See TracBrowser for help on using the repository browser.