source: branches/0.4-stable/yat/classifier/KernelLookup.cc

Last change on this file was 1392, checked in by Peter, 13 years ago

trac has moved

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 9.8 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2008 Peter Johansson
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "KernelLookup.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "yat/utility/Matrix.h"
29
30#include <cassert>
31
32namespace theplu {
33namespace yat {
34namespace classifier {
35
36  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
37    : kernel_(utility::SmartPtr<const Kernel>(&kernel, own))
38  {
39    column_index_ = utility::Index(kernel.size());
40    row_index_=column_index_;
41  }
42
43
44  KernelLookup::KernelLookup(const Kernel& kernel,
45                             const utility::Index& row, 
46                             const utility::Index& column,
47                             const bool owner)
48    : column_index_(column), 
49      kernel_(utility::SmartPtr<const Kernel>(&kernel, owner)),
50      row_index_(row)
51  {
52    // Checking that each row index is less than kernel.rows()
53    assert(validate(row_index_));
54    // Checking that each column index is less than kernel.column()
55    assert(validate(column_index_));
56  }
57
58
59  KernelLookup::KernelLookup(const KernelLookup& other, 
60                             const utility::Index& row, 
61                             const utility::Index& column)
62    : kernel_(other.kernel_)
63  {
64    row_index_ = utility::Index(other.row_index_, row);
65    column_index_ = utility::Index(other.column_index_, column);
66  }
67 
68
69  KernelLookup::KernelLookup(const KernelLookup& other)
70    : column_index_(other.column_index_), kernel_(other.kernel_), 
71      row_index_(other.row_index_)
72  {
73    // Checking that each row index is less than kernel.rows()
74    assert(validate(row_index_));
75    // Checking that each column index is less than kernel.column()
76    assert(validate(column_index_));
77  }
78 
79
80  KernelLookup::KernelLookup(const KernelLookup& other, 
81                             const utility::Index& index, 
82                             const bool row)
83    : kernel_(other.kernel_)
84  {
85    if (row){
86      row_index_ = utility::Index(other.row_index_, index);
87      column_index_= other.column_index_;
88    }
89    else{
90      column_index_ = utility::Index(other.column_index_, index);
91      row_index_= other.row_index_;
92    }
93    assert(kernel_->size());
94
95    // Checking that each row index is less than kernel.rows()
96    assert(validate(row_index_));
97    // Checking that each column index is less than kernel.column()
98    assert(validate(column_index_));
99  }
100 
101
102  KernelLookup::~KernelLookup(void)
103  {
104  }
105
106
107  KernelLookup::const_iterator KernelLookup::begin(void) const
108  {
109    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
110  }
111
112
113  KernelLookup::const_column_iterator
114  KernelLookup::begin_column(size_t i) const
115  {
116    return const_column_iterator(const_column_iterator::iterator_type(*this,
117                                                                      0,i), 
118                                 columns());
119  }
120
121
122  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
123  {
124    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
125  }
126
127
128  size_t KernelLookup::columns(void) const
129  {
130    return column_index_.size();
131  }
132
133
134  MatrixLookup KernelLookup::data(void) const
135  {
136    assert(!weighted());
137    return MatrixLookup(kernel_->data(), column_index_, false);
138  }
139
140
141  MatrixLookupWeighted KernelLookup::data_weighted(void) const
142  {
143    assert(weighted());
144    return MatrixLookupWeighted(kernel_->data_weighted(),column_index_,false);
145  }
146
147
148  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
149  {
150    return kernel_->element(vec, row_index_[i]);
151  }
152
153
154  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
155  {
156    return kernel_->element(vec, row_index_[i]);
157  }
158
159
160  KernelLookup::const_iterator KernelLookup::end(void) const
161  {
162    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
163  }
164
165
166  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
167  {
168    return const_column_iterator(const_column_iterator::iterator_type(*this, 
169                                                                      rows(),i),
170                                 columns());
171  }
172
173
174  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
175  {
176    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
177  }
178
179
180  size_t KernelLookup::rows(void) const
181  {
182    return row_index_.size();
183  }
184
185
186  KernelLookup KernelLookup::selected(const utility::Index& inputs) const
187  {
188    const Kernel* kernel;
189    if (kernel_->weighted()){
190      const MatrixLookupWeighted* ms = 
191        new MatrixLookupWeighted(data_weighted(),inputs,true);
192      kernel = kernel_->make_kernel(*ms, true);
193    }
194    else {
195      // matrix with selected features
196      const MatrixLookup* ms = new MatrixLookup(data(),inputs,true);
197      kernel = kernel_->make_kernel(*ms,true);
198    }
199    return KernelLookup(*kernel, true);
200  }
201
202
203  KernelLookup KernelLookup::test_kernel(const MatrixLookup& data) const
204  {
205    if (!weighted()){
206      assert(data.rows()==kernel_->data().rows());
207      utility::Matrix* data_all = 
208        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
209
210      for (size_t i=0; i<data_all->rows(); ++i) {
211
212        // first some columns from data in kernel_
213        for (size_t j=0; j<row_index_.size(); ++j){
214          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
215        }
216       
217        // last columns are equal to new data
218        for (size_t j=0;j<data.columns(); ++j){
219          (*data_all)(i,j+row_index_.size()) = data(i,j);
220        }
221      }
222      std::vector<size_t> column_index;
223      column_index.reserve(data.columns());
224      for (size_t i=0;i<data.columns(); ++i)
225        column_index.push_back(i+row_index_.size());
226
227      std::vector<size_t> row_index;
228      row_index.reserve(row_index_.size());
229      for (size_t i=0;i<row_index_.size(); ++i)
230        row_index.push_back(i);
231
232      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
233
234      const Kernel* kernel = 
235        kernel_->make_kernel(*tmp, true);
236
237      return KernelLookup(*kernel, utility::Index(row_index), 
238                          utility::Index(column_index), true);
239    }
240
241    assert(data.rows()==kernel_->data_weighted().rows());
242    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
243    // should hold a MatrixLookupweighted.
244    utility::Matrix* data_all = 
245      new utility::Matrix(data.rows(), rows()+data.columns());
246    utility::Matrix* weight_all = 
247      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
248    const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
249
250    for (size_t i=0; i<data.rows(); ++i){
251
252      // first some columns from data in kernel_
253      for (size_t j=0; j<row_index_.size(); ++j){
254        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]); 
255        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
256      }
257
258      // last columns are equal to new data
259      for (size_t j=0;j<data.columns(); ++j){
260        (*data_all)(i,j+row_index_.size()) = data(i,j);
261      }
262    }
263    std::vector<size_t> column_index;
264    column_index.reserve(data.columns());
265    for (size_t i=0;i<data.columns(); ++i)
266      column_index.push_back(i+row_index_.size());
267
268    std::vector<size_t> row_index;
269    row_index.reserve(row_index_.size());
270    for (size_t i=0;i<row_index_.size(); ++i)
271      row_index.push_back(i);
272
273    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*data_all, 
274                                                         *weight_all, true);
275    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
276
277
278    return KernelLookup(*kernel, row_index_, 
279                        utility::Index(column_index), true);
280  }
281
282
283
284  KernelLookup KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
285  {
286    utility::Matrix* data_all = 
287      new utility::Matrix(data.rows(), rows()+data.columns());
288    utility::Matrix* weight_all = 
289      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
290
291    if (weighted()){
292      const MatrixLookupWeighted& kernel_data = kernel_->data_weighted();
293   
294      for (size_t i=0; i<data.rows(); ++i){
295        // first columns are equal to data in kernel_
296        for (size_t j=0; j<row_index_.size(); ++j){
297          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
298          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
299        }
300      }
301    }
302    else {
303
304      for (size_t i=0; i<data.rows(); ++i){
305        // first columns are equal to data in kernel_
306        for (size_t j=0; j<row_index_.size(); ++j)
307          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
308      }
309    }
310
311    // last columns are equal to new data
312    for (size_t i=0; i<data.rows(); ++i){
313      for (size_t j=0;j<data.columns(); ++j){
314        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
315        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
316      }
317    }
318   
319    std::vector<size_t> column_index;
320    column_index.reserve(data.columns());
321    for (size_t i=0;i<data.columns(); ++i)
322      column_index.push_back(i+row_index_.size());
323    const Kernel* kernel = 
324      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
325    return KernelLookup(*kernel, row_index_, 
326                        utility::Index(column_index), true);
327  }
328
329
330  /*
331  const KernelLookup*
332  KernelLookup::training_data(const utility::Index& train) const
333  {
334    return new KernelLookup(*this,train,train);
335  }
336  */
337
338
339  bool KernelLookup::validate(const utility::Index& index) const
340  {
341    for (size_t i=0; i<index.size(); ++i)
342      if (index[i]>=kernel_->size())
343        return false;
344    return true;
345  }
346
347
348  bool KernelLookup::weighted(void) const
349  {
350    return kernel_->weighted();
351  }
352
353
354  double KernelLookup::operator()(size_t row, size_t column) const
355  {
356    return (*kernel_)(row_index_[row],column_index_[column]);
357  }
358
359}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.