source: trunk/yat/utility/matrix.cc @ 1009

Last change on this file since 1009 was 1009, checked in by Peter, 13 years ago

merging branch peter-dev into trunk delta 1008:994

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.0 KB
Line 
1// $Id: matrix.cc 1009 2008-02-01 04:15:44Z peter $
2
3/*
4  Copyright (C) 2003 Daniel Dalevi, Peter Johansson
5  Copyright (C) 2004 Jari Häkkinen, Peter Johansson
6  Copyright (C) 2005, 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
7  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
8
9  This file is part of the yat library, http://trac.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 2 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with this program; if not, write to the Free Software
23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
24  02111-1307, USA.
25*/
26
27#include "matrix.h"
28#include "vector.h"
29#include "vectorView.h"
30#include "utility.h"
31
32#include <cassert>
33#include <cmath>
34#include <sstream>
35#include <vector>
36
37#include <gsl/gsl_blas.h>
38
39namespace theplu {
40namespace yat {
41namespace utility {
42
43
44  matrix::matrix(void)
45    : blas_result_(NULL), m_(NULL), view_(NULL), view_const_(NULL),
46      proxy_m_(NULL)
47  {
48  }
49
50
51  matrix::matrix(const size_t& r, const size_t& c, double init_value)
52    : blas_result_(NULL), m_(gsl_matrix_alloc(r,c)), view_(NULL),
53      view_const_(NULL), proxy_m_(m_)
54  {
55    if (!m_)
56      throw utility::GSL_error("matrix::matrix failed to allocate memory");
57    all(init_value);
58  }
59
60
61  matrix::matrix(const matrix& o)
62    : blas_result_(NULL), m_(o.create_gsl_matrix_copy()), view_(NULL),
63      view_const_(NULL), proxy_m_(m_)
64  {
65  }
66
67
68  matrix::matrix(matrix& m, size_t offset_row, size_t offset_column,
69                 size_t n_row, size_t n_column)
70    : blas_result_(NULL), view_const_(NULL)
71  {
72    view_ = new gsl_matrix_view(gsl_matrix_submatrix(m.m_,
73                                                     offset_row,offset_column,
74                                                     n_row,n_column));
75    if (!view_)
76      throw utility::GSL_error("matrix::matrix failed to setup view");
77    proxy_m_ = m_ = &(view_->matrix);
78  }
79
80
81  // Constructor that gets data from istream
82  matrix::matrix(std::istream& is, char sep) 
83    throw (utility::IO_error,std::exception)
84    : blas_result_(NULL), view_(NULL), view_const_(NULL)
85  {
86    // read the data file and store in stl vectors (dynamically
87    // expandable)
88    std::vector<std::vector<double> > data_matrix;
89    u_int nof_columns=0;
90    u_int nof_rows = 0;
91    std::string line;
92    while(getline(is, line, '\n')){
93      // Ignoring empty lines
94      if (!line.size()) {
95        continue;
96      }
97      nof_rows++;
98      std::vector<double> v;
99      std::string element;
100      std::stringstream ss(line);
101     
102      bool ok=true;
103      while(ok) {
104        if(sep=='\0')
105          ok=(ss>>element);
106        else
107          ok=getline(ss, element, sep);
108        if(!ok)
109          break;
110       
111        if(utility::is_double(element)) {
112          v.push_back(atof(element.c_str()));
113        }
114        else if (!element.size() || utility::is_nan(element)) {
115          v.push_back(std::numeric_limits<double>::quiet_NaN());
116        }
117        else {
118          std::stringstream ss("Warning: '");
119          ss << element << "' is not accepted as a matrix element.";
120          throw IO_error(ss.str());
121        }
122      }           
123      if(sep!='\0' && line[line.size()-1]==sep) // add NaN for final separator
124          v.push_back(std::numeric_limits<double>::quiet_NaN());
125      if (!nof_columns)
126        nof_columns=v.size();
127      else if (v.size()!=nof_columns) {
128        std::ostringstream s;
129        s << "matrix::matrix(std::istream&, char) data file error: "
130          << "line " << nof_rows << " has " << v.size()
131          << " columns; expected " << nof_columns << " columns.";
132        throw utility::IO_error(s.str());
133      }
134      data_matrix.push_back(v);
135    }
136
137    // manipulate the state of the stream to be good
138    is.clear(std::ios::goodbit);
139    // convert the data to a gsl matrix
140    proxy_m_ = m_ = gsl_matrix_alloc ( nof_rows, nof_columns );
141    if (!m_)
142      throw utility::GSL_error("matrix::matrix failed to allocate memory");
143
144    // if gsl error handler disabled, out of bounds index will not
145    // abort the program.
146    for(u_int i=0;i<nof_rows;i++)
147      for(u_int j=0;j<nof_columns;j++)
148        gsl_matrix_set( m_, i, j, data_matrix[i][j] );
149  }
150
151
152  matrix::~matrix(void)
153  {
154    delete_allocated_memory();
155    if (blas_result_)
156      gsl_matrix_free(blas_result_);
157  }
158
159
160  const matrix& matrix::clone(const matrix& other)
161  {
162    if (this!=&other) {
163
164      delete_allocated_memory();
165
166      if (other.view_) {
167        view_ = new gsl_matrix_view(*other.view_);
168        proxy_m_ = m_ = &(view_->matrix);
169      }
170      else if (other.view_const_) {
171        view_const_ = new gsl_matrix_const_view(*other.view_const_);
172        proxy_m_ = &(view_const_->matrix);
173      }
174      else if (other.m_)
175        proxy_m_ = m_ = other.create_gsl_matrix_copy();
176
177      // no need to delete blas_result_ if the number of rows fit, it
178      // may be useful later.
179      if (blas_result_ && (blas_result_->size1!=rows())) {
180        gsl_matrix_free(blas_result_);
181        blas_result_=NULL;
182      }
183    }
184    return *this;
185  } 
186
187
188  size_t matrix::columns(void) const
189  {
190    if (!proxy_m_)
191      return 0;
192    return proxy_m_->size2;
193  }
194
195
196  gsl_matrix* matrix::create_gsl_matrix_copy(void) const
197  {
198    gsl_matrix* m = gsl_matrix_alloc(rows(),columns());
199    if (!m)
200      throw utility::GSL_error("matrix::create_gsl_matrix_copy failed to allocate memory");
201    if (gsl_matrix_memcpy(m,proxy_m_))
202      throw utility::GSL_error("matrix::create_gsl_matrix_copy dimension mis-match");
203    return m;
204  }
205
206
207  void matrix::delete_allocated_memory(void)
208  {
209    if (view_)
210      delete view_;
211    else if (view_const_)
212      delete view_const_;
213    else if (m_)
214      gsl_matrix_free(m_);
215    blas_result_=NULL;
216    proxy_m_=m_=NULL;
217  }
218
219
220  void matrix::div(const matrix& other)
221  {
222    assert(m_);
223    int status=gsl_matrix_div_elements(m_, other.gsl_matrix_p());
224    if (status)
225      throw utility::GSL_error(std::string("matrix::div_elements",status));
226  }
227
228
229  bool matrix::equal(const matrix& other, const double d) const
230  {
231    if (this==&other)
232      return true;
233    if (columns()!=other.columns() || rows()!=other.rows())
234      return false;
235    for (size_t i=0; i<rows(); i++)
236      for (size_t j=0; j<columns(); j++)
237        // The two last condition checks are needed for NaN detection
238        if (fabs( (*this)(i,j)-other(i,j) ) > d ||
239            (*this)(i,j)!=(*this)(i,j) || other(i,j)!=other(i,j))
240          return false;
241    return true;
242  }
243
244
245  const gsl_matrix* matrix::gsl_matrix_p(void) const
246  {
247    return proxy_m_;
248  }
249
250
251  gsl_matrix* matrix::gsl_matrix_p(void)
252  {
253    return m_;
254  }
255
256
257  bool matrix::isview(void) const
258  {
259    return view_ || view_const_;
260  }
261
262
263  void matrix::mul(const matrix& other)
264  {
265    assert(m_);
266    int status=gsl_matrix_mul_elements(m_, other.gsl_matrix_p());
267    if (status)
268      throw utility::GSL_error(std::string("matrix::mul_elements",status));
269  }
270
271
272  void matrix::resize(size_t r, size_t c, double init_value)
273  {
274    delete_allocated_memory();
275
276    proxy_m_ = m_ = gsl_matrix_alloc(r,c);
277    if (!m_)
278      throw utility::GSL_error("matrix::matrix failed to allocate memory");
279    all(init_value);
280
281    // no need to delete blas_result_ if the number of rows fit, it
282    // may be useful later.
283    if (blas_result_ && (blas_result_->size1!=rows())) {
284      gsl_matrix_free(blas_result_);
285      blas_result_=NULL;
286    }
287  }
288
289
290  size_t matrix::rows(void) const
291  {
292    if (!proxy_m_)
293      return 0;
294    return proxy_m_->size1;
295  }
296
297
298  void matrix::all(const double value)
299  {
300    assert(m_);
301    gsl_matrix_set_all(m_, value);
302  }
303
304
305  vectorView matrix::column_vec(size_t col)
306  {
307    vectorView res(*this, col, false);
308    return res;
309  }
310
311
312  const vectorView matrix::column_vec(size_t col) const
313  {
314    return vectorView(*this, col, false);
315  }
316
317
318  const vectorView matrix::row_vec(size_t col) const
319  {
320    return vectorView(*this, col, true);
321  }
322
323
324  vectorView matrix::row_vec(size_t row)
325  {
326    vectorView res(*this, row, true);
327    return res;
328  }
329
330
331  void matrix::swap_columns(const size_t i, const size_t j)
332  {
333    assert(m_);
334    int status=gsl_matrix_swap_columns(m_, i, j);
335    if (status)
336      throw utility::GSL_error(std::string("matrix::swap_columns",status));
337  }
338
339
340  void matrix::swap_rowcol(const size_t i, const size_t j)
341  {
342    assert(m_);
343    int status=gsl_matrix_swap_rowcol(m_, i, j);
344    if (status)
345      throw utility::GSL_error(std::string("matrix::swap_rowcol",status));
346  }
347
348
349  void matrix::swap_rows(const size_t i, const size_t j)
350  {
351    assert(m_);
352    int status=gsl_matrix_swap_rows(m_, i, j);
353    if (status)
354      throw utility::GSL_error(std::string("matrix::swap_rows",status));
355  }
356
357
358  void matrix::transpose(void)
359  {
360    assert(m_);
361    if (columns()==rows())
362      gsl_matrix_transpose(m_); // this never fails
363    else {
364      gsl_matrix* transposed = gsl_matrix_alloc(columns(),rows());
365      if (!transposed)
366        throw utility::GSL_error("matrix::transpose failed to allocate memory");
367      // next line never fails if allocation above succeeded.
368      gsl_matrix_transpose_memcpy(transposed,m_);
369      gsl_matrix_free(m_);
370      proxy_m_ = m_ = transposed;
371      if (blas_result_) {
372        gsl_matrix_free(blas_result_);
373        blas_result_=NULL;
374      }
375    }
376  }
377
378
379  double& matrix::operator()(size_t row, size_t column)
380  {
381    assert(m_);
382    assert(row<rows());
383    assert(column<columns());
384    double* d=gsl_matrix_ptr(m_, row, column);
385    if (!d)
386      throw utility::GSL_error("matrix::operator()",GSL_EINVAL);
387    return *d;
388  }
389
390
391  const double& matrix::operator()(size_t row, size_t column) const
392  {
393    assert(row<rows());
394    assert(column<columns());
395    const double* d=gsl_matrix_const_ptr(proxy_m_, row, column);
396    if (!d)
397      throw utility::GSL_error("matrix::operator()",GSL_EINVAL);
398    return *d;
399  }
400
401
402  bool matrix::operator==(const matrix& other) const
403  {
404    return equal(other);
405  }
406
407
408  bool matrix::operator!=(const matrix& other) const
409  {
410    return !equal(other);
411  }
412
413
414  const matrix& matrix::operator=( const matrix& other )
415  {
416    assert(m_);
417    if (this!=&other)
418      if (gsl_matrix_memcpy(m_, other.gsl_matrix_p()))
419        throw utility::GSL_error("matrix::create_gsl_matrix_copy dimension mis-match");
420    return *this;
421  }
422
423
424  const matrix& matrix::operator+=(const matrix& other)
425  {
426    assert(m_);
427    int status=gsl_matrix_add(m_, other.proxy_m_);
428    if (status)
429      throw utility::GSL_error(std::string("matrix::operator+=", status));
430    return *this;
431  }
432
433
434  const matrix& matrix::operator+=(const double d)
435  {
436    assert(m_);
437    gsl_matrix_add_constant(m_, d);
438    return *this;
439  }
440
441
442  const matrix& matrix::operator-=(const matrix& other)
443  {
444    assert(m_);
445    int status=gsl_matrix_sub(m_, other.proxy_m_);
446    if (status)
447      throw utility::GSL_error(std::string("matrix::operator-=", status));
448    return *this;
449  }
450
451
452  const matrix& matrix::operator-=(const double d)
453  {
454    assert(m_);
455    gsl_matrix_add_constant(m_, -d);
456    return *this;
457  }
458
459
460  const matrix& matrix::operator*=(const matrix& other)
461  {
462    assert(m_);
463    if ( blas_result_ && ((blas_result_->size1!=rows()) ||
464                          (blas_result_->size2!=other.columns())) ) {
465      gsl_matrix_free(blas_result_);
466      blas_result_=NULL;
467    }
468    if (!blas_result_) {
469      blas_result_ = gsl_matrix_alloc(rows(),other.columns());
470      if (!blas_result_)
471        throw utility::GSL_error("matrix::operator*= failed to allocate memory");
472    }
473    gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, m_, other.proxy_m_, 0.0,
474                   blas_result_);
475    gsl_matrix* tmp=m_;
476    proxy_m_ = m_ = blas_result_;
477    blas_result_=tmp;
478    return *this;
479  }
480
481
482  const matrix& matrix::operator*=(const double d)
483  {
484    assert(m_);
485    gsl_matrix_scale(m_, d);
486    return *this;
487  }
488
489
490  bool isnull(const matrix& other)
491  {
492    return gsl_matrix_isnull(other.gsl_matrix_p());
493  }
494
495
496  double max(const matrix& other)
497  {
498    return gsl_matrix_max(other.gsl_matrix_p());
499  }
500
501
502  double min(const matrix& other)
503  {
504    return gsl_matrix_min(other.gsl_matrix_p());
505  }
506
507
508  void minmax_index(const matrix& other,
509                    std::pair<size_t,size_t>& min, std::pair<size_t,size_t>& max)
510  {
511    gsl_matrix_minmax_index(other.gsl_matrix_p(), &min.first, &min.second,
512                            &max.first, &max.second);
513  }
514
515
516  bool nan(const matrix& templat, matrix& flag)
517  {
518    size_t rows=templat.rows();
519    size_t columns=templat.columns();
520    if (rows!=flag.rows() && columns!=flag.columns())
521      flag.clone(matrix(rows,columns,1.0));
522    else
523      flag.all(1.0);
524    bool nan=false;
525    for (size_t i=0; i<rows; i++)
526      for (size_t j=0; j<columns; j++) 
527        if (std::isnan(templat(i,j))) {
528          flag(i,j)=0;
529          nan=true;
530        }
531    return nan;
532  }
533
534
535  void swap(matrix& a, matrix& b)
536  {
537    assert(a.gsl_matrix_p()); assert(b.gsl_matrix_p());
538    int status=gsl_matrix_swap(a.gsl_matrix_p(), b.gsl_matrix_p());
539    if (status)
540      throw utility::GSL_error(std::string("swap(matrix&,matrix&)",status));
541  }
542
543
544  std::ostream& operator<<(std::ostream& s, const matrix& m)
545  {
546    s.setf(std::ios::dec);
547    s.precision(12);
548    for(size_t i=0, j=0; i<m.rows(); i++)
549      for (j=0; j<m.columns(); j++) {
550        s << m(i,j);
551        if (j<m.columns()-1)
552          s << s.fill();
553        else if (i<m.rows()-1)
554          s << "\n";
555      }
556    return s;
557  }
558
559
560  vector operator*(const matrix& m, const vectorBase& v)
561  {
562    utility::vector res(m.rows());
563    for (size_t i=0; i<res.size(); ++i)
564      res(i) = vectorView(m,i) * v;
565    return res;
566  }
567
568
569  vector operator*(const vector& v, const matrix& m)
570  {
571    utility::vector res(m.columns());
572    for (size_t i=0; i<res.size(); ++i)
573      res(i) = v * vectorView(m,i,false);
574    return res;
575  }
576
577}}} // of namespace utility, yat and thep
Note: See TracBrowser for help on using the repository browser.