source: trunk/yat/classifier/KernelLookup.cc @ 720

Last change on this file since 720 was 720, checked in by Jari Häkkinen, 16 years ago

Fixes #170. Almost all inlines removed, some classes have no cc file.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 8.5 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "KernelLookup.h"
25#include "DataLookup2D.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28
29#include <cassert>
30#ifndef NDEBUG
31#include <algorithm>
32#endif
33
34namespace theplu {
35namespace yat {
36namespace classifier {
37
38  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
39    : DataLookup2D(own), kernel_(&kernel)
40  {
41    column_index_.reserve(kernel.size());
42    for(size_t i=0; i<kernel.size(); i++)
43      column_index_.push_back(i);
44    row_index_=column_index_;
45  }
46
47
48  KernelLookup::KernelLookup(const Kernel& kernel,
49                             const std::vector<size_t>& row, 
50                             const std::vector<size_t>& column,
51                             const bool owner)
52    : DataLookup2D(row,column,owner), kernel_(&kernel)
53  {
54    // Checking that each row index is less than kernel.rows()
55    assert(row.empty() || 
56           *(std::max_element(row.begin(),row.end()))<kernel_->size());
57    // Checking that each column index is less than kernel.column()
58    assert(column.empty() || 
59           *(std::max_element(column.begin(),column.end()))<kernel_->size());
60  }
61
62
63  KernelLookup::KernelLookup(const KernelLookup& other, 
64                             const std::vector<size_t>& row, 
65                             const std::vector<size_t>& column)
66    : DataLookup2D(other,row,column), kernel_(other.kernel_)
67  {
68    ref_count_=other.ref_count_;
69    if (ref_count_)
70      ++(*ref_count_);
71  }
72 
73
74  KernelLookup::KernelLookup(const KernelLookup& other)
75    : DataLookup2D(other), kernel_(other.kernel_)
76  {
77    // Checking that no index is out of range
78    assert(row_index_.empty() || 
79           *(max_element(row_index_.begin(), row_index_.end()))<
80           kernel_->size());
81    assert(column_index_.empty() || 
82           *(max_element(column_index_.begin(), column_index_.end()))<
83           kernel_->size());
84    ref_count_=other.ref_count_;
85    if (ref_count_)
86      ++(*ref_count_);
87  }
88 
89
90  KernelLookup::KernelLookup(const KernelLookup& other, 
91                             const std::vector<size_t>& index, 
92                             const bool row)
93    : DataLookup2D(other,index,row), kernel_(other.kernel_)
94  {
95    assert(kernel_->size());
96
97    // Checking that no index is out of range
98    assert(row_index_.empty() || 
99           *(max_element(row_index_.begin(), row_index_.end()))<
100           kernel_->size());
101    assert(column_index_.empty() || 
102           *(max_element(column_index_.begin(), column_index_.end()))<
103           kernel_->size());
104    ref_count_=other.ref_count_;
105    if (ref_count_)
106      ++(*ref_count_);
107  }
108 
109
110  KernelLookup::~KernelLookup(void)
111  {
112    if (ref_count_)
113      if (!--(*ref_count_))
114        delete kernel_;
115  }
116
117
118  const DataLookup2D* KernelLookup::data(void) const
119  {
120    return kernel_->data().training_data(column_index_);
121  }
122
123
124  double KernelLookup::element(const DataLookup1D& vec, size_t i) const
125  {
126    return kernel_->element(vec, row_index_[i]);
127  }
128
129
130  double KernelLookup::element(const DataLookupWeighted1D& vec, size_t i) const
131  {
132    return kernel_->element(vec, row_index_[i]);
133  }
134
135
136  const Kernel* KernelLookup::kernel(void) const
137  {
138    return kernel_;
139  }
140
141
142  const KernelLookup* 
143  KernelLookup::selected(const std::vector<size_t>& inputs) const
144  {
145    const Kernel* kernel = kernel_->selected(inputs);
146    return new KernelLookup(*kernel, row_index_, column_index_, true);
147  }
148
149
150  const KernelLookup* KernelLookup::test_kernel(const MatrixLookup& data) const
151  {
152
153    assert(data.rows()==kernel_->data().rows());
154    if (!weighted()){
155      utility::matrix* data_all = 
156        new utility::matrix(data.rows(), row_index_.size()+data.columns());
157
158      for (size_t i=0; i<data_all->rows(); ++i) {
159
160        // first some columns from data in kernel_
161        for (size_t j=0; j<row_index_.size(); ++j){
162          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]); 
163        }
164       
165        // last columns are equal to new data
166        for (size_t j=0;j<data.columns(); ++j){
167          (*data_all)(i,j+row_index_.size()) = data(i,j);
168        }
169      }
170      std::vector<size_t> column_index;
171      column_index.reserve(data.columns());
172      for (size_t i=0;i<data.columns(); ++i)
173        column_index.push_back(i+row_index_.size());
174
175      std::vector<size_t> row_index;
176      row_index.reserve(row_index_.size());
177      for (size_t i=0;i<row_index_.size(); ++i)
178        row_index.push_back(i);
179
180      const MatrixLookup* tmp = new MatrixLookup(*data_all, true);
181
182      const Kernel* kernel = 
183        kernel_->make_kernel(*tmp, true);
184
185      return new KernelLookup(*kernel, row_index, column_index, true);
186    }
187
188    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
189    // should hold a MatrixLookupweighted.
190    utility::matrix* data_all = 
191      new utility::matrix(data.rows(), rows()+data.columns());
192    utility::matrix* weight_all = 
193      new utility::matrix(data.rows(), rows()+data.columns(), 1.0);
194    const MatrixLookupWeighted& kernel_data = 
195      dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
196
197    for (size_t i=0; i<data.rows(); ++i){
198
199      // first some columns from data in kernel_
200      for (size_t j=0; j<row_index_.size(); ++j){
201        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]); 
202        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
203      }
204
205      // last columns are equal to new data
206      for (size_t j=0;j<data.columns(); ++j){
207        (*data_all)(i,j+row_index_.size()) = data(i,j);
208      }
209    }
210    std::vector<size_t> column_index;
211    column_index.reserve(data.columns());
212    for (size_t i=0;i<data.columns(); ++i)
213      column_index.push_back(i+row_index_.size());
214
215    std::vector<size_t> row_index;
216    row_index.reserve(row_index_.size());
217    for (size_t i=0;i<row_index_.size(); ++i)
218      row_index.push_back(i);
219
220    MatrixLookupWeighted* tmp = new MatrixLookupWeighted(*data_all, 
221                                                         *weight_all, true);
222    const Kernel* kernel = kernel_->make_kernel(*tmp, true);
223    return new KernelLookup(*kernel, row_index_, column_index, true);
224  }
225
226
227
228  const KernelLookup* 
229  KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
230  {
231    utility::matrix* data_all = 
232      new utility::matrix(data.rows(), rows()+data.columns());
233    utility::matrix* weight_all = 
234      new utility::matrix(data.rows(), rows()+data.columns(), 1.0);
235
236    if (weighted()){
237      const MatrixLookupWeighted& kernel_data = 
238        dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
239   
240      for (size_t i=0; i<data.rows(); ++i){
241        // first columns are equal to data in kernel_
242        for (size_t j=0; j<row_index_.size(); ++j){
243          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
244          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
245        }
246      }
247    }
248    else {
249
250        dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
251   
252      for (size_t i=0; i<data.rows(); ++i){
253        // first columns are equal to data in kernel_
254        for (size_t j=0; j<row_index_.size(); ++j)
255          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
256      }
257    }
258
259    // last columns are equal to new data
260    for (size_t i=0; i<data.rows(); ++i){
261      for (size_t j=0;j<data.columns(); ++j){
262        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
263        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
264      }
265    }
266   
267    std::vector<size_t> column_index;
268    column_index.reserve(data.columns());
269    for (size_t i=0;i<data.columns(); ++i)
270      column_index.push_back(i+row_index_.size());
271    const Kernel* kernel = 
272      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
273    return new KernelLookup(*kernel, row_index_, column_index, true);
274  }
275
276
277  const KernelLookup*
278  KernelLookup::training_data(const std::vector<size_t>& train) const
279  {
280    return new KernelLookup(*this,train,train);
281  }
282
283
284  const KernelLookup*
285  KernelLookup::validation_data(const std::vector<size_t>& train,
286                                const std::vector<size_t>& validation) const
287  {
288    return new KernelLookup(*this,train,validation);
289  }
290
291
292  bool KernelLookup::weighted(void) const
293  {
294    return kernel_->weighted();
295  }
296
297
298  double KernelLookup::operator()(size_t row, size_t column) const
299  {
300    return (*kernel_)(row_index_[row],column_index_[column]);
301  }
302
303}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.