source: trunk/yat/classifier/KernelLookup.cc @ 1581

Last change on this file since 1581 was 1581, checked in by Peter, 13 years ago

refs #396 - fixing in KernelLookup?

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.3 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2008 Peter Johansson
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include "KernelLookup.h"
24#include "MatrixLookup.h"
25#include "MatrixLookupWeighted.h"
26#include "yat/utility/Matrix.h"
27#include "yat/utility/MatrixWeighted.h"
28
29#include <cassert>
30
31namespace theplu {
32namespace yat {
33namespace classifier {
34
35  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
36    : kernel_(utility::SmartPtr<const Kernel>(&kernel, own))
37  {
38    column_index_ = utility::Index(kernel.size());
39    row_index_=column_index_;
40  }
41
42
43  KernelLookup::KernelLookup(const Kernel& kernel,
44                             const utility::Index& row, 
45                             const utility::Index& column,
46                             const bool owner)
47    : column_index_(column), 
48      kernel_(utility::SmartPtr<const Kernel>(&kernel, owner)),
49      row_index_(row)
50  {
51    // Checking that each row index is less than kernel.rows()
52    assert(validate(row_index_));
53    // Checking that each column index is less than kernel.column()
54    assert(validate(column_index_));
55  }
56
57
58  KernelLookup::KernelLookup(const KernelLookup& other, 
59                             const utility::Index& row, 
60                             const utility::Index& column)
61    : kernel_(other.kernel_)
62  {
63    row_index_ = utility::Index(other.row_index_, row);
64    column_index_ = utility::Index(other.column_index_, column);
65  }
66 
67
68  KernelLookup::KernelLookup(const KernelLookup& other)
69    : column_index_(other.column_index_), kernel_(other.kernel_), 
70      row_index_(other.row_index_)
71  {
72    // Checking that each row index is less than kernel.rows()
73    assert(validate(row_index_));
74    // Checking that each column index is less than kernel.column()
75    assert(validate(column_index_));
76  }
77 
78
79  KernelLookup::KernelLookup(const KernelLookup& other, 
80                             const utility::Index& index, 
81                             const bool row)
82    : kernel_(other.kernel_)
83  {
84    if (row){
85      row_index_ = utility::Index(other.row_index_, index);
86      column_index_= other.column_index_;
87    }
88    else{
89      column_index_ = utility::Index(other.column_index_, index);
90      row_index_= other.row_index_;
91    }
92    assert(kernel_->size());
93
94    // Checking that each row index is less than kernel.rows()
95    assert(validate(row_index_));
96    // Checking that each column index is less than kernel.column()
97    assert(validate(column_index_));
98  }
99 
100
101  KernelLookup::~KernelLookup(void)
102  {
103  }
104
105
106  KernelLookup::const_iterator KernelLookup::begin(void) const
107  {
108    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
109  }
110
111
112  KernelLookup::const_column_iterator
113  KernelLookup::begin_column(size_t i) const
114  {
115    return const_column_iterator(const_column_iterator::iterator_type(*this,
116                                                                      0,i), 
117                                 columns());
118  }
119
120
121  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
122  {
123    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
124  }
125
126
127  size_t KernelLookup::columns(void) const
128  {
129    return column_index_.size();
130  }
131
132
133  MatrixLookup KernelLookup::data(void) const
134  {
135    assert(!weighted());
136    return MatrixLookup(kernel_->data(), column_index_, false);
137  }
138
139
140  MatrixLookupWeighted KernelLookup::data_weighted(void) const
141  {
142    assert(weighted());
143    return MatrixLookupWeighted(kernel_->data_weighted(),column_index_,false);
144  }
145
146
147  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
148  {
149    return kernel_->element(vec, row_index_[i]);
150  }
151
152
153  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
154  {
155    return kernel_->element(vec, row_index_[i]);
156  }
157
158
159  KernelLookup::const_iterator KernelLookup::end(void) const
160  {
161    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
162  }
163
164
165  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
166  {
167    return const_column_iterator(const_column_iterator::iterator_type(*this, 
168                                                                      rows(),i),
169                                 columns());
170  }
171
172
173  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
174  {
175    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
176  }
177
178
179  size_t KernelLookup::rows(void) const
180  {
181    return row_index_.size();
182  }
183
184
185  KernelLookup KernelLookup::selected(const utility::Index& inputs) const
186  {
187    const Kernel* kernel;
188    if (kernel_->weighted()){
189      const MatrixLookupWeighted* ms = 
190        new MatrixLookupWeighted(data_weighted(),inputs,true);
191      kernel = kernel_->make_kernel(*ms, true);
192    }
193    else {
194      // matrix with selected features
195      const MatrixLookup* ms = new MatrixLookup(data(),inputs,true);
196      kernel = kernel_->make_kernel(*ms,true);
197    }
198    return KernelLookup(*kernel, true);
199  }
200
201
202  KernelLookup KernelLookup::test_kernel(const MatrixLookup& data) const
203  {
204    if (!weighted()){
205      assert(data.rows()==kernel_->data().rows());
206      utility::Matrix* data_all = 
207        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
208
209      for (size_t i=0; i<data_all->rows(); ++i) {
210
211        // first some columns from data in kernel_
212        for (size_t j=0; j<row_index_.size(); ++j){
213          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
214        }
215       
216        // last columns are equal to new data
217        for (size_t j=0;j<data.columns(); ++j){
218          (*data_all)(i,j+row_index_.size()) = data(i,j);
219        }
220      }
221      std::vector<size_t> column_index;
222      column_index.reserve(data.columns());
223      for (size_t i=0;i<data.columns(); ++i)
224        column_index.push_back(i+row_index_.size());
225
226      std::vector<size_t> row_index;
227      row_index.reserve(row_index_.size());
228      for (size_t i=0;i<row_index_.size(); ++i)
229        row_index.push_back(i);
230
231      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
232
233      const Kernel* kernel = 
234        kernel_->make_kernel(*tmp, true);
235
236      return KernelLookup(*kernel, utility::Index(row_index), 
237                          utility::Index(column_index), true);
238    }
239
240    assert(data.rows()==kernel_->data_weighted().rows());
241    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
242    // should hold a MatrixLookupweighted.
243    utility::MatrixWeighted* x_all = 
244      new utility::MatrixWeighted(data.rows(), rows()+data.columns());
245    const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
246
247    for (size_t i=0; i<data.rows(); ++i){
248
249      // first some columns from data in kernel_
250      for (size_t j=0; j<row_index_.size(); ++j){
251        (*x_all)(i,j) = kernel_data(i,row_index_[j]); 
252      }
253
254      // last columns are equal to new data
255      for (size_t j=0;j<data.columns(); ++j){
256        (*x_all)(i,j+row_index_.size()).data() = data(i,j);
257      }
258    }
259    std::vector<size_t> column_index;
260    column_index.reserve(data.columns());
261    for (size_t i=0;i<data.columns(); ++i)
262      column_index.push_back(i+row_index_.size());
263
264    std::vector<size_t> row_index;
265    row_index.reserve(row_index_.size());
266    for (size_t i=0;i<row_index_.size(); ++i)
267      row_index.push_back(i);
268
269    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*x_all, true);
270    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
271
272
273    return KernelLookup(*kernel, row_index_, 
274                        utility::Index(column_index), true);
275  }
276
277
278
279  KernelLookup KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
280  {
281    utility::MatrixWeighted* x_all = 
282      new utility::MatrixWeighted(data.rows(), rows()+data.columns());
283
284    if (weighted()){
285      const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
286   
287      for (size_t i=0; i<data.rows(); ++i){
288        // first columns are equal to data in kernel_
289        for (size_t j=0; j<row_index_.size(); ++j){
290          (*x_all)(i,j) = kernel_data(i,row_index_[j]);
291        }
292      }
293    }
294    else {
295
296      for (size_t i=0; i<data.rows(); ++i){
297        // first columns are equal to data in kernel_
298        for (size_t j=0; j<row_index_.size(); ++j)
299          (*x_all)(i,j).data() = kernel_->data()(i,row_index_[j]);
300      }
301    }
302
303    // last columns are equal to new data
304    for (size_t i=0; i<data.rows(); ++i){
305      for (size_t j=0;j<data.columns(); ++j){
306        (*x_all)(i,j+row_index_.size()) = data(i,j);
307      }
308    }
309   
310    std::vector<size_t> column_index;
311    column_index.reserve(data.columns());
312    for (size_t i=0;i<data.columns(); ++i)
313      column_index.push_back(i+row_index_.size());
314    const Kernel* kernel = 
315      kernel_->make_kernel(MatrixLookupWeighted(*x_all, true));
316    return KernelLookup(*kernel, row_index_, 
317                        utility::Index(column_index), true);
318  }
319
320
321  /*
322  const KernelLookup*
323  KernelLookup::training_data(const utility::Index& train) const
324  {
325    return new KernelLookup(*this,train,train);
326  }
327  */
328
329
330  bool KernelLookup::validate(const utility::Index& index) const
331  {
332    for (size_t i=0; i<index.size(); ++i)
333      if (index[i]>=kernel_->size())
334        return false;
335    return true;
336  }
337
338
339  bool KernelLookup::weighted(void) const
340  {
341    return kernel_->weighted();
342  }
343
344
345  KernelLookup::const_reference
346  KernelLookup::operator()(size_t row, size_t column) const
347  {
348    return (*kernel_)(row_index_[row],column_index_[column]);
349  }
350
351}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.