source: trunk/yat/classifier/KernelLookup.cc @ 1127

Last change on this file since 1127 was 1127, checked in by Peter, 14 years ago

undoing [1126] which did not compile. forgot to run make check. sorry for any...

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 10.0 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005, 2006, 2007 Jari Häkkinen, Peter Johansson
5
6  This file is part of the yat library, http://trac.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "KernelLookup.h"
25#include "DataLookup2D.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "yat/utility/Matrix.h"
29
30#include <cassert>
31#ifndef NDEBUG
32#include <algorithm>
33#endif
34
35namespace theplu {
36namespace yat {
37namespace classifier {
38
39  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
40    : DataLookup2D(own), kernel_(&kernel)
41  {
42    column_index_.reserve(kernel.size());
43    for(size_t i=0; i<kernel.size(); i++)
44      column_index_.push_back(i);
45    row_index_=column_index_;
46  }
47
48
49  KernelLookup::KernelLookup(const Kernel& kernel,
50                             const std::vector<size_t>& row, 
51                             const std::vector<size_t>& column,
52                             const bool owner)
53    : DataLookup2D(row,column,owner), kernel_(&kernel)
54  {
55    // Checking that each row index is less than kernel.rows()
56    assert(row.empty() || 
57           *(std::max_element(row.begin(),row.end()))<kernel_->size());
58    // Checking that each column index is less than kernel.column()
59    assert(column.empty() || 
60           *(std::max_element(column.begin(),column.end()))<kernel_->size());
61  }
62
63
64  KernelLookup::KernelLookup(const KernelLookup& other, 
65                             const std::vector<size_t>& row, 
66                             const std::vector<size_t>& column)
67    : DataLookup2D(other,row,column), kernel_(other.kernel_)
68  {
69    ref_count_=other.ref_count_;
70    if (ref_count_)
71      ++(*ref_count_);
72  }
73 
74
75  KernelLookup::KernelLookup(const KernelLookup& other)
76    : DataLookup2D(other), kernel_(other.kernel_)
77  {
78    // Checking that no index is out of range
79    assert(row_index_.empty() || 
80           *(max_element(row_index_.begin(), row_index_.end()))<
81           kernel_->size());
82    assert(column_index_.empty() || 
83           *(max_element(column_index_.begin(), column_index_.end()))<
84           kernel_->size());
85    ref_count_=other.ref_count_;
86    if (ref_count_)
87      ++(*ref_count_);
88  }
89 
90
91  KernelLookup::KernelLookup(const KernelLookup& other, 
92                             const std::vector<size_t>& index, 
93                             const bool row)
94    : DataLookup2D(other,index,row), kernel_(other.kernel_)
95  {
96    assert(kernel_->size());
97
98    // Checking that no index is out of range
99    assert(row_index_.empty() || 
100           *(max_element(row_index_.begin(), row_index_.end()))<
101           kernel_->size());
102    assert(column_index_.empty() || 
103           *(max_element(column_index_.begin(), column_index_.end()))<
104           kernel_->size());
105    ref_count_=other.ref_count_;
106    if (ref_count_)
107      ++(*ref_count_);
108  }
109 
110
111  KernelLookup::~KernelLookup(void)
112  {
113    if (ref_count_)
114      if (!--(*ref_count_))
115        delete kernel_;
116  }
117
118
119  KernelLookup::const_iterator KernelLookup::begin(void) const
120  {
121    return const_iterator(const_iterator::iterator_type(*this, 0, 0), 1);
122  }
123
124
125  KernelLookup::const_column_iterator
126  KernelLookup::begin_column(size_t i) const
127  {
128    return const_column_iterator(const_column_iterator::iterator_type(*this,
129                                                                      0,i), 
130                                 columns());
131  }
132
133
134  KernelLookup::const_row_iterator KernelLookup::begin_row(size_t i) const
135  {
136    return const_row_iterator(const_row_iterator::iterator_type(*this,i,0), 1);
137  }
138
139
140  const DataLookup2D* KernelLookup::data(void) const
141  {
142    return kernel_->data().training_data(column_index_);
143  }
144
145
146  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
147  {
148    return kernel_->element(vec, row_index_[i]);
149  }
150
151
152  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
153  {
154    return kernel_->element(vec, row_index_[i]);
155  }
156
157
158  KernelLookup::const_iterator KernelLookup::end(void) const
159  {
160    return const_iterator(const_iterator::iterator_type(*this, rows(), 0), 1);
161  }
162
163
164  KernelLookup::const_column_iterator KernelLookup::end_column(size_t i) const
165  {
166    return const_column_iterator(const_column_iterator::iterator_type(*this, 
167                                                                      rows(),i),
168                                 columns());
169  }
170
171
172  KernelLookup::const_row_iterator KernelLookup::end_row(size_t i) const
173  {
174    return const_row_iterator(const_row_iterator::iterator_type(*this,i+1,0),1);
175  }
176
177
178  const KernelLookup* 
179  KernelLookup::selected(const std::vector<size_t>& inputs) const
180  {
181    const Kernel* kernel;
182    if (kernel_->weighted()){
183      const MatrixLookupWeighted* ml = 
184        dynamic_cast<const MatrixLookupWeighted*>(data());
185      assert(ml);
186      const MatrixLookupWeighted* ms = 
187        new MatrixLookupWeighted(*ml,inputs,true);
188      kernel = kernel_->make_kernel(*ms, false);
189    }
190    else {
191      const MatrixLookup* m = 
192        dynamic_cast<const MatrixLookup*>(data());
193      assert(m);
194      // matrix with selected features
195      const MatrixLookup* ms = new MatrixLookup(*m,inputs,true);
196      kernel = kernel_->make_kernel(*ms,true);
197    }
198    return new KernelLookup(*kernel, true);
199  }
200
201
202  const KernelLookup* KernelLookup::test_kernel(const MatrixLookup& data) const
203  {
204
205    assert(data.rows()==kernel_->data().rows());
206    if (!weighted()){
207      utility::Matrix* data_all = 
208        new utility::Matrix(data.rows(), row_index_.size()+data.columns());
209
210      for (size_t i=0; i<data_all->rows(); ++i) {
211
212        // first some columns from data in kernel_
213        for (size_t j=0; j<row_index_.size(); ++j){
214          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
215        }
216       
217        // last columns are equal to new data
218        for (size_t j=0;j<data.columns(); ++j){
219          (*data_all)(i,j+row_index_.size()) = data(i,j);
220        }
221      }
222      std::vector<size_t> column_index;
223      column_index.reserve(data.columns());
224      for (size_t i=0;i<data.columns(); ++i)
225        column_index.push_back(i+row_index_.size());
226
227      std::vector<size_t> row_index;
228      row_index.reserve(row_index_.size());
229      for (size_t i=0;i<row_index_.size(); ++i)
230        row_index.push_back(i);
231
232      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
233
234      const Kernel* kernel = 
235        kernel_->make_kernel(*tmp, true);
236
237      return new KernelLookup(*kernel, row_index, column_index, true);
238    }
239
240    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
241    // should hold a MatrixLookupweighted.
242    utility::Matrix* data_all = 
243      new utility::Matrix(data.rows(), rows()+data.columns());
244    utility::Matrix* weight_all = 
245      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
246    const MatrixLookupWeighted& kernel_data = 
247      dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
248
249    for (size_t i=0; i<data.rows(); ++i){
250
251      // first some columns from data in kernel_
252      for (size_t j=0; j<row_index_.size(); ++j){
253        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]); 
254        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
255      }
256
257      // last columns are equal to new data
258      for (size_t j=0;j<data.columns(); ++j){
259        (*data_all)(i,j+row_index_.size()) = data(i,j);
260      }
261    }
262    std::vector<size_t> column_index;
263    column_index.reserve(data.columns());
264    for (size_t i=0;i<data.columns(); ++i)
265      column_index.push_back(i+row_index_.size());
266
267    std::vector<size_t> row_index;
268    row_index.reserve(row_index_.size());
269    for (size_t i=0;i<row_index_.size(); ++i)
270      row_index.push_back(i);
271
272    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*data_all, 
273                                                         *weight_all, true);
274    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
275    return new KernelLookup(*kernel, row_index_, column_index, true);
276  }
277
278
279
280  const KernelLookup* 
281  KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
282  {
283    utility::Matrix* data_all = 
284      new utility::Matrix(data.rows(), rows()+data.columns());
285    utility::Matrix* weight_all = 
286      new utility::Matrix(data.rows(), rows()+data.columns(), 1.0);
287
288    if (weighted()){
289      const MatrixLookupWeighted& kernel_data = 
290        dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
291   
292      for (size_t i=0; i<data.rows(); ++i){
293        // first columns are equal to data in kernel_
294        for (size_t j=0; j<row_index_.size(); ++j){
295          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
296          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
297        }
298      }
299    }
300    else {
301
302        dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
303   
304      for (size_t i=0; i<data.rows(); ++i){
305        // first columns are equal to data in kernel_
306        for (size_t j=0; j<row_index_.size(); ++j)
307          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
308      }
309    }
310
311    // last columns are equal to new data
312    for (size_t i=0; i<data.rows(); ++i){
313      for (size_t j=0;j<data.columns(); ++j){
314        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
315        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
316      }
317    }
318   
319    std::vector<size_t> column_index;
320    column_index.reserve(data.columns());
321    for (size_t i=0;i<data.columns(); ++i)
322      column_index.push_back(i+row_index_.size());
323    const Kernel* kernel = 
324      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
325    return new KernelLookup(*kernel, row_index_, column_index, true);
326  }
327
328
329  const KernelLookup*
330  KernelLookup::training_data(const std::vector<size_t>& train) const
331  {
332    return new KernelLookup(*this,train,train);
333  }
334
335
336  const KernelLookup*
337  KernelLookup::validation_data(const std::vector<size_t>& train,
338                                const std::vector<size_t>& validation) const
339  {
340    return new KernelLookup(*this,train,validation);
341  }
342
343
344  bool KernelLookup::weighted(void) const
345  {
346    return kernel_->weighted();
347  }
348
349
350  double KernelLookup::operator()(size_t row, size_t column) const
351  {
352    return (*kernel_)(row_index_[row],column_index_[column]);
353  }
354
355}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.