source: trunk/yat/classifier/KernelLookup.cc @ 2119

Last change on this file since 2119 was 2119, checked in by Peter, 13 years ago

converted files to utf-8. fixes #577

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.3 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
5
6  This file is part of the yat library, http://dev.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 3 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with yat. If not, see <http://www.gnu.org/licenses/>.
20*/
21
22#include "KernelLookup.h"
23#include "MatrixLookup.h"
24#include "MatrixLookupWeighted.h"
25#include "yat/utility/Matrix.h"
26#include "yat/utility/MatrixWeighted.h"
27
28#include <cassert>
29
30namespace theplu {
31namespace yat {
32namespace classifier {
33
34  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
35    : kernel_(utility::SmartPtr<const Kernel>(&kernel, own))
36  {
37    column_index_ = utility::Index(kernel.size());
38    row_index_=column_index_;
39  }
40
41
42  KernelLookup::KernelLookup(const Kernel& kernel,
43                             const utility::Index& row, 
44                             const utility::Index& column,
45                             const bool owner)
46    : column_index_(column), 
47      kernel_(utility::SmartPtr<const Kernel>(&kernel, owner)),
48      row_index_(row)
49  {
50    // Checking that each row index is less than kernel.rows()
51    assert(validate(row_index_));
52    // Checking that each column index is less than kernel.column()
53    assert(validate(column_index_));
54  }
55
56
57  KernelLookup::KernelLookup(const KernelLookup& other, 
58                             const utility::Index& row, 
59                             const utility::Index& column)
60    : kernel_(other.kernel_)
61  {
62    row_index_ = utility::Index(other.row_index_, row);
63    column_index_ = utility::Index(other.column_index_, column);
64  }
65 
66
67  KernelLookup::KernelLookup(const KernelLookup& other)
68    : column_index_(other.column_index_), kernel_(other.kernel_), 
69      row_index_(other.row_index_)
70  {
71    // Checking that each row index is less than kernel.rows()
72    assert(validate(row_index_));
73    // Checking that each column index is less than kernel.column()
74    assert(validate(column_index_));
75  }
76 
77
78  KernelLookup::KernelLookup(const KernelLookup& other, 
79                             const utility::Index& index, 
80                             const bool row)
81    : kernel_(other.kernel_)
82  {
83    if (row){
84      row_index_ = utility::Index(other.row_index_, index);
85      column_index_= other.column_index_;
86    }
87    else{
88      column_index_ = utility::Index(other.column_index_, index);
89      row_index_= other.row_index_;
90    }
91    assert(kernel_->size());
92
93    // Checking that each row index is less than kernel.rows()
94    assert(validate(row_index_));
95    // Checking that each column index is less than kernel.column()
96    assert(validate(column_index_));
97  }
98 
99
100  KernelLookup::~KernelLookup(void)
101  {
102  }
103
104
105  KernelLookup::const_iterator KernelLookup::begin(void) const
106  {
107    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
108  }
109
110
111  KernelLookup::const_column_iterator
112  KernelLookup::begin_column(size_t i) const
113  {
114    return const_column_iterator(const_column_iterator::iterator_type(*this,
115                                                                      0,i), 
116                                 columns());
117  }
118
119
120  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
121  {
122    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
123  }
124
125
126  size_t KernelLookup::columns(void) const
127  {
128    return column_index_.size();
129  }
130
131
132  MatrixLookup KernelLookup::data(void) const
133  {
134    assert(!weighted());
135    return MatrixLookup(kernel_->data(), column_index_, false);
136  }
137
138
139  MatrixLookupWeighted KernelLookup::data_weighted(void) const
140  {
141    assert(weighted());
142    return MatrixLookupWeighted(kernel_->data_weighted(),column_index_,false);
143  }
144
145
146  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
147  {
148    return kernel_->element(vec, row_index_[i]);
149  }
150
151
152  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
153  {
154    return kernel_->element(vec, row_index_[i]);
155  }
156
157
158  KernelLookup::const_iterator KernelLookup::end(void) const
159  {
160    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
161  }
162
163
164  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
165  {
166    return const_column_iterator(const_column_iterator::iterator_type(*this, 
167                                                                      rows(),i),
168                                 columns());
169  }
170
171
172  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
173  {
174    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
175  }
176
177
178  size_t KernelLookup::rows(void) const
179  {
180    return row_index_.size();
181  }
182
183
184  KernelLookup KernelLookup::selected(const utility::Index& inputs) const
185  {
186    const Kernel* kernel;
187    if (kernel_->weighted()){
188      const MatrixLookupWeighted* ms = 
189        new MatrixLookupWeighted(data_weighted(),inputs,true);
190      kernel = kernel_->make_kernel(*ms, true);
191    }
192    else {
193      // matrix with selected features
194      const MatrixLookup* ms = new MatrixLookup(data(),inputs,true);
195      kernel = kernel_->make_kernel(*ms,true);
196    }
197    return KernelLookup(*kernel, true);
198  }
199
200
201  KernelLookup KernelLookup::test_kernel(const MatrixLookup& data) const
202  {
203    if (!weighted()){
204      assert(data.rows()==kernel_->data().rows());
205      utility::Matrix* data_all = 
206        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
207
208      for (size_t i=0; i<data_all->rows(); ++i) {
209
210        // first some columns from data in kernel_
211        for (size_t j=0; j<row_index_.size(); ++j){
212          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
213        }
214       
215        // last columns are equal to new data
216        for (size_t j=0;j<data.columns(); ++j){
217          (*data_all)(i,j+row_index_.size()) = data(i,j);
218        }
219      }
220      std::vector<size_t> column_index;
221      column_index.reserve(data.columns());
222      for (size_t i=0;i<data.columns(); ++i)
223        column_index.push_back(i+row_index_.size());
224
225      std::vector<size_t> row_index;
226      row_index.reserve(row_index_.size());
227      for (size_t i=0;i<row_index_.size(); ++i)
228        row_index.push_back(i);
229
230      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
231
232      const Kernel* kernel = 
233        kernel_->make_kernel(*tmp, true);
234
235      return KernelLookup(*kernel, utility::Index(row_index), 
236                          utility::Index(column_index), true);
237    }
238
239    assert(data.rows()==kernel_->data_weighted().rows());
240    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
241    // should hold a MatrixLookupweighted.
242    utility::MatrixWeighted* x_all = 
243      new utility::MatrixWeighted(data.rows(), rows()+data.columns());
244    const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
245
246    for (size_t i=0; i<data.rows(); ++i){
247
248      // first some columns from data in kernel_
249      for (size_t j=0; j<row_index_.size(); ++j){
250        (*x_all)(i,j) = kernel_data(i,row_index_[j]); 
251      }
252
253      // last columns are equal to new data
254      for (size_t j=0;j<data.columns(); ++j){
255        (*x_all)(i,j+row_index_.size()).data() = data(i,j);
256      }
257    }
258    std::vector<size_t> column_index;
259    column_index.reserve(data.columns());
260    for (size_t i=0;i<data.columns(); ++i)
261      column_index.push_back(i+row_index_.size());
262
263    std::vector<size_t> row_index;
264    row_index.reserve(row_index_.size());
265    for (size_t i=0;i<row_index_.size(); ++i)
266      row_index.push_back(i);
267
268    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*x_all, true);
269    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
270
271
272    return KernelLookup(*kernel, row_index_, 
273                        utility::Index(column_index), true);
274  }
275
276
277
278  KernelLookup KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
279  {
280    utility::MatrixWeighted* x_all = 
281      new utility::MatrixWeighted(data.rows(), rows()+data.columns());
282
283    if (weighted()){
284      const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
285   
286      for (size_t i=0; i<data.rows(); ++i){
287        // first columns are equal to data in kernel_
288        for (size_t j=0; j<row_index_.size(); ++j){
289          (*x_all)(i,j) = kernel_data(i,row_index_[j]);
290        }
291      }
292    }
293    else {
294
295      for (size_t i=0; i<data.rows(); ++i){
296        // first columns are equal to data in kernel_
297        for (size_t j=0; j<row_index_.size(); ++j)
298          (*x_all)(i,j).data() = kernel_->data()(i,row_index_[j]);
299      }
300    }
301
302    // last columns are equal to new data
303    for (size_t i=0; i<data.rows(); ++i){
304      for (size_t j=0;j<data.columns(); ++j){
305        (*x_all)(i,j+row_index_.size()) = data(i,j);
306      }
307    }
308   
309    std::vector<size_t> column_index;
310    column_index.reserve(data.columns());
311    for (size_t i=0;i<data.columns(); ++i)
312      column_index.push_back(i+row_index_.size());
313    const Kernel* kernel = 
314      kernel_->make_kernel(MatrixLookupWeighted(*x_all, true));
315    return KernelLookup(*kernel, row_index_, 
316                        utility::Index(column_index), true);
317  }
318
319
320  /*
321  const KernelLookup*
322  KernelLookup::training_data(const utility::Index& train) const
323  {
324    return new KernelLookup(*this,train,train);
325  }
326  */
327
328
329  bool KernelLookup::validate(const utility::Index& index) const
330  {
331    for (size_t i=0; i<index.size(); ++i)
332      if (index[i]>=kernel_->size())
333        return false;
334    return true;
335  }
336
337
338  bool KernelLookup::weighted(void) const
339  {
340    return kernel_->weighted();
341  }
342
343
344  KernelLookup::const_reference
345  KernelLookup::operator()(size_t row, size_t column) const
346  {
347    return (*kernel_)(row_index_[row],column_index_[column]);
348  }
349
350}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.