source: trunk/yat/classifier/KernelLookup.cc @ 1165

Last change on this file since 1165 was 1165, checked in by Peter, 14 years ago

removed dynamic_casts

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.2 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007 Jari Häkkinen, Peter Johansson
5
6  This file is part of the yat library, http://trac.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "KernelLookup.h"
25#include "DataLookup2D.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "yat/utility/Matrix.h"
29
30#include <cassert>
31#ifndef NDEBUG
32#include <algorithm>
33#endif
34
35namespace theplu {
36namespace yat {
37namespace classifier {
38
39  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
40    : kernel_(utility::SmartPtr<const Kernel>(&kernel, own))
41  {
42    column_index_ = utility::Index(kernel.size());
43    row_index_=column_index_;
44  }
45
46
47  KernelLookup::KernelLookup(const Kernel& kernel,
48                             const utility::Index& row, 
49                             const utility::Index& column,
50                             const bool owner)
51    : column_index_(column), 
52      kernel_(utility::SmartPtr<const Kernel>(&kernel, owner)),
53      row_index_(row)
54  {
55    // Checking that each row index is less than kernel.rows()
56    assert(validate(row_index_));
57    // Checking that each column index is less than kernel.column()
58    assert(validate(column_index_));
59  }
60
61
62  KernelLookup::KernelLookup(const KernelLookup& other, 
63                             const utility::Index& row, 
64                             const utility::Index& column)
65    : kernel_(other.kernel_)
66  {
67    row_index_ = utility::Index(other.row_index_, row);
68    column_index_ = utility::Index(other.column_index_, column);
69  }
70 
71
72  KernelLookup::KernelLookup(const KernelLookup& other)
73    : column_index_(other.column_index_), kernel_(other.kernel_), 
74      row_index_(other.row_index_)
75  {
76    // Checking that each row index is less than kernel.rows()
77    assert(validate(row_index_));
78    // Checking that each column index is less than kernel.column()
79    assert(validate(column_index_));
80  }
81 
82
83  KernelLookup::KernelLookup(const KernelLookup& other, 
84                             const utility::Index& index, 
85                             const bool row)
86    : kernel_(other.kernel_)
87  {
88    if (row){
89      row_index_ = utility::Index(other.row_index_, index);
90      column_index_= other.column_index_;
91    }
92    else{
93      column_index_ = utility::Index(other.column_index_, index);
94      row_index_= other.row_index_;
95    }
96    assert(kernel_->size());
97
98    // Checking that each row index is less than kernel.rows()
99    assert(validate(row_index_));
100    // Checking that each column index is less than kernel.column()
101    assert(validate(column_index_));
102  }
103 
104
105  KernelLookup::~KernelLookup(void)
106  {
107  }
108
109
110  KernelLookup::const_iterator KernelLookup::begin(void) const
111  {
112    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
113  }
114
115
116  KernelLookup::const_column_iterator
117  KernelLookup::begin_column(size_t i) const
118  {
119    return const_column_iterator(const_column_iterator::iterator_type(*this,
120                                                                      0,i), 
121                                 columns());
122  }
123
124
125  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
126  {
127    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
128  }
129
130
131  size_t KernelLookup::columns(void) const
132  {
133    return column_index_.size();
134  }
135
136
137  utility::SmartPtr<const MatrixLookup> KernelLookup::data(void) const
138  {
139    return utility::SmartPtr<const MatrixLookup>
140      (kernel_->data().training_data(column_index_));
141  }
142
143
144  utility::SmartPtr<const MatrixLookupWeighted> 
145  KernelLookup::data_weighted(void) const
146  {
147    return utility::SmartPtr<const MatrixLookupWeighted>
148      (kernel_->data_weighted().training_data(column_index_));
149  }
150
151
152  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
153  {
154    return kernel_->element(vec, row_index_[i]);
155  }
156
157
158  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
159  {
160    return kernel_->element(vec, row_index_[i]);
161  }
162
163
164  KernelLookup::const_iterator KernelLookup::end(void) const
165  {
166    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
167  }
168
169
170  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
171  {
172    return const_column_iterator(const_column_iterator::iterator_type(*this, 
173                                                                      rows(),i),
174                                 columns());
175  }
176
177
178  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
179  {
180    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
181  }
182
183
184  size_t KernelLookup::rows(void) const
185  {
186    return row_index_.size();
187  }
188
189
190  const KernelLookup* 
191  KernelLookup::selected(const utility::Index& inputs) const
192  {
193    const Kernel* kernel;
194    if (kernel_->weighted()){
195      utility::SmartPtr<const MatrixLookupWeighted> ml = data_weighted();
196      const MatrixLookupWeighted* ms = 
197        new MatrixLookupWeighted(*ml,inputs,true);
198      kernel = kernel_->make_kernel(*ms, false);
199    }
200    else {
201      utility::SmartPtr<const MatrixLookup> m = data();
202      // matrix with selected features
203      const MatrixLookup* ms = new MatrixLookup(*m,inputs,true);
204      kernel = kernel_->make_kernel(*ms,true);
205    }
206    return new KernelLookup(*kernel, true);
207  }
208
209
210  const KernelLookup* KernelLookup::test_kernel(const MatrixLookup& data) const
211  {
212
213    if (!weighted()){
214      assert(data.rows()==kernel_->data().rows());
215      utility::Matrix* data_all = 
216        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
217
218      for (size_t i=0; i<data_all->rows(); ++i) {
219
220        // first some columns from data in kernel_
221        for (size_t j=0; j<row_index_.size(); ++j){
222          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
223        }
224       
225        // last columns are equal to new data
226        for (size_t j=0;j<data.columns(); ++j){
227          (*data_all)(i,j+row_index_.size()) = data(i,j);
228        }
229      }
230      std::vector<size_t> column_index;
231      column_index.reserve(data.columns());
232      for (size_t i=0;i<data.columns(); ++i)
233        column_index.push_back(i+row_index_.size());
234
235      std::vector<size_t> row_index;
236      row_index.reserve(row_index_.size());
237      for (size_t i=0;i<row_index_.size(); ++i)
238        row_index.push_back(i);
239
240      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
241
242      const Kernel* kernel = 
243        kernel_->make_kernel(*tmp, true);
244
245      return new KernelLookup(*kernel, utility::Index(row_index), 
246                              utility::Index(column_index), true);
247    }
248
249    assert(data.rows()==kernel_->data_weighted().rows());
250    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
251    // should hold a MatrixLookupweighted.
252    utility::Matrix* data_all = 
253      new utility::Matrix(data.rows(), rows()+data.columns());
254    utility::Matrix* weight_all = 
255      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
256    const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
257
258    for (size_t i=0; i<data.rows(); ++i){
259
260      // first some columns from data in kernel_
261      for (size_t j=0; j<row_index_.size(); ++j){
262        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]); 
263        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
264      }
265
266      // last columns are equal to new data
267      for (size_t j=0;j<data.columns(); ++j){
268        (*data_all)(i,j+row_index_.size()) = data(i,j);
269      }
270    }
271    std::vector<size_t> column_index;
272    column_index.reserve(data.columns());
273    for (size_t i=0;i<data.columns(); ++i)
274      column_index.push_back(i+row_index_.size());
275
276    std::vector<size_t> row_index;
277    row_index.reserve(row_index_.size());
278    for (size_t i=0;i<row_index_.size(); ++i)
279      row_index.push_back(i);
280
281    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*data_all, 
282                                                         *weight_all, true);
283    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
284    return new KernelLookup(*kernel, row_index_, 
285                            utility::Index(column_index), true);
286  }
287
288
289
290  const KernelLookup* 
291  KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
292  {
293    utility::Matrix* data_all = 
294      new utility::Matrix(data.rows(), rows()+data.columns());
295    utility::Matrix* weight_all = 
296      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
297
298    if (weighted()){
299      const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
300   
301      for (size_t i=0; i<data.rows(); ++i){
302        // first columns are equal to data in kernel_
303        for (size_t j=0; j<row_index_.size(); ++j){
304          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
305          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
306        }
307      }
308    }
309    else {
310
311      for (size_t i=0; i<data.rows(); ++i){
312        // first columns are equal to data in kernel_
313        for (size_t j=0; j<row_index_.size(); ++j)
314          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
315      }
316    }
317
318    // last columns are equal to new data
319    for (size_t i=0; i<data.rows(); ++i){
320      for (size_t j=0;j<data.columns(); ++j){
321        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
322        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
323      }
324    }
325   
326    std::vector<size_t> column_index;
327    column_index.reserve(data.columns());
328    for (size_t i=0;i<data.columns(); ++i)
329      column_index.push_back(i+row_index_.size());
330    const Kernel* kernel = 
331      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
332    return new KernelLookup(*kernel, row_index_, 
333                            utility::Index(column_index), true);
334  }
335
336
337  const KernelLookup*
338  KernelLookup::training_data(const utility::Index& train) const
339  {
340    return new KernelLookup(*this,train,train);
341  }
342
343
344  bool KernelLookup::validate(const utility::Index& index) const
345  {
346    for (size_t i=0; i<index.size(); ++i)
347      if (index[i]>=kernel_->size())
348        return false;
349    return true;
350  }
351
352
353  const KernelLookup*
354  KernelLookup::validation_data(const utility::Index& train,
355                                const utility::Index& validation) const
356  {
357    return new KernelLookup(*this,train,validation);
358  }
359
360
361  bool KernelLookup::weighted(void) const
362  {
363    return kernel_->weighted();
364  }
365
366
367  double KernelLookup::operator()(size_t row, size_t column) const
368  {
369    return (*kernel_)(row_index_[row],column_index_[column]);
370  }
371
372}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.