source: trunk/yat/classifier/KernelLookup.cc @ 1580

Last change on this file since 1580 was 1549, checked in by Peter, 13 years ago

adding some typedefs in Container2Ds - refs #448

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.7 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2008 Peter Johansson
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include "KernelLookup.h"
24#include "MatrixLookup.h"
25#include "MatrixLookupWeighted.h"
26#include "yat/utility/Matrix.h"
27
28#include <cassert>
29
30namespace theplu {
31namespace yat {
32namespace classifier {
33
34  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
35    : kernel_(utility::SmartPtr<const Kernel>(&kernel, own))
36  {
37    column_index_ = utility::Index(kernel.size());
38    row_index_=column_index_;
39  }
40
41
42  KernelLookup::KernelLookup(const Kernel& kernel,
43                             const utility::Index& row, 
44                             const utility::Index& column,
45                             const bool owner)
46    : column_index_(column), 
47      kernel_(utility::SmartPtr<const Kernel>(&kernel, owner)),
48      row_index_(row)
49  {
50    // Checking that each row index is less than kernel.rows()
51    assert(validate(row_index_));
52    // Checking that each column index is less than kernel.column()
53    assert(validate(column_index_));
54  }
55
56
57  KernelLookup::KernelLookup(const KernelLookup& other, 
58                             const utility::Index& row, 
59                             const utility::Index& column)
60    : kernel_(other.kernel_)
61  {
62    row_index_ = utility::Index(other.row_index_, row);
63    column_index_ = utility::Index(other.column_index_, column);
64  }
65 
66
67  KernelLookup::KernelLookup(const KernelLookup& other)
68    : column_index_(other.column_index_), kernel_(other.kernel_), 
69      row_index_(other.row_index_)
70  {
71    // Checking that each row index is less than kernel.rows()
72    assert(validate(row_index_));
73    // Checking that each column index is less than kernel.column()
74    assert(validate(column_index_));
75  }
76 
77
78  KernelLookup::KernelLookup(const KernelLookup& other, 
79                             const utility::Index& index, 
80                             const bool row)
81    : kernel_(other.kernel_)
82  {
83    if (row){
84      row_index_ = utility::Index(other.row_index_, index);
85      column_index_= other.column_index_;
86    }
87    else{
88      column_index_ = utility::Index(other.column_index_, index);
89      row_index_= other.row_index_;
90    }
91    assert(kernel_->size());
92
93    // Checking that each row index is less than kernel.rows()
94    assert(validate(row_index_));
95    // Checking that each column index is less than kernel.column()
96    assert(validate(column_index_));
97  }
98 
99
100  KernelLookup::~KernelLookup(void)
101  {
102  }
103
104
105  KernelLookup::const_iterator KernelLookup::begin(void) const
106  {
107    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
108  }
109
110
111  KernelLookup::const_column_iterator
112  KernelLookup::begin_column(size_t i) const
113  {
114    return const_column_iterator(const_column_iterator::iterator_type(*this,
115                                                                      0,i), 
116                                 columns());
117  }
118
119
120  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
121  {
122    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
123  }
124
125
126  size_t KernelLookup::columns(void) const
127  {
128    return column_index_.size();
129  }
130
131
132  MatrixLookup KernelLookup::data(void) const
133  {
134    assert(!weighted());
135    return MatrixLookup(kernel_->data(), column_index_, false);
136  }
137
138
139  MatrixLookupWeighted KernelLookup::data_weighted(void) const
140  {
141    assert(weighted());
142    return MatrixLookupWeighted(kernel_->data_weighted(),column_index_,false);
143  }
144
145
146  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
147  {
148    return kernel_->element(vec, row_index_[i]);
149  }
150
151
152  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
153  {
154    return kernel_->element(vec, row_index_[i]);
155  }
156
157
158  KernelLookup::const_iterator KernelLookup::end(void) const
159  {
160    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
161  }
162
163
164  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
165  {
166    return const_column_iterator(const_column_iterator::iterator_type(*this, 
167                                                                      rows(),i),
168                                 columns());
169  }
170
171
172  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
173  {
174    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
175  }
176
177
178  size_t KernelLookup::rows(void) const
179  {
180    return row_index_.size();
181  }
182
183
184  KernelLookup KernelLookup::selected(const utility::Index& inputs) const
185  {
186    const Kernel* kernel;
187    if (kernel_->weighted()){
188      const MatrixLookupWeighted* ms = 
189        new MatrixLookupWeighted(data_weighted(),inputs,true);
190      kernel = kernel_->make_kernel(*ms, true);
191    }
192    else {
193      // matrix with selected features
194      const MatrixLookup* ms = new MatrixLookup(data(),inputs,true);
195      kernel = kernel_->make_kernel(*ms,true);
196    }
197    return KernelLookup(*kernel, true);
198  }
199
200
201  KernelLookup KernelLookup::test_kernel(const MatrixLookup& data) const
202  {
203    if (!weighted()){
204      assert(data.rows()==kernel_->data().rows());
205      utility::Matrix* data_all = 
206        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
207
208      for (size_t i=0; i<data_all->rows(); ++i) {
209
210        // first some columns from data in kernel_
211        for (size_t j=0; j<row_index_.size(); ++j){
212          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
213        }
214       
215        // last columns are equal to new data
216        for (size_t j=0;j<data.columns(); ++j){
217          (*data_all)(i,j+row_index_.size()) = data(i,j);
218        }
219      }
220      std::vector<size_t> column_index;
221      column_index.reserve(data.columns());
222      for (size_t i=0;i<data.columns(); ++i)
223        column_index.push_back(i+row_index_.size());
224
225      std::vector<size_t> row_index;
226      row_index.reserve(row_index_.size());
227      for (size_t i=0;i<row_index_.size(); ++i)
228        row_index.push_back(i);
229
230      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
231
232      const Kernel* kernel = 
233        kernel_->make_kernel(*tmp, true);
234
235      return KernelLookup(*kernel, utility::Index(row_index), 
236                          utility::Index(column_index), true);
237    }
238
239    assert(data.rows()==kernel_->data_weighted().rows());
240    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
241    // should hold a MatrixLookupweighted.
242    utility::Matrix* data_all = 
243      new utility::Matrix(data.rows(), rows()+data.columns());
244    utility::Matrix* weight_all = 
245      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
246    const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
247
248    for (size_t i=0; i<data.rows(); ++i){
249
250      // first some columns from data in kernel_
251      for (size_t j=0; j<row_index_.size(); ++j){
252        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]); 
253        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
254      }
255
256      // last columns are equal to new data
257      for (size_t j=0;j<data.columns(); ++j){
258        (*data_all)(i,j+row_index_.size()) = data(i,j);
259      }
260    }
261    std::vector<size_t> column_index;
262    column_index.reserve(data.columns());
263    for (size_t i=0;i<data.columns(); ++i)
264      column_index.push_back(i+row_index_.size());
265
266    std::vector<size_t> row_index;
267    row_index.reserve(row_index_.size());
268    for (size_t i=0;i<row_index_.size(); ++i)
269      row_index.push_back(i);
270
271    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*data_all, 
272                                                         *weight_all, true);
273    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
274
275
276    return KernelLookup(*kernel, row_index_, 
277                        utility::Index(column_index), true);
278  }
279
280
281
282  KernelLookup KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
283  {
284    utility::Matrix* data_all = 
285      new utility::Matrix(data.rows(), rows()+data.columns());
286    utility::Matrix* weight_all = 
287      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
288
289    if (weighted()){
290      const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
291   
292      for (size_t i=0; i<data.rows(); ++i){
293        // first columns are equal to data in kernel_
294        for (size_t j=0; j<row_index_.size(); ++j){
295          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
296          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
297        }
298      }
299    }
300    else {
301
302      for (size_t i=0; i<data.rows(); ++i){
303        // first columns are equal to data in kernel_
304        for (size_t j=0; j<row_index_.size(); ++j)
305          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
306      }
307    }
308
309    // last columns are equal to new data
310    for (size_t i=0; i<data.rows(); ++i){
311      for (size_t j=0;j<data.columns(); ++j){
312        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
313        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
314      }
315    }
316   
317    std::vector<size_t> column_index;
318    column_index.reserve(data.columns());
319    for (size_t i=0;i<data.columns(); ++i)
320      column_index.push_back(i+row_index_.size());
321    const Kernel* kernel = 
322      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
323    return KernelLookup(*kernel, row_index_, 
324                        utility::Index(column_index), true);
325  }
326
327
328  /*
329  const KernelLookup*
330  KernelLookup::training_data(const utility::Index& train) const
331  {
332    return new KernelLookup(*this,train,train);
333  }
334  */
335
336
337  bool KernelLookup::validate(const utility::Index& index) const
338  {
339    for (size_t i=0; i<index.size(); ++i)
340      if (index[i]>=kernel_->size())
341        return false;
342    return true;
343  }
344
345
346  bool KernelLookup::weighted(void) const
347  {
348    return kernel_->weighted();
349  }
350
351
352  KernelLookup::const_reference
353  KernelLookup::operator()(size_t row, size_t column) const
354  {
355    return (*kernel_)(row_index_[row],column_index_[column]);
356  }
357
358}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.