source: trunk/yat/classifier/KernelLookup.cc @ 1132

Last change on this file since 1132 was 1132, checked in by Peter, 14 years ago

KernelLookup? is not inherited from DataLookup2D - fixes #234

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 11.1 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007 Jari Häkkinen, Peter Johansson
5
6  This file is part of the yat library, http://trac.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "KernelLookup.h"
25#include "DataLookup2D.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "yat/utility/Matrix.h"
29
30#include <cassert>
31#ifndef NDEBUG
32#include <algorithm>
33#endif
34
35namespace theplu {
36namespace yat {
37namespace classifier {
38
39  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
40    : kernel_(&kernel)
41  {
42    if (own)
43      ref_count_ = new u_int(1);
44    else
45      ref_count_ = NULL;
46    column_index_.reserve(kernel.size());
47    for(size_t i=0; i<kernel.size(); i++)
48      column_index_.push_back(i);
49    row_index_=column_index_;
50  }
51
52
53  KernelLookup::KernelLookup(const Kernel& kernel,
54                             const std::vector<size_t>& row, 
55                             const std::vector<size_t>& column,
56                             const bool owner)
57    : column_index_(column), kernel_(&kernel), ref_count_(NULL),
58      row_index_(row)
59  {
60    // Checking that each row index is less than kernel.rows()
61    assert(row.empty() || 
62           *(std::max_element(row.begin(),row.end()))<kernel_->size());
63    // Checking that each column index is less than kernel.column()
64    assert(column.empty() || 
65           *(std::max_element(column.begin(),column.end()))<kernel_->size());
66  }
67
68
69  KernelLookup::KernelLookup(const KernelLookup& other, 
70                             const std::vector<size_t>& row, 
71                             const std::vector<size_t>& column)
72    : kernel_(other.kernel_)
73  {
74    assert(row_index_.empty());
75    row_index_.reserve(row.size());
76    for (size_t i=0; i<row.size(); i++) {
77      assert(row[i]<other.row_index_.size());
78      row_index_.push_back(other.row_index_[row[i]]);
79    }
80    assert(column_index_.empty());
81    column_index_.reserve(column.size());
82    for (size_t i=0; i<column.size(); i++) {
83      assert(column[i]<other.column_index_.size());
84      column_index_.push_back(other.column_index_[column[i]]);
85    }
86    ref_count_=other.ref_count_;
87    if (ref_count_)
88      ++(*ref_count_);
89  }
90 
91
92  KernelLookup::KernelLookup(const KernelLookup& other)
93    : column_index_(other.column_index_), kernel_(other.kernel_), 
94      row_index_(other.row_index_)
95  {
96    // Checking that no index is out of range
97    assert(row_index_.empty() || 
98           *(max_element(row_index_.begin(), row_index_.end()))<
99           kernel_->size());
100    assert(column_index_.empty() || 
101           *(max_element(column_index_.begin(), column_index_.end()))<
102           kernel_->size());
103    ref_count_=other.ref_count_;
104    if (ref_count_)
105      ++(*ref_count_);
106  }
107 
108
109  KernelLookup::KernelLookup(const KernelLookup& other, 
110                             const std::vector<size_t>& index, 
111                             const bool row)
112    : kernel_(other.kernel_)
113  {
114    if (row){
115      assert(row_index_.empty());
116      row_index_.reserve(index.size());
117      for (size_t i=0; i<index.size(); i++) {
118        assert(index[i]<other.row_index_.size());
119        row_index_.push_back(other.row_index_[index[i]]);
120      }
121      column_index_= other.column_index_;
122    }
123    else{
124      assert(column_index_.empty());
125      column_index_.reserve(index.size());
126      for (size_t i=0; i<index.size(); i++) {
127        column_index_.push_back(other.column_index_[index[i]]);
128      }
129      row_index_= other.row_index_;
130    }
131    assert(kernel_->size());
132
133    // Checking that no index is out of range
134    assert(row_index_.empty() || 
135           *(max_element(row_index_.begin(), row_index_.end()))<
136           kernel_->size());
137    assert(column_index_.empty() || 
138           *(max_element(column_index_.begin(), column_index_.end()))<
139           kernel_->size());
140    ref_count_=other.ref_count_;
141    if (ref_count_)
142      ++(*ref_count_);
143  }
144 
145
146  KernelLookup::~KernelLookup(void)
147  {
148    if (ref_count_)
149      if (!--(*ref_count_))
150        delete kernel_;
151  }
152
153
154  KernelLookup::const_iterator KernelLookup::begin(void) const
155  {
156    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
157  }
158
159
160  KernelLookup::const_column_iterator
161  KernelLookup::begin_column(size_t i) const
162  {
163    return const_column_iterator(const_column_iterator::iterator_type(*this,
164                                                                      0,i), 
165                                 columns());
166  }
167
168
169  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
170  {
171    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
172  }
173
174
175  size_t KernelLookup::columns(void) const
176  {
177    return column_index_.size();
178  }
179
180
181  const DataLookup2D* KernelLookup::data(void) const
182  {
183    return kernel_->data().training_data(column_index_);
184  }
185
186
187  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
188  {
189    return kernel_->element(vec, row_index_[i]);
190  }
191
192
193  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
194  {
195    return kernel_->element(vec, row_index_[i]);
196  }
197
198
199  KernelLookup::const_iterator KernelLookup::end(void) const
200  {
201    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
202  }
203
204
205  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
206  {
207    return const_column_iterator(const_column_iterator::iterator_type(*this, 
208                                                                      rows(),i),
209                                 columns());
210  }
211
212
213  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
214  {
215    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
216  }
217
218
219  size_t KernelLookup::rows(void) const
220  {
221    return row_index_.size();
222  }
223
224
225  const KernelLookup* 
226  KernelLookup::selected(const std::vector<size_t>& inputs) const
227  {
228    const Kernel* kernel;
229    if (kernel_->weighted()){
230      const MatrixLookupWeighted* ml = 
231        dynamic_cast<const MatrixLookupWeighted*>(data());
232      assert(ml);
233      const MatrixLookupWeighted* ms = 
234        new MatrixLookupWeighted(*ml,inputs,true);
235      kernel = kernel_->make_kernel(*ms, false);
236    }
237    else {
238      const MatrixLookup* m = 
239        dynamic_cast<const MatrixLookup*>(data());
240      assert(m);
241      // matrix with selected features
242      const MatrixLookup* ms = new MatrixLookup(*m,inputs,true);
243      kernel = kernel_->make_kernel(*ms,true);
244    }
245    return new KernelLookup(*kernel, true);
246  }
247
248
249  const KernelLookup* KernelLookup::test_kernel(const MatrixLookup& data) const
250  {
251
252    assert(data.rows()==kernel_->data().rows());
253    if (!weighted()){
254      utility::Matrix* data_all = 
255        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
256
257      for (size_t i=0; i<data_all->rows(); ++i) {
258
259        // first some columns from data in kernel_
260        for (size_t j=0; j<row_index_.size(); ++j){
261          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
262        }
263       
264        // last columns are equal to new data
265        for (size_t j=0;j<data.columns(); ++j){
266          (*data_all)(i,j+row_index_.size()) = data(i,j);
267        }
268      }
269      std::vector<size_t> column_index;
270      column_index.reserve(data.columns());
271      for (size_t i=0;i<data.columns(); ++i)
272        column_index.push_back(i+row_index_.size());
273
274      std::vector<size_t> row_index;
275      row_index.reserve(row_index_.size());
276      for (size_t i=0;i<row_index_.size(); ++i)
277        row_index.push_back(i);
278
279      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
280
281      const Kernel* kernel = 
282        kernel_->make_kernel(*tmp, true);
283
284      return new KernelLookup(*kernel, row_index, column_index, true);
285    }
286
287    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
288    // should hold a MatrixLookupweighted.
289    utility::Matrix* data_all = 
290      new utility::Matrix(data.rows(), rows()+data.columns());
291    utility::Matrix* weight_all = 
292      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
293    const MatrixLookupWeighted& kernel_data = 
294      dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
295
296    for (size_t i=0; i<data.rows(); ++i){
297
298      // first some columns from data in kernel_
299      for (size_t j=0; j<row_index_.size(); ++j){
300        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]); 
301        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
302      }
303
304      // last columns are equal to new data
305      for (size_t j=0;j<data.columns(); ++j){
306        (*data_all)(i,j+row_index_.size()) = data(i,j);
307      }
308    }
309    std::vector<size_t> column_index;
310    column_index.reserve(data.columns());
311    for (size_t i=0;i<data.columns(); ++i)
312      column_index.push_back(i+row_index_.size());
313
314    std::vector<size_t> row_index;
315    row_index.reserve(row_index_.size());
316    for (size_t i=0;i<row_index_.size(); ++i)
317      row_index.push_back(i);
318
319    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*data_all, 
320                                                         *weight_all, true);
321    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
322    return new KernelLookup(*kernel, row_index_, column_index, true);
323  }
324
325
326
327  const KernelLookup* 
328  KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
329  {
330    utility::Matrix* data_all = 
331      new utility::Matrix(data.rows(), rows()+data.columns());
332    utility::Matrix* weight_all = 
333      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
334
335    if (weighted()){
336      const MatrixLookupWeighted& kernel_data = 
337        dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
338   
339      for (size_t i=0; i<data.rows(); ++i){
340        // first columns are equal to data in kernel_
341        for (size_t j=0; j<row_index_.size(); ++j){
342          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
343          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
344        }
345      }
346    }
347    else {
348
349        dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
350   
351      for (size_t i=0; i<data.rows(); ++i){
352        // first columns are equal to data in kernel_
353        for (size_t j=0; j<row_index_.size(); ++j)
354          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
355      }
356    }
357
358    // last columns are equal to new data
359    for (size_t i=0; i<data.rows(); ++i){
360      for (size_t j=0;j<data.columns(); ++j){
361        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
362        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
363      }
364    }
365   
366    std::vector<size_t> column_index;
367    column_index.reserve(data.columns());
368    for (size_t i=0;i<data.columns(); ++i)
369      column_index.push_back(i+row_index_.size());
370    const Kernel* kernel = 
371      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
372    return new KernelLookup(*kernel, row_index_, column_index, true);
373  }
374
375
376  const KernelLookup*
377  KernelLookup::training_data(const std::vector<size_t>& train) const
378  {
379    return new KernelLookup(*this,train,train);
380  }
381
382
383  const KernelLookup*
384  KernelLookup::validation_data(const std::vector<size_t>& train,
385                                const std::vector<size_t>& validation) const
386  {
387    return new KernelLookup(*this,train,validation);
388  }
389
390
391  bool KernelLookup::weighted(void) const
392  {
393    return kernel_->weighted();
394  }
395
396
397  double KernelLookup::operator()(size_t row, size_t column) const
398  {
399    return (*kernel_)(row_index_[row],column_index_[column]);
400  }
401
402}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.