source: trunk/yat/utility/matrix.cc @ 808

Last change on this file since 808 was 808, checked in by Peter, 15 years ago

previous argument was invalid, but here is an implementation. Fixes #205

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.1 KB
Line 
1// $Id: matrix.cc 808 2007-03-15 20:07:01Z peter $
2
3/*
4  Copyright (C) 2003 Daniel Dalevi, Peter Johansson
5  Copyright (C) 2004 Jari Häkkinen, Peter Johansson
6  Copyright (C) 2005, 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
7  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
8
9  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "matrix.h"
28#include "vector.h"
29#include "utility.h"
30
31#include <cassert>
32#include <cmath>
33#include <sstream>
34#include <vector>
35
36#include <gsl/gsl_blas.h>
37
38namespace theplu {
39namespace yat {
40namespace utility {
41
42
43  matrix::matrix(void)
44    : blas_result_(NULL), m_(NULL), view_(NULL), view_const_(NULL),
45      proxy_m_(NULL)
46  {
47  }
48
49
50  matrix::matrix(const size_t& r, const size_t& c, double init_value)
51    : blas_result_(NULL), m_(gsl_matrix_alloc(r,c)), view_(NULL),
52      view_const_(NULL), proxy_m_(m_)
53  {
54    if (!m_)
55      throw utility::GSL_error("matrix::matrix failed to allocate memory");
56    set_all(init_value);
57  }
58
59
60  matrix::matrix(const matrix& o)
61    : blas_result_(NULL), m_(o.create_gsl_matrix_copy()), view_(NULL),
62      view_const_(NULL), proxy_m_(m_)
63  {
64  }
65
66
67  matrix::matrix(matrix& m, size_t offset_row, size_t offset_column,
68                 size_t n_row, size_t n_column)
69    : blas_result_(NULL), view_const_(NULL)
70  {
71    view_ = new gsl_matrix_view(gsl_matrix_submatrix(m.m_,
72                                                     offset_row,offset_column,
73                                                     n_row,n_column));
74    if (!view_)
75      throw utility::GSL_error("matrix::matrix failed to setup view");
76    proxy_m_ = m_ = &(view_->matrix);
77  }
78
79
80  // Constructor that gets data from istream
81  matrix::matrix(std::istream& is, char sep) 
82    throw (utility::IO_error,std::exception)
83    : blas_result_(NULL), view_(NULL), view_const_(NULL)
84  {
85    // read the data file and store in stl vectors (dynamically
86    // expandable)
87    std::vector<std::vector<double> > data_matrix;
88    u_int nof_columns=0;
89    u_int nof_rows = 0;
90    std::string line;
91    while(getline(is, line, '\n')){
92      // Ignoring empty lines
93      if (!line.size()) {
94        continue;
95      }
96      nof_rows++;
97      std::vector<double> v;
98      std::string element;
99      std::stringstream ss(line);
100     
101      bool ok=true;
102      while(ok) {
103        if(sep=='\0')
104          ok=(ss>>element);
105        else
106          ok=getline(ss, element, sep);
107        if(!ok)
108          break;
109       
110        if(utility::is_double(element)) {
111          v.push_back(atof(element.c_str()));
112        }
113        else if (!element.size() || utility::is_nan(element)) {
114          v.push_back(std::numeric_limits<double>::quiet_NaN());
115        }
116        else {
117          std::stringstream ss("Warning: '");
118          ss << element << "' is not accepted as a matrix element.";
119          throw IO_error(ss.str());
120        }
121      }           
122      if(sep!='\0' && line[line.size()-1]==sep) // add NaN for final separator
123          v.push_back(std::numeric_limits<double>::quiet_NaN());
124      if (!nof_columns)
125        nof_columns=v.size();
126      else if (v.size()!=nof_columns) {
127        std::ostringstream s;
128        s << "matrix::matrix(std::istream&, char) data file error: "
129          << "line " << nof_rows << " has " << v.size()
130          << " columns; expected " << nof_columns << " columns.";
131        throw utility::IO_error(s.str());
132      }
133      data_matrix.push_back(v);
134    }
135
136    // manipulate the state of the stream to be good
137    is.clear(std::ios::goodbit);
138    // convert the data to a gsl matrix
139    proxy_m_ = m_ = gsl_matrix_alloc ( nof_rows, nof_columns );
140    if (!m_)
141      throw utility::GSL_error("matrix::matrix failed to allocate memory");
142
143    // if gsl error handler disabled, out of bounds index will not
144    // abort the program.
145    for(u_int i=0;i<nof_rows;i++)
146      for(u_int j=0;j<nof_columns;j++)
147        gsl_matrix_set( m_, i, j, data_matrix[i][j] );
148  }
149
150
151  matrix::~matrix(void)
152  {
153    if (view_)
154      delete view_;
155    else if (view_const_)
156      delete view_const_;
157    else if (m_)
158      gsl_matrix_free(m_);
159    if (blas_result_)
160      gsl_matrix_free(blas_result_);
161    blas_result_=NULL;
162    m_=NULL;
163  }
164
165
166  const matrix& matrix::clone(const matrix& other)
167  {
168    if (this!=&other) {
169
170      if (view_)
171        delete view_;
172      else if (view_const_)
173        delete view_const_;
174      else if (m_)
175        gsl_matrix_free(m_);
176      proxy_m_=m_=NULL;
177
178      if (other.view_) {
179        view_ = new gsl_matrix_view(*other.view_);
180        proxy_m_ = m_ = &(view_->matrix);
181      }
182      else if (other.view_const_) {
183        view_const_ = new gsl_matrix_const_view(*other.view_const_);
184        proxy_m_ = &(view_const_->matrix);
185      }
186      else if (other.m_)
187        proxy_m_ = m_ = other.create_gsl_matrix_copy();
188
189      // no need to delete blas_result_ if the number of rows fit, it
190      // may be useful later.
191      if (blas_result_ && (blas_result_->size1!=rows())) {
192        gsl_matrix_free(blas_result_);
193        blas_result_=NULL;
194      }
195    }
196    return *this;
197  } 
198
199
200  size_t matrix::columns(void) const
201  {
202    if (!proxy_m_)
203      return 0;
204    return proxy_m_->size2;
205  }
206
207
208  gsl_matrix* matrix::create_gsl_matrix_copy(void) const
209  {
210    gsl_matrix* m = gsl_matrix_alloc(rows(),columns());
211    if (!m)
212      throw utility::GSL_error("matrix::create_gsl_matrix_copy failed to allocate memory");
213    if (gsl_matrix_memcpy(m,proxy_m_))
214      throw utility::GSL_error("matrix::create_gsl_matrix_copy dimension mis-match");
215    return m;
216  }
217
218
219  void matrix::div(const matrix& other)
220  {
221    assert(m_);
222    int status=gsl_matrix_div_elements(m_, other.gsl_matrix_p());
223    if (status)
224      throw utility::GSL_error(std::string("matrix::div_elements",status));
225  }
226
227
228  bool matrix::equal(const matrix& other, const double d) const
229  {
230    if (this==&other)
231      return true;
232    if (columns()!=other.columns() || rows()!=other.rows())
233      return false;
234    for (size_t i=0; i<rows(); i++)
235      for (size_t j=0; j<columns(); j++)
236        // The two last condition checks are needed for NaN detection
237        if (fabs( (*this)(i,j)-other(i,j) ) > d ||
238            (*this)(i,j)!=(*this)(i,j) || other(i,j)!=other(i,j))
239          return false;
240    return true;
241  }
242
243
244  const gsl_matrix* matrix::gsl_matrix_p(void) const
245  {
246    return proxy_m_;
247  }
248
249
250  gsl_matrix* matrix::gsl_matrix_p(void)
251  {
252    return m_;
253  }
254
255
256  bool matrix::isview(void) const
257  {
258    return view_ || view_const_;
259  }
260
261
262  void matrix::mul(const matrix& other)
263  {
264    assert(m_);
265    int status=gsl_matrix_mul_elements(m_, other.gsl_matrix_p());
266    if (status)
267      throw utility::GSL_error(std::string("matrix::mul_elements",status));
268  }
269
270
271  void matrix::resize(size_t r, size_t c, double init_value)
272  {
273      if (view_)
274        delete view_;
275      else if (view_const_)
276        delete view_const_;
277      else if (m_)
278        gsl_matrix_free(m_);
279      proxy_m_ = m_ = gsl_matrix_alloc(r,c);
280      if (!m_)
281        throw utility::GSL_error("matrix::matrix failed to allocate memory");
282      set_all(init_value);
283
284      // no need to delete blas_result_ if the number of rows fit, it
285      // may be useful later.
286      if (blas_result_ && (blas_result_->size1!=rows())) {
287        gsl_matrix_free(blas_result_);
288        blas_result_=NULL;
289      }
290  } 
291
292
293  size_t matrix::rows(void) const
294  {
295    if (!proxy_m_)
296      return 0;
297    return proxy_m_->size1;
298  }
299
300
301  void matrix::set(const matrix& other)
302  {
303    assert(m_);
304    if (gsl_matrix_memcpy(m_, other.gsl_matrix_p()))
305      throw utility::GSL_error("matrix::create_gsl_matrix_copy dimension mis-match");
306  }
307
308
309  void matrix::set_all(const double value)
310  {
311    assert(m_);
312    gsl_matrix_set_all(m_, value);
313  }
314
315
316  void matrix::set_column(const size_t column, const vector& vec)
317  {
318    assert(m_);
319    int status=gsl_matrix_set_col(m_, column, vec.gsl_vector_p());
320    if (status)
321      throw utility::GSL_error(std::string("matrix::set_column",status));
322  }
323
324
325  void matrix::set_row(const size_t row, const vector& vec)
326  {
327    assert(m_);
328    int status=gsl_matrix_set_row(m_, row, vec.gsl_vector_p());
329    if (status)
330      throw utility::GSL_error(std::string("matrix::set_row",status));
331  }
332
333
334  void matrix::swap_columns(const size_t i, const size_t j)
335  {
336    assert(m_);
337    int status=gsl_matrix_swap_columns(m_, i, j);
338    if (status)
339      throw utility::GSL_error(std::string("matrix::swap_columns",status));
340  }
341
342
343  void matrix::swap_rowcol(const size_t i, const size_t j)
344  {
345    assert(m_);
346    int status=gsl_matrix_swap_rowcol(m_, i, j);
347    if (status)
348      throw utility::GSL_error(std::string("matrix::swap_rowcol",status));
349  }
350
351
352  void matrix::swap_rows(const size_t i, const size_t j)
353  {
354    assert(m_);
355    int status=gsl_matrix_swap_rows(m_, i, j);
356    if (status)
357      throw utility::GSL_error(std::string("matrix::swap_rows",status));
358  }
359
360
361  void matrix::transpose(void)
362  {
363    assert(m_);
364    if (columns()==rows())
365      gsl_matrix_transpose(m_); // this never fails
366    else {
367      gsl_matrix* transposed = gsl_matrix_alloc(columns(),rows());
368      if (!transposed)
369        throw utility::GSL_error("matrix::transpose failed to allocate memory");
370      // next line never fails if allocation above succeeded.
371      gsl_matrix_transpose_memcpy(transposed,m_);
372      gsl_matrix_free(m_);
373      proxy_m_ = m_ = transposed;
374      if (blas_result_) {
375        gsl_matrix_free(blas_result_);
376        blas_result_=NULL;
377      }
378    }
379  }
380
381
382  double& matrix::operator()(size_t row, size_t column)
383  {
384    assert(m_);
385    double* d=gsl_matrix_ptr(m_, row, column);
386    if (!d)
387      throw utility::GSL_error("matrix::operator()",GSL_EINVAL);
388    return *d;
389  }
390
391
392  const double& matrix::operator()(size_t row, size_t column) const
393  {
394    const double* d=gsl_matrix_const_ptr(proxy_m_, row, column);
395    if (!d)
396      throw utility::GSL_error("matrix::operator()",GSL_EINVAL);
397    return *d;
398  }
399
400
401  bool matrix::operator==(const matrix& other) const
402  {
403    return equal(other);
404  }
405
406
407  bool matrix::operator!=(const matrix& other) const
408  {
409    return !equal(other);
410  }
411
412
413  const matrix& matrix::operator=( const matrix& other )
414  {
415    set(other);
416    return *this;
417  }
418
419
420  const matrix& matrix::operator+=(const matrix& other)
421  {
422    assert(m_);
423    int status=gsl_matrix_add(m_, other.proxy_m_);
424    if (status)
425      throw utility::GSL_error(std::string("matrix::operator+=", status));
426    return *this;
427  }
428
429
430  const matrix& matrix::operator+=(const double d)
431  {
432    assert(m_);
433    gsl_matrix_add_constant(m_, d);
434    return *this;
435  }
436
437
438  const matrix& matrix::operator-=(const matrix& other)
439  {
440    assert(m_);
441    int status=gsl_matrix_sub(m_, other.proxy_m_);
442    if (status)
443      throw utility::GSL_error(std::string("matrix::operator-=", status));
444    return *this;
445  }
446
447
448  const matrix& matrix::operator-=(const double d)
449  {
450    assert(m_);
451    gsl_matrix_add_constant(m_, -d);
452    return *this;
453  }
454
455
456  const matrix& matrix::operator*=(const matrix& other)
457  {
458    assert(m_);
459    if ( blas_result_ && ((blas_result_->size1!=rows()) ||
460                          (blas_result_->size2!=other.columns())) ) {
461      gsl_matrix_free(blas_result_);
462      blas_result_=NULL;
463    }
464    if (!blas_result_) {
465      blas_result_ = gsl_matrix_alloc(rows(),other.columns());
466      if (!blas_result_)
467        throw utility::GSL_error("matrix::operator*= failed to allocate memory");
468    }
469    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, m_, other.proxy_m_, 0.0,
470                   blas_result_);
471    gsl_matrix* tmp=m_;
472    proxy_m_ = m_ = blas_result_;
473    blas_result_=tmp;
474    return *this;
475  }
476
477
478  const matrix& matrix::operator*=(const double d)
479  {
480    assert(m_);
481    gsl_matrix_scale(m_, d);
482    return *this;
483  }
484
485
486  bool isnull(const matrix& other)
487  {
488    return gsl_matrix_isnull(other.gsl_matrix_p());
489  }
490
491
492  double max(const matrix& other)
493  {
494    return gsl_matrix_max(other.gsl_matrix_p());
495  }
496
497
498  double min(const matrix& other)
499  {
500    return gsl_matrix_min(other.gsl_matrix_p());
501  }
502
503
504  void minmax_index(const matrix& other,
505                    std::pair<size_t,size_t>& min, std::pair<size_t,size_t>& max)
506  {
507    gsl_matrix_minmax_index(other.gsl_matrix_p(), &min.first, &min.second,
508                            &max.first, &max.second);
509  }
510
511
512  bool nan(const matrix& templat, matrix& flag)
513  {
514    size_t rows=templat.rows();
515    size_t columns=templat.columns();
516    if (rows!=flag.rows() && columns!=flag.columns())
517      flag.clone(matrix(rows,columns,1.0));
518    else
519      flag.set_all(1.0);
520    bool nan=false;
521    for (size_t i=0; i<rows; i++)
522      for (size_t j=0; j<columns; j++) 
523        if (std::isnan(templat(i,j))) {
524          flag(i,j)=0;
525          nan=true;
526        }
527    return nan;
528  }
529
530
531  void swap(matrix& a, matrix& b)
532  {
533    assert(a.gsl_matrix_p()); assert(b.gsl_matrix_p());
534    int status=gsl_matrix_swap(a.gsl_matrix_p(), b.gsl_matrix_p());
535    if (status)
536      throw utility::GSL_error(std::string("swap(matrix&,matrix&)",status));
537  }
538
539
540  std::ostream& operator<<(std::ostream& s, const matrix& m)
541  {
542    s.setf(std::ios::dec);
543    s.precision(12);
544    for(size_t i=0, j=0; i<m.rows(); i++)
545      for (j=0; j<m.columns(); j++) {
546        s << m(i,j);
547        if (j<m.columns()-1)
548          s << s.fill();
549        else if (i<m.rows()-1)
550          s << "\n";
551      }
552    return s;
553  }
554
555
556  vector operator*(const matrix& m, const vector& v)
557  {
558    utility::vector res(m.rows());
559    for (size_t i=0; i<res.size(); ++i)
560      res(i) = vector(m,i) * v;
561    return res;
562  }
563
564
565  vector operator*(const vector& v, const matrix& m)
566  {
567    utility::vector res(m.columns());
568    for (size_t i=0; i<res.size(); ++i)
569      res(i) = v * vector(m,i,false);
570    return res;
571  }
572
573}}} // of namespace utility, yat and thep
Note: See TracBrowser for help on using the repository browser.