source: trunk/yat/classifier/KernelLookup.cc @ 1201

Last change on this file since 1201 was 1201, checked in by Peter, 14 years ago

returning SmartPtr? rather than conventional pointer when object is dynamically allocated - also fixed bug in SubsetGenerator? for Kernel with feature selection

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.2 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007 Jari Häkkinen, Peter Johansson
5
6  This file is part of the yat library, http://trac.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "KernelLookup.h"
25#include "MatrixLookup.h"
26#include "MatrixLookupWeighted.h"
27#include "yat/utility/Matrix.h"
28
29#include <cassert>
30
31namespace theplu {
32namespace yat {
33namespace classifier {
34
35  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
36    : kernel_(utility::SmartPtr<const Kernel>(&kernel, own))
37  {
38    column_index_ = utility::Index(kernel.size());
39    row_index_=column_index_;
40  }
41
42
43  KernelLookup::KernelLookup(const Kernel& kernel,
44                             const utility::Index& row, 
45                             const utility::Index& column,
46                             const bool owner)
47    : column_index_(column), 
48      kernel_(utility::SmartPtr<const Kernel>(&kernel, owner)),
49      row_index_(row)
50  {
51    // Checking that each row index is less than kernel.rows()
52    assert(validate(row_index_));
53    // Checking that each column index is less than kernel.column()
54    assert(validate(column_index_));
55  }
56
57
58  KernelLookup::KernelLookup(const KernelLookup& other, 
59                             const utility::Index& row, 
60                             const utility::Index& column)
61    : kernel_(other.kernel_)
62  {
63    row_index_ = utility::Index(other.row_index_, row);
64    column_index_ = utility::Index(other.column_index_, column);
65  }
66 
67
68  KernelLookup::KernelLookup(const KernelLookup& other)
69    : column_index_(other.column_index_), kernel_(other.kernel_), 
70      row_index_(other.row_index_)
71  {
72    // Checking that each row index is less than kernel.rows()
73    assert(validate(row_index_));
74    // Checking that each column index is less than kernel.column()
75    assert(validate(column_index_));
76  }
77 
78
79  KernelLookup::KernelLookup(const KernelLookup& other, 
80                             const utility::Index& index, 
81                             const bool row)
82    : kernel_(other.kernel_)
83  {
84    if (row){
85      row_index_ = utility::Index(other.row_index_, index);
86      column_index_= other.column_index_;
87    }
88    else{
89      column_index_ = utility::Index(other.column_index_, index);
90      row_index_= other.row_index_;
91    }
92    assert(kernel_->size());
93
94    // Checking that each row index is less than kernel.rows()
95    assert(validate(row_index_));
96    // Checking that each column index is less than kernel.column()
97    assert(validate(column_index_));
98  }
99 
100
101  KernelLookup::~KernelLookup(void)
102  {
103  }
104
105
106  KernelLookup::const_iterator KernelLookup::begin(void) const
107  {
108    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
109  }
110
111
112  KernelLookup::const_column_iterator
113  KernelLookup::begin_column(size_t i) const
114  {
115    return const_column_iterator(const_column_iterator::iterator_type(*this,
116                                                                      0,i), 
117                                 columns());
118  }
119
120
121  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
122  {
123    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
124  }
125
126
127  size_t KernelLookup::columns(void) const
128  {
129    return column_index_.size();
130  }
131
132
133  utility::SmartPtr<const MatrixLookup> KernelLookup::data(void) const
134  {
135    return utility::SmartPtr<const MatrixLookup>
136      (new MatrixLookup(kernel_->data(), column_index_, false));
137  }
138
139
140  utility::SmartPtr<const MatrixLookupWeighted> 
141  KernelLookup::data_weighted(void) const
142  {
143    return utility::SmartPtr<const MatrixLookupWeighted>
144      (new MatrixLookupWeighted(kernel_->data_weighted(),column_index_,false));
145  }
146
147
148  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
149  {
150    return kernel_->element(vec, row_index_[i]);
151  }
152
153
154  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
155  {
156    return kernel_->element(vec, row_index_[i]);
157  }
158
159
160  KernelLookup::const_iterator KernelLookup::end(void) const
161  {
162    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
163  }
164
165
166  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
167  {
168    return const_column_iterator(const_column_iterator::iterator_type(*this, 
169                                                                      rows(),i),
170                                 columns());
171  }
172
173
174  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
175  {
176    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
177  }
178
179
180  size_t KernelLookup::rows(void) const
181  {
182    return row_index_.size();
183  }
184
185
186  utility::SmartPtr<const KernelLookup> 
187  KernelLookup::selected(const utility::Index& inputs) const
188  {
189    const Kernel* kernel;
190    if (kernel_->weighted()){
191      utility::SmartPtr<const MatrixLookupWeighted> ml = data_weighted();
192      const MatrixLookupWeighted* ms = 
193        new MatrixLookupWeighted(*ml,inputs,true);
194      kernel = kernel_->make_kernel(*ms, false);
195    }
196    else {
197      utility::SmartPtr<const MatrixLookup> m = data();
198      // matrix with selected features
199      const MatrixLookup* ms = new MatrixLookup(*m,inputs,true);
200      kernel = kernel_->make_kernel(*ms,true);
201    }
202    return utility::SmartPtr<const KernelLookup>
203      (new KernelLookup(*kernel, true));
204  }
205
206
207  utility::SmartPtr<const KernelLookup> 
208  KernelLookup::test_kernel(const MatrixLookup& data) const
209  {
210    if (!weighted()){
211      assert(data.rows()==kernel_->data().rows());
212      utility::Matrix* data_all = 
213        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
214
215      for (size_t i=0; i<data_all->rows(); ++i) {
216
217        // first some columns from data in kernel_
218        for (size_t j=0; j<row_index_.size(); ++j){
219          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
220        }
221       
222        // last columns are equal to new data
223        for (size_t j=0;j<data.columns(); ++j){
224          (*data_all)(i,j+row_index_.size()) = data(i,j);
225        }
226      }
227      std::vector<size_t> column_index;
228      column_index.reserve(data.columns());
229      for (size_t i=0;i<data.columns(); ++i)
230        column_index.push_back(i+row_index_.size());
231
232      std::vector<size_t> row_index;
233      row_index.reserve(row_index_.size());
234      for (size_t i=0;i<row_index_.size(); ++i)
235        row_index.push_back(i);
236
237      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
238
239      const Kernel* kernel = 
240        kernel_->make_kernel(*tmp, true);
241
242      return utility::SmartPtr<const KernelLookup>
243        (new KernelLookup(*kernel, utility::Index(row_index), 
244                          utility::Index(column_index), true));
245    }
246
247    assert(data.rows()==kernel_->data_weighted().rows());
248    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
249    // should hold a MatrixLookupweighted.
250    utility::Matrix* data_all = 
251      new utility::Matrix(data.rows(), rows()+data.columns());
252    utility::Matrix* weight_all = 
253      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
254    const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
255
256    for (size_t i=0; i<data.rows(); ++i){
257
258      // first some columns from data in kernel_
259      for (size_t j=0; j<row_index_.size(); ++j){
260        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]); 
261        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
262      }
263
264      // last columns are equal to new data
265      for (size_t j=0;j<data.columns(); ++j){
266        (*data_all)(i,j+row_index_.size()) = data(i,j);
267      }
268    }
269    std::vector<size_t> column_index;
270    column_index.reserve(data.columns());
271    for (size_t i=0;i<data.columns(); ++i)
272      column_index.push_back(i+row_index_.size());
273
274    std::vector<size_t> row_index;
275    row_index.reserve(row_index_.size());
276    for (size_t i=0;i<row_index_.size(); ++i)
277      row_index.push_back(i);
278
279    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*data_all, 
280                                                         *weight_all, true);
281    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
282
283
284    return utility::SmartPtr<const KernelLookup>
285      (new KernelLookup(*kernel, row_index_, 
286                        utility::Index(column_index), true));
287  }
288
289
290
291  utility::SmartPtr<const KernelLookup>
292  KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
293  {
294    utility::Matrix* data_all = 
295      new utility::Matrix(data.rows(), rows()+data.columns());
296    utility::Matrix* weight_all = 
297      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
298
299    if (weighted()){
300      const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
301   
302      for (size_t i=0; i<data.rows(); ++i){
303        // first columns are equal to data in kernel_
304        for (size_t j=0; j<row_index_.size(); ++j){
305          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
306          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
307        }
308      }
309    }
310    else {
311
312      for (size_t i=0; i<data.rows(); ++i){
313        // first columns are equal to data in kernel_
314        for (size_t j=0; j<row_index_.size(); ++j)
315          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
316      }
317    }
318
319    // last columns are equal to new data
320    for (size_t i=0; i<data.rows(); ++i){
321      for (size_t j=0;j<data.columns(); ++j){
322        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
323        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
324      }
325    }
326   
327    std::vector<size_t> column_index;
328    column_index.reserve(data.columns());
329    for (size_t i=0;i<data.columns(); ++i)
330      column_index.push_back(i+row_index_.size());
331    const Kernel* kernel = 
332      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
333    return utility::SmartPtr<const KernelLookup>
334      (new KernelLookup(*kernel, row_index_, 
335                        utility::Index(column_index), true));
336  }
337
338
339  /*
340  const KernelLookup*
341  KernelLookup::training_data(const utility::Index& train) const
342  {
343    return new KernelLookup(*this,train,train);
344  }
345  */
346
347
348  bool KernelLookup::validate(const utility::Index& index) const
349  {
350    for (size_t i=0; i<index.size(); ++i)
351      if (index[i]>=kernel_->size())
352        return false;
353    return true;
354  }
355
356
357  bool KernelLookup::weighted(void) const
358  {
359    return kernel_->weighted();
360  }
361
362
363  double KernelLookup::operator()(size_t row, size_t column) const
364  {
365    return (*kernel_)(row_index_[row],column_index_[column]);
366  }
367
368}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.