source: trunk/yat/classifier/MatrixLookupWeighted.cc @ 1091

Last change on this file since 1091 was 1091, checked in by Peter, 16 years ago

fixes #267 iterator for MatrixLookupWeighted?

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.4 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "MatrixLookupWeighted.h"
26#include "MatrixLookup.h"
27#include "yat/utility/matrix.h"
28
29#include <algorithm>
30#include <cassert>
31#include <fstream>
32
33namespace theplu {
34namespace yat {
35namespace classifier {
36
37  MatrixLookupWeighted::MatrixLookupWeighted(const utility::matrix& data, 
38                                             const utility::matrix& weights,
39                                             const bool own)
40    : DataLookup2D(own), data_(&data), weights_(&weights), 
41      ref_count_weights_(NULL)
42  {
43    assert(data.rows()==weights.rows());
44    assert(data.columns()==weights.columns());
45    for(size_t i=0;i<(*data_).rows();i++)
46      row_index_.push_back(i);
47    for(size_t i=0;i<(*data_).columns();i++)
48      column_index_.push_back(i);
49  }
50
51
52  MatrixLookupWeighted::MatrixLookupWeighted(const utility::matrix& data)
53    : DataLookup2D(), data_(&data)
54  {
55    utility::matrix weights;
56    utility::nan(*data_,weights);
57    weights_= new utility::matrix(weights);
58    ref_count_weights_=new u_int(1);
59    for(size_t i=0;i<(*data_).rows();i++)
60      row_index_.push_back(i);
61    for(size_t i=0;i<(*data_).columns();i++)
62      column_index_.push_back(i);
63  }
64
65
66  MatrixLookupWeighted::MatrixLookupWeighted(const MatrixLookup& ml)
67    : DataLookup2D(ml), data_(ml.data_)
68  {
69    weights_= new utility::matrix(data_->rows(), data_->columns(), 1.0);
70    ref_count_weights_=new u_int(1);
71    ref_count_=ml.ref_count_;
72    if (ref_count_)
73      ++(*ref_count_);
74
75  }
76 
77
78  MatrixLookupWeighted::MatrixLookupWeighted(const utility::matrix& data, 
79                                             const utility::matrix& weights, 
80                                             const std::vector<size_t>& row, 
81                                             const std::vector<size_t>& col)
82    : DataLookup2D(row,col), data_(&data), weights_(&weights),
83      ref_count_weights_(NULL)
84  {
85    // Checking that each row index is less than data.rows()
86    assert(row.empty() || 
87           *(std::max_element(row.begin(),row.end()))<data.rows());
88    // Checking that each column index is less than data.column()
89    assert(col.empty() || 
90           *(std::max_element(col.begin(),col.end()))<data.columns());
91    // Checking that each row index is less than weights.rows()
92    assert(row.empty() || 
93           *(std::max_element(row.begin(),row.end()))<weights.rows());
94    // Checking that each column index is less than weights.column()
95    assert(col.empty() || 
96           *(std::max_element(col.begin(),col.end()))<weights.columns());
97  }
98 
99
100
101  MatrixLookupWeighted::MatrixLookupWeighted(const utility::matrix& data, 
102                                             const utility::matrix& weights, 
103                                             const std::vector<size_t>& index, 
104                                             const bool row)
105    : DataLookup2D(), data_(&data), weights_(&weights),
106      ref_count_weights_(NULL)
107  {
108    if (row){
109      // Checking that each row index is less than data.rows()
110      assert(index.empty() || 
111             *(std::max_element(index.begin(),index.end()))<data.rows());
112      // Checking that each row index is less than weights.rows()
113      assert(index.empty() || 
114             *(std::max_element(index.begin(),index.end()))<weights.rows());
115      row_index_=index;
116      assert(column_index_.empty());
117      column_index_.reserve(data.columns());
118      for (size_t i=0; i<data.columns(); i++)
119        column_index_.push_back(i);
120    }
121    else{
122      // Checking that each column index is less than data.column()
123      assert(index.empty() || 
124             *(std::max_element(index.begin(),index.end()))<data.columns());
125      // Checking that each column index is less than weights.column()
126      assert(index.empty() || 
127             *(std::max_element(index.begin(),index.end()))<weights.columns());
128      column_index_=index;
129      assert(row_index_.empty());
130      column_index_.reserve(data.rows());
131      for (size_t i=0; i<data.rows(); i++)
132        row_index_.push_back(i);
133    }
134  }
135 
136
137  /*
138  MatrixLookupWeighted::MatrixLookupWeighted(const MatrixLookup& dv,
139                                             const MatrixLookup& wv)
140    : DataLookup2D(dv), data_(dv.data_), weights_(dv.data_)
141  {
142  }
143  */
144
145
146  MatrixLookupWeighted::MatrixLookupWeighted(const MatrixLookupWeighted& other)
147    : DataLookup2D(other), data_(other.data_), weights_(other.weights_)
148  {
149    ref_count_ = other.ref_count_;
150    if (ref_count_)
151      ++(*ref_count_);
152    ref_count_weights_ = other.ref_count_weights_;
153    if (ref_count_weights_)
154      ++(*ref_count_weights_);
155
156  }
157
158
159
160  MatrixLookupWeighted::MatrixLookupWeighted(const MatrixLookupWeighted& other,
161                                             const std::vector<size_t>& row, 
162                                             const std::vector<size_t>& col)
163    : DataLookup2D(other,row,col), data_(other.data_), weights_(other.weights_)
164  {
165    ref_count_ = other.ref_count_;
166    if (ref_count_)
167      ++(*ref_count_);
168    ref_count_weights_ = other.ref_count_weights_;
169    if (ref_count_weights_)
170      ++(*ref_count_weights_);
171  }
172 
173
174
175  MatrixLookupWeighted::MatrixLookupWeighted(const MatrixLookupWeighted& other, 
176                                             const std::vector<size_t>& index, 
177                                             bool row)
178    : DataLookup2D(other,index,row), data_(other.data_), 
179      weights_(other.weights_)
180  {
181    ref_count_ = other.ref_count_;
182    if (ref_count_)
183      ++(*ref_count_);
184    ref_count_weights_ = other.ref_count_weights_;
185    if (ref_count_weights_)
186      ++(*ref_count_weights_);
187
188    // Checking that no index is out of range
189    assert(row_index_.empty() || 
190           *(max_element(row_index_.begin(), row_index_.end()))<data_->rows());
191    assert(column_index_.empty() || 
192           *(max_element(column_index_.begin(), column_index_.end()))<
193           data_->columns());
194    // Checking that no index is out of range
195    assert(row_index_.empty() || 
196           *(max_element(row_index_.begin(), row_index_.end()))<
197           weights_->rows());
198    assert(column_index_.empty() || 
199           *(max_element(column_index_.begin(), column_index_.end()))<
200           weights_->columns());
201  }
202 
203
204
205  MatrixLookupWeighted::MatrixLookupWeighted(const size_t rows, 
206                                             const size_t columns, 
207                                             const double value,
208                                             const double weight)
209    : DataLookup2D(rows,columns)
210  {
211    data_ = new utility::matrix(1,1,value);
212    ref_count_=new u_int(1);
213    weights_ = new utility::matrix(1,1,weight);
214    ref_count_weights_=new u_int(1);
215  }
216
217 
218  MatrixLookupWeighted::MatrixLookupWeighted(std::istream& is, char sep)
219    : DataLookup2D()
220  {
221    data_ = new utility::matrix(is,sep);
222    ref_count_=new u_int(1);
223    for(size_t i=0;i<(*data_).rows();i++)
224      row_index_.push_back(i);
225    for(size_t i=0;i<(*data_).columns();i++)
226      column_index_.push_back(i);
227    utility::matrix weights;
228    utility::nan(*data_,weights);
229    weights_= new utility::matrix(weights);
230    ref_count_weights_=new u_int(1);
231  }
232 
233
234  MatrixLookupWeighted::~MatrixLookupWeighted(void)
235  {
236    if (ref_count_)
237      if (!--(*ref_count_))
238        delete data_;
239    if (ref_count_weights_)
240      if (!--(*ref_count_weights_))
241        delete weights_;
242  }
243
244
245
246  MatrixLookupWeighted::const_iterator MatrixLookupWeighted::begin(void) const
247  {
248    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
249  }
250
251
252  MatrixLookupWeighted::const_iterator
253  MatrixLookupWeighted::begin_column(size_t i) const
254  {
255    return const_iterator(const_iterator::iterator_type(*this, 0, i),columns());
256  }
257
258
259  MatrixLookupWeighted::const_iterator
260  MatrixLookupWeighted::begin_row(size_t i) const
261  {
262    return const_iterator(const_iterator::iterator_type(*this, i, 0), 1);
263  }
264
265
266  double MatrixLookupWeighted::data(size_t row, size_t column) const
267  {
268    return (*data_)(row_index_[row], column_index_[column]);
269  }
270
271
272
273  MatrixLookupWeighted::const_iterator MatrixLookupWeighted::end(void) const
274  {
275    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
276  }
277
278
279  MatrixLookupWeighted::const_iterator
280  MatrixLookupWeighted::end_column(size_t i) const
281  {
282    return const_iterator(const_iterator::iterator_type(*this, rows(), i), 
283                          columns());
284  }
285
286
287  MatrixLookupWeighted::const_iterator
288  MatrixLookupWeighted::end_row(size_t i) const
289  {
290    return const_iterator(const_iterator::iterator_type(*this, i+1, 0), 1);
291  }
292
293
294  const MatrixLookupWeighted* 
295  MatrixLookupWeighted::selected(const std::vector<size_t>& i) const
296  { 
297    return new MatrixLookupWeighted(*this,i, true); 
298  }
299
300
301
302  const MatrixLookupWeighted* 
303  MatrixLookupWeighted::training_data(const std::vector<size_t>& i) const
304  { 
305    return new MatrixLookupWeighted(*this,i, false); 
306  }
307
308
309
310  const MatrixLookupWeighted* 
311  MatrixLookupWeighted::validation_data(const std::vector<size_t>& train,
312                                        const std::vector<size_t>& val) const
313  { 
314    return new MatrixLookupWeighted(*this,val, false); 
315  }
316
317
318
319  double MatrixLookupWeighted::weight(size_t row, size_t column) const
320  {
321    return (*weights_)(row_index_[row], column_index_[column]);
322  }
323
324
325
326  bool MatrixLookupWeighted::weighted(void) const 
327  {
328    return true;
329  }
330
331
332
333  double MatrixLookupWeighted::operator()(const size_t row,
334                                          const size_t column) const
335  { 
336    return (weight(row,column) ? data(row,column)*weight(row,column) : 0);
337  }
338
339
340
341  const MatrixLookupWeighted& MatrixLookupWeighted::operator=
342  (const MatrixLookupWeighted& other)
343  {
344    if (this!=&other){
345      if (ref_count_ && !--(*ref_count_))
346        delete data_;
347      if (ref_count_weights_ && !--(*ref_count_weights_))
348        delete weights_;
349      DataLookup2D::operator=(other);
350      data_ = other.data_;
351      ref_count_=other.ref_count_;
352      if (ref_count_)
353        ++(*ref_count_);
354      weights_ = other.weights_;
355      ref_count_weights_ = other.ref_count_weights_;
356      if (ref_count_weights_)
357        ++(*ref_count_weights_);
358    }
359    return *this;
360  }
361
362
363  std::ostream& operator<<(std::ostream& s, const MatrixLookupWeighted& m)
364  {
365    s.setf(std::ios::dec);
366    s.precision(12);
367    for(size_t i=0, j=0; i<m.rows(); i++)
368      for (j=0; j<m.columns(); j++) {
369        if (m.weight(i,j))
370          s << m.data(i,j);
371        if (j<m.columns()-1)
372          s << s.fill();
373        else if (i<m.rows()-1)
374          s << "\n";
375      }
376    return s;
377  }
378
379
380
381}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.