source: trunk/yat/classifier/NBC.cc

Last change on this file was 2919, checked in by Peter, 9 years ago

update copyright years

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 6.1 KB
Line 
1// $Id: NBC.cc 2919 2012-12-19 06:54:23Z peter $
2
3/*
4  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
5  Copyright (C) 2012 Peter Johansson
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include <config.h>
24
25#include "NBC.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "Target.h"
29#include "yat/statistics/Averager.h"
30#include "yat/statistics/AveragerWeighted.h"
31#include "yat/utility/Matrix.h"
32#include "yat/utility/WeightedIterator.h"
33
34#include <cassert>
35#include <cmath>
36#include <limits>
37#include <stdexcept>
38#include <vector>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  NBC::NBC() 
45    : SupervisedClassifier()
46  {
47  }
48
49
50  NBC::~NBC()   
51  {
52  }
53
54
55  NBC* NBC::make_classifier() const 
56  {     
57    return new NBC();
58  }
59
60
61  void NBC::train(const MatrixLookup& data, const Target& target)
62  {   
63    sigma2_.resize(data.rows(), target.nof_classes());
64    centroids_.resize(data.rows(), target.nof_classes());
65   
66    for(size_t i=0; i<data.rows(); ++i) {
67      std::vector<statistics::Averager> aver(target.nof_classes());
68      for(size_t j=0; j<data.columns(); ++j) 
69        aver[target(j)].add(data(i,j));
70     
71      assert(centroids_.columns()==target.nof_classes());
72      for (size_t j=0; j<target.nof_classes(); ++j){
73        assert(i<centroids_.rows());
74        assert(j<centroids_.columns());
75        assert(i<sigma2_.rows());
76        assert(j<sigma2_.columns());
77        if (aver[j].n()>1){
78          sigma2_(i,j) = aver[j].variance();
79          centroids_(i,j) = aver[j].mean();
80        }
81        else {
82            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
83            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
84        }
85      }
86    }
87  }   
88
89
90  void NBC::train(const MatrixLookupWeighted& data, const Target& target)
91  {   
92    sigma2_.resize(data.rows(), target.nof_classes());
93    centroids_.resize(data.rows(), target.nof_classes());
94
95    for(size_t i=0; i<data.rows(); ++i) {
96      std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
97      for(size_t j=0; j<data.columns(); ++j) 
98        aver[target(j)].add(data.data(i,j), data.weight(i,j));
99     
100      assert(centroids_.columns()==target.nof_classes());
101      for (size_t j=0; j<target.nof_classes(); ++j) {
102        assert(i<centroids_.rows());
103        assert(j<centroids_.columns());
104        assert(i<sigma2_.rows());
105        assert(j<sigma2_.columns());
106        if (aver[j].n()>1){
107          sigma2_(i,j) = aver[j].variance();
108          centroids_(i,j) = aver[j].mean();
109        }
110        else {
111          sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
112          centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
113        }
114      }
115    }
116  }
117
118
119  void NBC::predict(const MatrixLookup& ml,                     
120                    utility::Matrix& prediction) const
121  {   
122    assert(ml.rows()==sigma2_.rows());
123    assert(ml.rows()==centroids_.rows());
124    // each row in prediction corresponds to a sample label (class)
125    prediction.resize(centroids_.columns(), ml.columns(), 0);
126
127    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
128    for (size_t label=0; label<centroids_.columns(); ++label) {
129      double sum_log_sigma = sum_logsigma(label);
130      for (size_t sample=0; sample<prediction.rows(); ++sample) {
131        prediction(label,sample) = sum_log_sigma;
132        for (size_t i=0; i<ml.rows(); ++i) 
133          prediction(label, sample) += 
134            std::pow(ml(i, label)-centroids_(i, label),2)/
135            sigma2_(i, label);
136      }
137    }
138    standardize_lnP(prediction);
139  }
140
141     
142  void NBC::predict(const MatrixLookupWeighted& mlw,                   
143                    utility::Matrix& prediction) const
144  {   
145    assert(mlw.rows()==sigma2_.rows());
146    assert(mlw.rows()==centroids_.rows());
147   
148    // each row in prediction corresponds to a sample label (class)
149    prediction.resize(centroids_.columns(), mlw.columns(), 0);
150
151    // first calculate -lnP = sum (sigma_i) +
152    // N sum w_i(x_i-m_i)^2/2sigma_i^2 / sum w_i
153    for (size_t label=0; label<centroids_.columns(); ++label) {
154      double sum_log_sigma = sum_logsigma(label);
155      for (size_t sample=0; sample<prediction.rows(); ++sample) {
156        statistics::AveragerWeighted aw;
157        for (size_t i=0; i<mlw.rows(); ++i) 
158          aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
159                 sigma2_(i, label), mlw.weight(i, label));
160        prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2;
161      }
162    }
163    standardize_lnP(prediction);
164  }
165
166  void NBC::standardize_lnP(utility::Matrix& prediction) const
167  {
168    /// -lnP might be a large number, in order to avoid out of bound
169    /// problems when calculating P = exp(- -lnP), we centralize matrix
170    /// by adding a constant.
171    utility::Matrix weights;
172    // create zero/unity weight matrix (w=0 if NaN)
173    nan(prediction, weights);
174    using utility::weighted_iterator;
175    statistics::AveragerWeighted a;
176    add(a, weighted_iterator(prediction.begin(), weights.begin()),
177        weighted_iterator(prediction.end(), weights.end()));
178    prediction -= a.mean();
179
180    // exponentiate
181    for (size_t i=0; i<prediction.rows(); ++i)
182      for (size_t j=0; j<prediction.columns(); ++j)
183        prediction(i,j) = std::exp(prediction(i,j));
184
185    // normalize each row (label) to sum up to unity (probability)
186    for (size_t i=0; i<prediction.rows(); ++i){
187      // calculate sum of row ignoring NaNs
188      statistics::AveragerWeighted a;
189      add(a, weighted_iterator(prediction.begin_row(i), weights.begin_row(i)),
190          weighted_iterator(prediction.end_row(i), weights.end_row(i)));
191      prediction.row_view(i) *= 1.0/a.sum_wx();
192    }
193  }
194
195
196  double NBC::sum_logsigma(size_t label) const
197  {
198    double sum_log_sigma=0;
199    assert(label<sigma2_.columns());
200    for (size_t i=0; i<sigma2_.rows(); ++i) {
201      sum_log_sigma += std::log(sigma2_(i, label));
202    }
203    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
204  }
205
206}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.