1 | // $Id: NBC.cc 1392 2008-07-28 19:35:30Z peter $ |
---|
2 | |
---|
3 | /* |
---|
4 | Copyright (C) 2006, 2007 Jari Häkkinen, Peter Johansson, Markus Ringnér |
---|
5 | Copyright (C) 2008 Peter Johansson, Markus Ringnér |
---|
6 | |
---|
7 | This file is part of the yat library, http://dev.thep.lu.se/yat |
---|
8 | |
---|
9 | The yat library is free software; you can redistribute it and/or |
---|
10 | modify it under the terms of the GNU General Public License as |
---|
11 | published by the Free Software Foundation; either version 2 of the |
---|
12 | License, or (at your option) any later version. |
---|
13 | |
---|
14 | The yat library is distributed in the hope that it will be useful, |
---|
15 | but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
16 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
---|
17 | General Public License for more details. |
---|
18 | |
---|
19 | You should have received a copy of the GNU General Public License |
---|
20 | along with this program; if not, write to the Free Software |
---|
21 | Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA |
---|
22 | 02111-1307, USA. |
---|
23 | */ |
---|
24 | |
---|
25 | #include "NBC.h" |
---|
26 | #include "MatrixLookup.h" |
---|
27 | #include "MatrixLookupWeighted.h" |
---|
28 | #include "Target.h" |
---|
29 | #include "yat/statistics/Averager.h" |
---|
30 | #include "yat/statistics/AveragerWeighted.h" |
---|
31 | #include "yat/utility/Matrix.h" |
---|
32 | |
---|
33 | #include <cassert> |
---|
34 | #include <cmath> |
---|
35 | #include <stdexcept> |
---|
36 | #include <vector> |
---|
37 | |
---|
38 | namespace theplu { |
---|
39 | namespace yat { |
---|
40 | namespace classifier { |
---|
41 | |
---|
42 | NBC::NBC() |
---|
43 | : SupervisedClassifier() |
---|
44 | { |
---|
45 | } |
---|
46 | |
---|
47 | |
---|
48 | NBC::~NBC() |
---|
49 | { |
---|
50 | } |
---|
51 | |
---|
52 | |
---|
53 | NBC* NBC::make_classifier() const |
---|
54 | { |
---|
55 | return new NBC(); |
---|
56 | } |
---|
57 | |
---|
58 | |
---|
59 | void NBC::train(const MatrixLookup& data, const Target& target) |
---|
60 | { |
---|
61 | sigma2_.resize(data.rows(), target.nof_classes()); |
---|
62 | centroids_.resize(data.rows(), target.nof_classes()); |
---|
63 | |
---|
64 | for(size_t i=0; i<data.rows(); ++i) { |
---|
65 | std::vector<statistics::Averager> aver(target.nof_classes()); |
---|
66 | for(size_t j=0; j<data.columns(); ++j) |
---|
67 | aver[target(j)].add(data(i,j)); |
---|
68 | |
---|
69 | assert(centroids_.columns()==target.nof_classes()); |
---|
70 | for (size_t j=0; j<target.nof_classes(); ++j){ |
---|
71 | assert(i<centroids_.rows()); |
---|
72 | assert(j<centroids_.columns()); |
---|
73 | assert(i<sigma2_.rows()); |
---|
74 | assert(j<sigma2_.columns()); |
---|
75 | if (aver[j].n()>1){ |
---|
76 | sigma2_(i,j) = aver[j].variance(); |
---|
77 | centroids_(i,j) = aver[j].mean(); |
---|
78 | } |
---|
79 | else { |
---|
80 | sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN(); |
---|
81 | centroids_(i,j) = std::numeric_limits<double>::quiet_NaN(); |
---|
82 | } |
---|
83 | } |
---|
84 | } |
---|
85 | } |
---|
86 | |
---|
87 | |
---|
88 | void NBC::train(const MatrixLookupWeighted& data, const Target& target) |
---|
89 | { |
---|
90 | sigma2_.resize(data.rows(), target.nof_classes()); |
---|
91 | centroids_.resize(data.rows(), target.nof_classes()); |
---|
92 | |
---|
93 | for(size_t i=0; i<data.rows(); ++i) { |
---|
94 | std::vector<statistics::AveragerWeighted> aver(target.nof_classes()); |
---|
95 | for(size_t j=0; j<data.columns(); ++j) |
---|
96 | aver[target(j)].add(data.data(i,j), data.weight(i,j)); |
---|
97 | |
---|
98 | assert(centroids_.columns()==target.nof_classes()); |
---|
99 | for (size_t j=0; j<target.nof_classes(); ++j) { |
---|
100 | assert(i<centroids_.rows()); |
---|
101 | assert(j<centroids_.columns()); |
---|
102 | assert(i<sigma2_.rows()); |
---|
103 | assert(j<sigma2_.columns()); |
---|
104 | if (aver[j].n()>1){ |
---|
105 | sigma2_(i,j) = aver[j].variance(); |
---|
106 | centroids_(i,j) = aver[j].mean(); |
---|
107 | } |
---|
108 | else { |
---|
109 | sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN(); |
---|
110 | centroids_(i,j) = std::numeric_limits<double>::quiet_NaN(); |
---|
111 | } |
---|
112 | } |
---|
113 | } |
---|
114 | } |
---|
115 | |
---|
116 | |
---|
117 | void NBC::predict(const MatrixLookup& ml, |
---|
118 | utility::Matrix& prediction) const |
---|
119 | { |
---|
120 | assert(ml.rows()==sigma2_.rows()); |
---|
121 | assert(ml.rows()==centroids_.rows()); |
---|
122 | // each row in prediction corresponds to a sample label (class) |
---|
123 | prediction.resize(centroids_.columns(), ml.columns(), 0); |
---|
124 | |
---|
125 | // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2 |
---|
126 | for (size_t label=0; label<centroids_.columns(); ++label) { |
---|
127 | double sum_log_sigma = sum_logsigma(label); |
---|
128 | for (size_t sample=0; sample<prediction.rows(); ++sample) { |
---|
129 | prediction(label,sample) = sum_log_sigma; |
---|
130 | for (size_t i=0; i<ml.rows(); ++i) |
---|
131 | prediction(label, sample) += |
---|
132 | std::pow(ml(i, label)-centroids_(i, label),2)/ |
---|
133 | sigma2_(i, label); |
---|
134 | } |
---|
135 | } |
---|
136 | standardize_lnP(prediction); |
---|
137 | } |
---|
138 | |
---|
139 | |
---|
140 | void NBC::predict(const MatrixLookupWeighted& mlw, |
---|
141 | utility::Matrix& prediction) const |
---|
142 | { |
---|
143 | assert(mlw.rows()==sigma2_.rows()); |
---|
144 | assert(mlw.rows()==centroids_.rows()); |
---|
145 | |
---|
146 | // each row in prediction corresponds to a sample label (class) |
---|
147 | prediction.resize(centroids_.columns(), mlw.columns(), 0); |
---|
148 | |
---|
149 | // first calculate -lnP = sum (sigma_i) + |
---|
150 | // N sum w_i(x_i-m_i)^2/2sigma_i^2 / sum w_i |
---|
151 | for (size_t label=0; label<centroids_.columns(); ++label) { |
---|
152 | double sum_log_sigma = sum_logsigma(label); |
---|
153 | for (size_t sample=0; sample<prediction.rows(); ++sample) { |
---|
154 | statistics::AveragerWeighted aw; |
---|
155 | for (size_t i=0; i<mlw.rows(); ++i) |
---|
156 | aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/ |
---|
157 | sigma2_(i, label), mlw.weight(i, label)); |
---|
158 | prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2; |
---|
159 | } |
---|
160 | } |
---|
161 | standardize_lnP(prediction); |
---|
162 | } |
---|
163 | |
---|
164 | void NBC::standardize_lnP(utility::Matrix& prediction) const |
---|
165 | { |
---|
166 | /// -lnP might be a large number, in order to avoid out of bound |
---|
167 | /// problems when calculating P = exp(- -lnP), we centralize matrix |
---|
168 | /// by adding a constant. |
---|
169 | // lookup of prediction with zero weights for NaNs |
---|
170 | MatrixLookupWeighted mlw(prediction); |
---|
171 | statistics::AveragerWeighted a; |
---|
172 | add(a, mlw.begin(), mlw.end()); |
---|
173 | prediction -= a.mean(); |
---|
174 | |
---|
175 | // exponentiate |
---|
176 | for (size_t i=0; i<prediction.rows(); ++i) |
---|
177 | for (size_t j=0; j<prediction.columns(); ++j) |
---|
178 | prediction(i,j) = std::exp(prediction(i,j)); |
---|
179 | |
---|
180 | // normalize each row (label) to sum up to unity (probability) |
---|
181 | for (size_t i=0; i<prediction.rows(); ++i){ |
---|
182 | // calculate sum of row ignoring NaNs |
---|
183 | statistics::AveragerWeighted a; |
---|
184 | add(a, mlw.begin_row(i), mlw.end_row(i)); |
---|
185 | prediction.row_view(i) *= 1.0/a.sum_wx(); |
---|
186 | } |
---|
187 | } |
---|
188 | |
---|
189 | |
---|
190 | double NBC::sum_logsigma(size_t label) const |
---|
191 | { |
---|
192 | double sum_log_sigma=0; |
---|
193 | assert(label<sigma2_.columns()); |
---|
194 | for (size_t i=0; i<sigma2_.rows(); ++i) { |
---|
195 | sum_log_sigma += std::log(sigma2_(i, label)); |
---|
196 | } |
---|
197 | return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2 |
---|
198 | } |
---|
199 | |
---|
200 | }}} // of namespace classifier, yat, and theplu |
---|