source: trunk/yat/classifier/KernelLookup.cc @ 2223

Last change on this file since 2223 was 2223, checked in by Peter, 13 years ago

fixes #543

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.5 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
5
6  This file is part of the yat library, http://dev.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 3 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with yat. If not, see <http://www.gnu.org/licenses/>.
20*/
21
22#include "KernelLookup.h"
23#include "MatrixLookup.h"
24#include "MatrixLookupWeighted.h"
25#include "yat/utility/Matrix.h"
26#include "yat/utility/MatrixWeighted.h"
27
28#include <cassert>
29
30namespace theplu {
31namespace yat {
32namespace classifier {
33
34  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
35    : kernel_(utility::SmartPtr<const Kernel>(&kernel, own))
36  {
37    column_index_ = utility::Index(kernel.size());
38    row_index_=column_index_;
39  }
40
41
42  KernelLookup::KernelLookup(const Kernel& kernel,
43                             const utility::Index& row, 
44                             const utility::Index& column,
45                             const bool owner)
46    : column_index_(column), 
47      kernel_(utility::SmartPtr<const Kernel>(&kernel, owner)),
48      row_index_(row)
49  {
50    // Checking that each row index is less than kernel.rows()
51    assert(validate(row_index_));
52    // Checking that each column index is less than kernel.column()
53    assert(validate(column_index_));
54  }
55
56
57  KernelLookup::KernelLookup(const KernelLookup& other, 
58                             const utility::Index& row, 
59                             const utility::Index& column)
60    : kernel_(other.kernel_)
61  {
62    row_index_ = utility::Index(other.row_index_, row);
63    column_index_ = utility::Index(other.column_index_, column);
64  }
65 
66
67  KernelLookup::KernelLookup(const KernelLookup& other)
68    : column_index_(other.column_index_), kernel_(other.kernel_), 
69      row_index_(other.row_index_)
70  {
71    // Checking that each row index is less than kernel.rows()
72    assert(validate(row_index_));
73    // Checking that each column index is less than kernel.column()
74    assert(validate(column_index_));
75  }
76 
77
78  KernelLookup::KernelLookup(const KernelLookup& other, 
79                             const utility::Index& index, 
80                             const bool row)
81    : kernel_(other.kernel_)
82  {
83    if (row){
84      row_index_ = utility::Index(other.row_index_, index);
85      column_index_= other.column_index_;
86    }
87    else{
88      column_index_ = utility::Index(other.column_index_, index);
89      row_index_= other.row_index_;
90    }
91    assert(kernel_->size());
92
93    // Checking that each row index is less than kernel.rows()
94    assert(validate(row_index_));
95    // Checking that each column index is less than kernel.column()
96    assert(validate(column_index_));
97  }
98 
99
100  KernelLookup::~KernelLookup(void)
101  {
102  }
103
104
105  KernelLookup::const_iterator KernelLookup::begin(void) const
106  {
107    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
108  }
109
110
111  KernelLookup::const_column_iterator
112  KernelLookup::begin_column(size_t i) const
113  {
114    return const_column_iterator(const_column_iterator::iterator_type(*this,
115                                                                      0,i), 
116                                 columns());
117  }
118
119
120  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
121  {
122    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
123  }
124
125
126  size_t KernelLookup::columns(void) const
127  {
128    return column_index_.size();
129  }
130
131
132  MatrixLookup KernelLookup::data(void) const
133  {
134    assert(!weighted());
135    return MatrixLookup(kernel_->data(), column_index_, false);
136  }
137
138
139  MatrixLookupWeighted KernelLookup::data_weighted(void) const
140  {
141    assert(weighted());
142    using utility::Index; // just to avoid long line
143    return MatrixLookupWeighted(kernel_->data_weighted(), 
144                                Index(kernel_->data_weighted().rows()), 
145                                column_index_);
146  }
147
148
149  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
150  {
151    return kernel_->element(vec, row_index_[i]);
152  }
153
154
155  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
156  {
157    return kernel_->element(vec, row_index_[i]);
158  }
159
160
161  KernelLookup::const_iterator KernelLookup::end(void) const
162  {
163    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
164  }
165
166
167  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
168  {
169    return const_column_iterator(const_column_iterator::iterator_type(*this, 
170                                                                      rows(),i),
171                                 columns());
172  }
173
174
175  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
176  {
177    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
178  }
179
180
181  size_t KernelLookup::rows(void) const
182  {
183    return row_index_.size();
184  }
185
186
187  KernelLookup KernelLookup::selected(const utility::Index& inputs) const
188  {
189    const Kernel* kernel;
190    if (kernel_->weighted()){
191      const MatrixLookupWeighted* ms = 
192        new MatrixLookupWeighted(data_weighted(),inputs, 
193                                 utility::Index(data_weighted().columns()));
194      kernel = kernel_->make_kernel(*ms, true);
195    }
196    else {
197      // matrix with selected features
198      const MatrixLookup* ms = new MatrixLookup(data(),inputs,true);
199      kernel = kernel_->make_kernel(*ms,true);
200    }
201    return KernelLookup(*kernel, true);
202  }
203
204
205  KernelLookup KernelLookup::test_kernel(const MatrixLookup& data) const
206  {
207    if (!weighted()){
208      assert(data.rows()==kernel_->data().rows());
209      utility::Matrix* data_all = 
210        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
211
212      for (size_t i=0; i<data_all->rows(); ++i) {
213
214        // first some columns from data in kernel_
215        for (size_t j=0; j<row_index_.size(); ++j){
216          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
217        }
218       
219        // last columns are equal to new data
220        for (size_t j=0;j<data.columns(); ++j){
221          (*data_all)(i,j+row_index_.size()) = data(i,j);
222        }
223      }
224      std::vector<size_t> column_index;
225      column_index.reserve(data.columns());
226      for (size_t i=0;i<data.columns(); ++i)
227        column_index.push_back(i+row_index_.size());
228
229      std::vector<size_t> row_index;
230      row_index.reserve(row_index_.size());
231      for (size_t i=0;i<row_index_.size(); ++i)
232        row_index.push_back(i);
233
234      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
235
236      const Kernel* kernel = 
237        kernel_->make_kernel(*tmp, true);
238
239      return KernelLookup(*kernel, utility::Index(row_index), 
240                          utility::Index(column_index), true);
241    }
242
243    assert(data.rows()==kernel_->data_weighted().rows());
244    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
245    // should hold a MatrixLookupweighted.
246    utility::MatrixWeighted* x_all = 
247      new utility::MatrixWeighted(data.rows(), rows()+data.columns());
248    const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
249
250    for (size_t i=0; i<data.rows(); ++i){
251
252      // first some columns from data in kernel_
253      for (size_t j=0; j<row_index_.size(); ++j){
254        (*x_all)(i,j) = kernel_data(i,row_index_[j]); 
255      }
256
257      // last columns are equal to new data
258      for (size_t j=0;j<data.columns(); ++j){
259        (*x_all)(i,j+row_index_.size()).data() = data(i,j);
260      }
261    }
262    std::vector<size_t> column_index;
263    column_index.reserve(data.columns());
264    for (size_t i=0;i<data.columns(); ++i)
265      column_index.push_back(i+row_index_.size());
266
267    std::vector<size_t> row_index;
268    row_index.reserve(row_index_.size());
269    for (size_t i=0;i<row_index_.size(); ++i)
270      row_index.push_back(i);
271
272    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*x_all, true);
273    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
274
275
276    return KernelLookup(*kernel, row_index_, 
277                        utility::Index(column_index), true);
278  }
279
280
281
282  KernelLookup KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
283  {
284    utility::MatrixWeighted* x_all = 
285      new utility::MatrixWeighted(data.rows(), rows()+data.columns());
286
287    if (weighted()){
288      const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
289   
290      for (size_t i=0; i<data.rows(); ++i){
291        // first columns are equal to data in kernel_
292        for (size_t j=0; j<row_index_.size(); ++j){
293          (*x_all)(i,j) = kernel_data(i,row_index_[j]);
294        }
295      }
296    }
297    else {
298
299      for (size_t i=0; i<data.rows(); ++i){
300        // first columns are equal to data in kernel_
301        for (size_t j=0; j<row_index_.size(); ++j)
302          (*x_all)(i,j).data() = kernel_->data()(i,row_index_[j]);
303      }
304    }
305
306    // last columns are equal to new data
307    for (size_t i=0; i<data.rows(); ++i){
308      for (size_t j=0;j<data.columns(); ++j){
309        (*x_all)(i,j+row_index_.size()) = data(i,j);
310      }
311    }
312   
313    std::vector<size_t> column_index;
314    column_index.reserve(data.columns());
315    for (size_t i=0;i<data.columns(); ++i)
316      column_index.push_back(i+row_index_.size());
317    const Kernel* kernel = 
318      kernel_->make_kernel(MatrixLookupWeighted(*x_all, true));
319    return KernelLookup(*kernel, row_index_, 
320                        utility::Index(column_index), true);
321  }
322
323
324  /*
325  const KernelLookup*
326  KernelLookup::training_data(const utility::Index& train) const
327  {
328    return new KernelLookup(*this,train,train);
329  }
330  */
331
332
333  bool KernelLookup::validate(const utility::Index& index) const
334  {
335    for (size_t i=0; i<index.size(); ++i)
336      if (index[i]>=kernel_->size())
337        return false;
338    return true;
339  }
340
341
342  bool KernelLookup::weighted(void) const
343  {
344    return kernel_->weighted();
345  }
346
347
348  KernelLookup::const_reference
349  KernelLookup::operator()(size_t row, size_t column) const
350  {
351    return (*kernel_)(row_index_[row],column_index_[column]);
352  }
353
354}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.