source: trunk/c++_tools/classifier/KernelLookup.cc @ 658

Last change on this file since 658 was 658, checked in by Peter, 16 years ago

added function in KernelLookup? to create a KernelLookup? from inner data and outer (passed) data.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 6.7 KB
Line 
1// $Id$
2
3#include <c++_tools/classifier/KernelLookup.h>
4#include <c++_tools/classifier/DataLookup2D.h>
5#include <c++_tools/classifier/MatrixLookup.h>
6#include <c++_tools/classifier/MatrixLookupWeighted.h>
7
8
9#include <cassert>
10#ifndef NDEBUG
11#include <algorithm>
12#endif
13
14namespace theplu {
15namespace classifier { 
16
17  KernelLookup::KernelLookup(const Kernel& kernel, const bool own)
18    : DataLookup2D(), kernel_(&kernel)
19  {
20    if (own)
21      ref_count_ = new u_int(1);
22
23    column_index_.reserve(kernel.size());
24    for(size_t i=0; i<kernel.size(); i++)
25      column_index_.push_back(i);
26    row_index_=column_index_;
27  }
28 
29  KernelLookup::KernelLookup(const Kernel& kernel,
30                             const std::vector<size_t>& row, 
31                             const std::vector<size_t>& column,
32                             const bool owner)
33    : DataLookup2D(row,column), kernel_(&kernel)
34  {
35    if (owner)
36      ref_count_ = new u_int(1);
37
38    // Checking that each row index is less than kernel.rows()
39    assert(row.empty() || 
40           *(std::max_element(row.begin(),row.end()))<kernel_->size());
41    // Checking that each column index is less than kernel.column()
42    assert(column.empty() || 
43           *(std::max_element(column.begin(),column.end()))<kernel_->size());
44
45  }
46
47
48  KernelLookup::KernelLookup(const KernelLookup& other, 
49                             const std::vector<size_t>& row, 
50                             const std::vector<size_t>& column)
51    : DataLookup2D(other,row,column), kernel_(other.kernel_)
52  {
53    ref_count_=other.ref_count_;
54    if (ref_count_)
55      ++(*ref_count_);
56  }
57 
58
59  KernelLookup::KernelLookup(const KernelLookup& other)
60    : DataLookup2D(other), kernel_(other.kernel_)
61  {
62    // Checking that no index is out of range
63    assert(row_index_.empty() || 
64           *(max_element(row_index_.begin(), row_index_.end()))<
65           kernel_->size());
66    assert(column_index_.empty() || 
67           *(max_element(column_index_.begin(), column_index_.end()))<
68           kernel_->size());
69    ref_count_=other.ref_count_;
70    if (ref_count_)
71      ++(*ref_count_);
72
73  }
74 
75
76  KernelLookup::KernelLookup(const KernelLookup& other, 
77                             const std::vector<size_t>& index, 
78                             const bool row)
79    : DataLookup2D(other,index,row), kernel_(other.kernel_)
80  {
81    // Checking that no index is out of range
82    assert(row_index_.empty() || 
83           *(max_element(row_index_.begin(), row_index_.end()))<
84           kernel_->size());
85    assert(column_index_.empty() || 
86           *(max_element(column_index_.begin(), column_index_.end()))<
87           kernel_->size());
88    ref_count_=other.ref_count_;
89    if (ref_count_)
90      ++(*ref_count_);
91  }
92 
93
94  KernelLookup::~KernelLookup(void)
95  {
96    if (ref_count_)
97      if (!--(*ref_count_))
98        delete kernel_;
99  }
100
101  const KernelLookup* 
102  KernelLookup::training_data(const std::vector<size_t>& train) const
103  { 
104    return new KernelLookup(*this,train,train); 
105  } 
106
107
108  const KernelLookup* 
109  KernelLookup::validation_data(const std::vector<size_t>& train, 
110                                const std::vector<size_t>& validation) const
111  { 
112    return new KernelLookup(*this,train,validation); 
113  } 
114
115
116  const KernelLookup* 
117  KernelLookup::selected(const std::vector<size_t>& inputs) const
118  {
119    const Kernel* kernel = kernel_->selected(inputs);
120    return new KernelLookup(*kernel, row_index_, column_index_, true);
121  }
122
123
124  const KernelLookup* KernelLookup::test_kernel(const MatrixLookup& data) const
125  {
126    assert(data.rows()==kernel_->data().rows());
127    if (!weighted()){
128      utility::matrix* data_all = 
129        new utility::matrix(data.rows(), rows()+data.columns());
130      for (size_t i=0; i<data.rows(); ++i){
131        // first columns are equal to data in kernel_
132        for (size_t j=0; j<row_index_.size(); ++j)
133          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
134        // last columns are equal to new data
135        for (size_t j=0;j<data.columns(); ++j)
136          (*data_all)(i,j+row_index_.size()) = data(i,j);
137      }
138      std::vector<size_t> column_index;
139      column_index.reserve(data.columns());
140      for (size_t i=0;i<data.columns(); ++i)
141        column_index.push_back(i+row_index_.size());
142      const Kernel* kernel = 
143        kernel_->make_kernel(MatrixLookup(*data_all, true), true);
144      return new KernelLookup(*kernel, row_index_, column_index, true);
145    }
146
147    // kernel_ holds MatrixLookupWeighted, hence new Kernel also
148    // should hold a MatrixLookupweighted.
149    utility::matrix* data_all = 
150      new utility::matrix(data.rows(), rows()+data.columns());
151    utility::matrix* weight_all = 
152      new utility::matrix(data.rows(), rows()+data.columns(), 1.0);
153    const MatrixLookupWeighted& kernel_data = 
154      dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
155
156    for (size_t i=0; i<data.rows(); ++i){
157      // first columns are equal to data in kernel_
158      for (size_t j=0; j<row_index_.size(); ++j){
159        (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
160        (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
161      }
162      // last columns are equal to new data
163      for (size_t j=0;j<data.columns(); ++j){
164        (*data_all)(i,j+row_index_.size()) = data(i,j);
165      }
166    }
167    std::vector<size_t> column_index;
168    column_index.reserve(data.columns());
169    for (size_t i=0;i<data.columns(); ++i)
170      column_index.push_back(i+row_index_.size());
171    const Kernel* kernel = 
172      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
173    return new KernelLookup(*kernel, row_index_, column_index, true);
174  }
175
176
177
178  const KernelLookup* 
179  KernelLookup::test_kernel(const MatrixLookupWeighted& data) const
180  {
181    utility::matrix* data_all = 
182      new utility::matrix(data.rows(), rows()+data.columns());
183    utility::matrix* weight_all = 
184      new utility::matrix(data.rows(), rows()+data.columns(), 1.0);
185
186    if (weighted()){
187      const MatrixLookupWeighted& kernel_data = 
188        dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
189   
190      for (size_t i=0; i<data.rows(); ++i){
191        // first columns are equal to data in kernel_
192        for (size_t j=0; j<row_index_.size(); ++j){
193          (*data_all)(i,j) = kernel_data.data(i,row_index_[j]);
194          (*weight_all)(i,j) = kernel_data.weight(i,row_index_[j]);
195        }
196      }
197    }
198    else {
199
200        dynamic_cast<const MatrixLookupWeighted&>(kernel_->data());
201   
202      for (size_t i=0; i<data.rows(); ++i){
203        // first columns are equal to data in kernel_
204        for (size_t j=0; j<row_index_.size(); ++j)
205          (*data_all)(i,j) = kernel_->data()(i,row_index_[j]);
206      }
207    }
208
209    // last columns are equal to new data
210    for (size_t i=0; i<data.rows(); ++i){
211      for (size_t j=0;j<data.columns(); ++j){
212        (*data_all)(i,j+row_index_.size()) = data.data(i,j);
213        (*weight_all)(i,j+row_index_.size()) = data.weight(i,j);
214      }
215    }
216   
217    std::vector<size_t> column_index;
218    column_index.reserve(data.columns());
219    for (size_t i=0;i<data.columns(); ++i)
220      column_index.push_back(i+row_index_.size());
221    const Kernel* kernel = 
222      kernel_->make_kernel(MatrixLookupWeighted(*data_all, *weight_all, true));
223    return new KernelLookup(*kernel, row_index_, column_index, true);
224  }
225
226}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.