source: trunk/src/matrix.cc @ 12

Last change on this file since 12 was 12, checked in by daniel, 19 years ago

Matrix and vector APIs for GSL.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 4.7 KB
Line 
1
2#include <iostream>
3#include "matrix.h"
4
5using namespace thep_gsl_api;
6
7
8// Constructors and Destructors
9///////////////////////////////
10
11matrix::matrix() : m_( NULL )
12{
13}
14
15
16matrix::matrix( const size_t& rows, const size_t& cols, 
17    bool init_to_zero )
18{
19  if( init_to_zero )
20    { 
21     m_ = gsl_matrix_calloc( rows, cols );
22    }
23  else     
24    {
25     m_ = gsl_matrix_alloc ( rows, cols );
26    }
27}
28
29
30// Is this the way to do it? No copy ...
31// internal data ...
32matrix::matrix( gsl_matrix* m ) : m_( m )
33{
34}
35
36
37// Copy constructor
38matrix::matrix( const matrix& other )
39{
40  m_ = new_copy( other.get_gsl_matrix() );
41}
42
43
44matrix::~matrix()
45{
46  if( m_ ) 
47    {
48     gsl_matrix_free( m_ );
49     m_ = NULL;
50    }
51}
52
53
54gsl_matrix* matrix::new_copy( const gsl_matrix* p_other )
55{
56  // Get dimenstions
57  size_t rows = p_other->size1;
58  size_t cols = p_other->size2;
59 
60  // Create new empty matrix
61  gsl_matrix* p_res =  gsl_matrix_alloc ( rows, cols );
62
63  // Copy p_others elements into p_res
64  gsl_matrix_memcpy( p_res, p_other );
65
66  return p_res;
67}
68
69
70// Operators on matrices
71////////////////////////
72matrix& matrix::operator=( const matrix& other )
73{
74  if( this != &other ) 
75    {
76     gsl_matrix* v_new = new_copy( other.get_gsl_matrix() );
77     if( m_ ) gsl_matrix_free( m_ );
78     m_ = v_new;
79    }
80  return *this;
81} 
82
83
84matrix matrix::operator+( const matrix &other ) const
85{
86  assert( rows() == other.rows() &&
87    cols() == other.cols() );
88  matrix res( *this );
89  gsl_matrix_add( res.get_gsl_matrix(), 
90      other.get_gsl_matrix() );
91  return res;
92}
93
94
95matrix matrix::operator-( const matrix &other ) const
96{
97  assert( rows() == other.rows() &&
98    cols() == other.cols() );
99  matrix res( *this );
100  gsl_matrix_sub( res.get_gsl_matrix(), 
101      other.get_gsl_matrix() );
102  return res;
103}
104
105
106matrix matrix::operator*( const matrix &other ) const
107{
108  assert( rows() == other.cols() );
109  matrix res( rows(), other.cols() );
110  gsl_linalg_matmult( m_, other.get_gsl_matrix(), 
111          res.get_gsl_matrix() );
112  return res;
113}
114
115
116// Matrix vector multiplication
117vector matrix::operator*( const vector &other ) const
118{
119  assert( cols() == other.size() ); // NxM1 N1x1 ( M1 must equal N1 )
120  vector res( rows(), other.is_column() );
121
122  gsl_blas_dgemv( CblasNoTrans, 1.0, m_, other.get_gsl_vector(), 
123      0.0, res.get_gsl_vector() );
124
125  return res;
126}
127
128
129std::ostream& thep_gsl_api::operator<< ( std::ostream& s_out, 
130           const matrix& a )
131{
132  using namespace std;
133  s_out.setf( ios::fixed ); 
134  for( size_t i = 0, j = 0; i < a.rows(); ++i ) 
135    {
136     for ( j = 0; j < a.cols() - 1; ++j ) 
137       {
138  s_out << a.get( i, j ) << " ";
139       }
140     s_out << a.get( i, j ) << endl;
141    }
142   return s_out;
143}
144
145
146// Public functions (A-Z)
147/////////////////////////
148
149matrix matrix::transpose() const 
150{
151  matrix res( cols(), rows() );
152
153   for ( size_t i = 0; i < rows(); i++ ) 
154     {
155      for ( size_t j = 0; j < cols(); j++ ) 
156  {
157         gsl_matrix_set( res.get_gsl_matrix(), j, i, gsl_matrix_get( m_, i, j ) );
158  }
159     }
160   
161   return( res );
162}
163
164
165void matrix::swap_cols( const size_t& i, const size_t& j )
166{
167  gsl_matrix_swap_columns( m_, i, j );
168}
169
170
171void matrix::swap_rows( const size_t& i, const size_t& j )
172{
173  gsl_matrix_swap_rows( m_, i, j );
174}
175
176
177double matrix::norm( double n ) const 
178{
179  double sum = 0.0; 
180  for ( size_t i = 0; i < rows(); ++i ) 
181    {
182     for ( size_t j = 0; j < cols(); ++j ) 
183       {
184  sum += pow( (double)( get( i, j ) ), n );
185       }
186    }
187 
188  return pow( sum, 1 / n );
189}
190
191
192matrix matrix::row( const size_t& i ) const 
193{
194  matrix rowmatrix( 1, cols() );
195  gsl_vector *tempvector = gsl_vector_calloc( cols() );
196  assert( i >= 0 && i < rows() ); 
197  gsl_matrix_get_row( tempvector, m_, i );
198  gsl_matrix_set_row( rowmatrix.get_gsl_matrix(), 0, tempvector );
199  gsl_vector_free( tempvector );
200  return( rowmatrix );
201}
202
203
204matrix matrix::col( const size_t& i ) const 
205{
206  matrix columnmatrix( rows(), 1 );
207  gsl_vector *tempvector = gsl_vector_calloc( rows() );
208  assert( i >= 0 && i < cols() ); 
209  gsl_matrix_get_col( tempvector, m_, i );
210  gsl_matrix_set_col( columnmatrix.get_gsl_matrix(), 0, tempvector );
211  gsl_vector_free( tempvector );
212  return( columnmatrix );
213}
214
215
216
217double matrix::sum() const 
218{
219  double sum = 0;
220  for ( size_t i = 0; i < rows(); ++i ) 
221    {
222     for ( size_t j = 0; j < cols(); j++ ) 
223       {
224  sum += gsl_matrix_get( m_, i, j );
225       }
226    }
227  return( sum );
228}
229
230
231
232
233vector matrix::row_sum() const
234{
235  vector sum( rows() );
236  for ( size_t i = 0; i < rows(); ++i ) 
237    {
238     sum.set( i, row( i ).sum() );
239    }
240  return( sum );
241}
242
243
244vector matrix::mean_row_sum() const
245{
246  vector v_sum = row_sum();
247  for( size_t i = 0; i < rows(); ++i ) 
248    {
249     v_sum[ i ] = v_sum[ i ] / static_cast<double>( cols() );
250    }
251  return v_sum;
252}
Note: See TracBrowser for help on using the repository browser.