source: trunk/yat/classifier/KernelLookup.cc @ 1168

Last change on this file since 1168 was 1168, checked in by Peter, 13 years ago

fixes #342

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.0 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007 Jari Häkkinen, Peter Johansson
5
6  This file is part of the yat library, http://trac.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "KernelLookup.h"
25#include "MatrixLookup.h"
26#include "MatrixLookupWeighted.h"
27#include "yat/utility/Matrix.h"
28
29#include <cassert>
30#ifndef NDEBUG
31#include <algorithm>
32#endif
33
34namespace theplu {
35namespace yat {
36namespace classifier {
37
38  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
39    : kernel_(utility::SmartPtr<const Kernel>(&kernel, own))
40  {
41    column_index_ = utility::Index(kernel.size());
42    row_index_=column_index_;
43  }
44
45
46  KernelLookup::KernelLookup(const Kernel& kernel,
47                             const utility::Index& row, 
48                             const utility::Index& column,
49                             const bool owner)
50    : column_index_(column), 
51      kernel_(utility::SmartPtr<const Kernel>(&kernel, owner)),
52      row_index_(row)
53  {
54    // Checking that each row index is less than kernel.rows()
55    assert(validate(row_index_));
56    // Checking that each column index is less than kernel.column()
57    assert(validate(column_index_));
58  }
59
60
61  KernelLookup::KernelLookup(const KernelLookup& other, 
62                             const utility::Index& row, 
63                             const utility::Index& column)
64    : kernel_(other.kernel_)
65  {
66    row_index_ = utility::Index(other.row_index_, row);
67    column_index_ = utility::Index(other.column_index_, column);
68  }
69 
70
71  KernelLookup::KernelLookup(const KernelLookup& other)
72    : column_index_(other.column_index_), kernel_(other.kernel_), 
73      row_index_(other.row_index_)
74  {
75    // Checking that each row index is less than kernel.rows()
76    assert(validate(row_index_));
77    // Checking that each column index is less than kernel.column()
78    assert(validate(column_index_));
79  }
80 
81
82  KernelLookup::KernelLookup(const KernelLookup& other, 
83                             const utility::Index& index, 
84                             const bool row)
85    : kernel_(other.kernel_)
86  {
87    if (row){
88      row_index_ = utility::Index(other.row_index_, index);
89      column_index_= other.column_index_;
90    }
91    else{
92      column_index_ = utility::Index(other.column_index_, index);
93      row_index_= other.row_index_;
94    }
95    assert(kernel_->size());
96
97    // Checking that each row index is less than kernel.rows()
98    assert(validate(row_index_));
99    // Checking that each column index is less than kernel.column()
100    assert(validate(column_index_));
101  }
102 
103
104  KernelLookup::~KernelLookup(void)
105  {
106  }
107
108
109  KernelLookup::const_iterator KernelLookup::begin(void) const
110  {
111    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
112  }
113
114
115  KernelLookup::const_column_iterator
116  KernelLookup::begin_column(size_t i) const
117  {
118    return const_column_iterator(const_column_iterator::iterator_type(*this,
119                                                                      0,i), 
120                                 columns());
121  }
122
123
124  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
125  {
126    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
127  }
128
129
130  size_t KernelLookup::columns(void) const
131  {
132    return column_index_.size();
133  }
134
135
136  utility::SmartPtr<const MatrixLookup> KernelLookup::data(void) const
137  {
138    return utility::SmartPtr<const MatrixLookup>
139      (new MatrixLookup(kernel_->data(), column_index_, false));
140  }
141
142
143  utility::SmartPtr<const MatrixLookupWeighted> 
144  KernelLookup::data_weighted(void) const
145  {
146    return utility::SmartPtr<const MatrixLookupWeighted>
147      (new MatrixLookupWeighted(kernel_->data_weighted(),column_index_,false));
148  }
149
150
151  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
152  {
153    return kernel_->element(vec, row_index_[i]);
154  }
155
156
157  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
158  {
159    return kernel_->element(vec, row_index_[i]);
160  }
161
162
163  KernelLookup::const_iterator KernelLookup::end(void) const
164  {
165    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
166  }
167
168
169  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
170  {
171    return const_column_iterator(const_column_iterator::iterator_type(*this, 
172                                                                      rows(),i),
173                                 columns());
174  }
175
176
177  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
178  {
179    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
180  }
181
182
183  size_t KernelLookup::rows(void) const
184  {
185    return row_index_.size();
186  }
187
188
189  const KernelLookup* 
190  KernelLookup::selected(const utility::Index& inputs) const
191  {
192    const Kernel* kernel;
193    if (kernel_->weighted()){
194      utility::SmartPtr<const MatrixLookupWeighted> ml = data_weighted();
195      const MatrixLookupWeighted* ms = 
196        new MatrixLookupWeighted(*ml,inputs,true);
197      kernel = kernel_->make_kernel(*ms, false);
198    }
199    else {
200      utility::SmartPtr<const MatrixLookup> m = data();
201      // matrix with selected features
202      const MatrixLookup* ms = new MatrixLookup(*m,inputs,true);
203      kernel = kernel_->make_kernel(*ms,true);
204    }
205    return new KernelLookup(*kernel, true);
206  }
207
208
209  const KernelLookup* KernelLookup::test_kernel(const MatrixLookup& data) const
210  {
211
212    if (!weighted()){
213      assert(data.rows()==kernel_->data().rows());
214      utility::Matrix* data_all = 
215        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
216
217      for (size_t i=0; i<data_all->rows(); ++i) {
218
219        // first some columns from data in kernel_
220        for (size_t j=0; j<row_index_.size(); ++j){
221          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
222        }
223       
224        // last columns are equal to new data
225        for (size_t j=0;j<data.columns(); ++j){
226          (*data_all)(i,j+row_index_.size()) = data(i,j);
227        }
228      }
229      std::vector<size_t> column_index;
230      column_index.reserve(data.columns());
231      for (size_t i=0;i<data.columns(); ++i)
232        column_index.push_back(i+row_index_.size());
233
234      std::vector<size_t> row_index;
235      row_index.reserve(row_index_.size());
236      for (size_t i=0;i<row_index_.size(); ++i)
237        row_index.push_back(i);
238
239      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
240
241      const Kernel* kernel = 
242        kernel_->make_kernel(*tmp, true);
243
244      return new KernelLookup(*kernel, utility::Index(row_index), 
245                              utility::Index(column_index), true);
246    }
247
248    assert(data.rows()==kernel_->data_weighted().rows());
249    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
250    // should hold a MatrixLookupweighted.
251    utility::Matrix* data_all = 
252      new utility::Matrix(data.rows(), rows()+data.columns());
253    utility::Matrix* weight_all = 
254      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
255    const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
256
257    for (size_t i=0; i<data.rows(); ++i){
258
259      // first some columns from data in kernel_
260      for (size_t j=0; j<row_index_.size(); ++j){
261        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]); 
262        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
263      }
264
265      // last columns are equal to new data
266      for (size_t j=0;j<data.columns(); ++j){
267        (*data_all)(i,j+row_index_.size()) = data(i,j);
268      }
269    }
270    std::vector<size_t> column_index;
271    column_index.reserve(data.columns());
272    for (size_t i=0;i<data.columns(); ++i)
273      column_index.push_back(i+row_index_.size());
274
275    std::vector<size_t> row_index;
276    row_index.reserve(row_index_.size());
277    for (size_t i=0;i<row_index_.size(); ++i)
278      row_index.push_back(i);
279
280    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*data_all, 
281                                                         *weight_all, true);
282    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
283    return new KernelLookup(*kernel, row_index_, 
284                            utility::Index(column_index), true);
285  }
286
287
288
289  const KernelLookup* 
290  KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
291  {
292    utility::Matrix* data_all = 
293      new utility::Matrix(data.rows(), rows()+data.columns());
294    utility::Matrix* weight_all = 
295      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
296
297    if (weighted()){
298      const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
299   
300      for (size_t i=0; i<data.rows(); ++i){
301        // first columns are equal to data in kernel_
302        for (size_t j=0; j<row_index_.size(); ++j){
303          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
304          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
305        }
306      }
307    }
308    else {
309
310      for (size_t i=0; i<data.rows(); ++i){
311        // first columns are equal to data in kernel_
312        for (size_t j=0; j<row_index_.size(); ++j)
313          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
314      }
315    }
316
317    // last columns are equal to new data
318    for (size_t i=0; i<data.rows(); ++i){
319      for (size_t j=0;j<data.columns(); ++j){
320        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
321        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
322      }
323    }
324   
325    std::vector<size_t> column_index;
326    column_index.reserve(data.columns());
327    for (size_t i=0;i<data.columns(); ++i)
328      column_index.push_back(i+row_index_.size());
329    const Kernel* kernel = 
330      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
331    return new KernelLookup(*kernel, row_index_, 
332                            utility::Index(column_index), true);
333  }
334
335
336  /*
337  const KernelLookup*
338  KernelLookup::training_data(const utility::Index& train) const
339  {
340    return new KernelLookup(*this,train,train);
341  }
342  */
343
344
345  bool KernelLookup::validate(const utility::Index& index) const
346  {
347    for (size_t i=0; i<index.size(); ++i)
348      if (index[i]>=kernel_->size())
349        return false;
350    return true;
351  }
352
353
354  bool KernelLookup::weighted(void) const
355  {
356    return kernel_->weighted();
357  }
358
359
360  double KernelLookup::operator()(size_t row, size_t column) const
361  {
362    return (*kernel_)(row_index_[row],column_index_[column]);
363  }
364
365}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.