source: trunk/yat/classifier/NBC.cc @ 1580

Last change on this file since 1580 was 1487, checked in by Jari Häkkinen, 13 years ago

Addresses #436. GPL license copy reference should also be updated.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 5.9 KB
Line 
1// $Id: NBC.cc 1487 2008-09-10 08:41:36Z jari $
2
3/*
4  Copyright (C) 2006, 2007 Jari Häkkinen, Peter Johansson, Markus Ringnér
5  Copyright (C) 2008 Peter Johansson, Markus Ringnér
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include "NBC.h"
24#include "MatrixLookup.h"
25#include "MatrixLookupWeighted.h"
26#include "Target.h"
27#include "yat/statistics/Averager.h"
28#include "yat/statistics/AveragerWeighted.h"
29#include "yat/utility/Matrix.h"
30
31#include <cassert>
32#include <cmath>
33#include <limits>
34#include <stdexcept>
35#include <vector>
36
37namespace theplu {
38namespace yat {
39namespace classifier {
40
41  NBC::NBC() 
42    : SupervisedClassifier()
43  {
44  }
45
46
47  NBC::~NBC()   
48  {
49  }
50
51
52  NBC* NBC::make_classifier() const 
53  {     
54    return new NBC();
55  }
56
57
58  void NBC::train(const MatrixLookup& data, const Target& target)
59  {   
60    sigma2_.resize(data.rows(), target.nof_classes());
61    centroids_.resize(data.rows(), target.nof_classes());
62   
63    for(size_t i=0; i<data.rows(); ++i) {
64      std::vector<statistics::Averager> aver(target.nof_classes());
65      for(size_t j=0; j<data.columns(); ++j) 
66        aver[target(j)].add(data(i,j));
67     
68      assert(centroids_.columns()==target.nof_classes());
69      for (size_t j=0; j<target.nof_classes(); ++j){
70        assert(i<centroids_.rows());
71        assert(j<centroids_.columns());
72        assert(i<sigma2_.rows());
73        assert(j<sigma2_.columns());
74        if (aver[j].n()>1){
75          sigma2_(i,j) = aver[j].variance();
76          centroids_(i,j) = aver[j].mean();
77        }
78        else {
79            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
80            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
81        }
82      }
83    }
84  }   
85
86
87  void NBC::train(const MatrixLookupWeighted& data, const Target& target)
88  {   
89    sigma2_.resize(data.rows(), target.nof_classes());
90    centroids_.resize(data.rows(), target.nof_classes());
91
92    for(size_t i=0; i<data.rows(); ++i) {
93      std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
94      for(size_t j=0; j<data.columns(); ++j) 
95        aver[target(j)].add(data.data(i,j), data.weight(i,j));
96     
97      assert(centroids_.columns()==target.nof_classes());
98      for (size_t j=0; j<target.nof_classes(); ++j) {
99        assert(i<centroids_.rows());
100        assert(j<centroids_.columns());
101        assert(i<sigma2_.rows());
102        assert(j<sigma2_.columns());
103        if (aver[j].n()>1){
104          sigma2_(i,j) = aver[j].variance();
105          centroids_(i,j) = aver[j].mean();
106        }
107        else {
108          sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
109          centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
110        }
111      }
112    }
113  }
114
115
116  void NBC::predict(const MatrixLookup& ml,                     
117                    utility::Matrix& prediction) const
118  {   
119    assert(ml.rows()==sigma2_.rows());
120    assert(ml.rows()==centroids_.rows());
121    // each row in prediction corresponds to a sample label (class)
122    prediction.resize(centroids_.columns(), ml.columns(), 0);
123
124    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
125    for (size_t label=0; label<centroids_.columns(); ++label) {
126      double sum_log_sigma = sum_logsigma(label);
127      for (size_t sample=0; sample<prediction.rows(); ++sample) {
128        prediction(label,sample) = sum_log_sigma;
129        for (size_t i=0; i<ml.rows(); ++i) 
130          prediction(label, sample) += 
131            std::pow(ml(i, label)-centroids_(i, label),2)/
132            sigma2_(i, label);
133      }
134    }
135    standardize_lnP(prediction);
136  }
137
138     
139  void NBC::predict(const MatrixLookupWeighted& mlw,                   
140                    utility::Matrix& prediction) const
141  {   
142    assert(mlw.rows()==sigma2_.rows());
143    assert(mlw.rows()==centroids_.rows());
144   
145    // each row in prediction corresponds to a sample label (class)
146    prediction.resize(centroids_.columns(), mlw.columns(), 0);
147
148    // first calculate -lnP = sum (sigma_i) +
149    // N sum w_i(x_i-m_i)^2/2sigma_i^2 / sum w_i
150    for (size_t label=0; label<centroids_.columns(); ++label) {
151      double sum_log_sigma = sum_logsigma(label);
152      for (size_t sample=0; sample<prediction.rows(); ++sample) {
153        statistics::AveragerWeighted aw;
154        for (size_t i=0; i<mlw.rows(); ++i) 
155          aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
156                 sigma2_(i, label), mlw.weight(i, label));
157        prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2;
158      }
159    }
160    standardize_lnP(prediction);
161  }
162
163  void NBC::standardize_lnP(utility::Matrix& prediction) const
164  {
165    /// -lnP might be a large number, in order to avoid out of bound
166    /// problems when calculating P = exp(- -lnP), we centralize matrix
167    /// by adding a constant.
168    // lookup of prediction with zero weights for NaNs
169    MatrixLookupWeighted mlw(prediction);
170    statistics::AveragerWeighted a;
171    add(a, mlw.begin(), mlw.end());
172    prediction -= a.mean();
173   
174    // exponentiate
175    for (size_t i=0; i<prediction.rows(); ++i)
176      for (size_t j=0; j<prediction.columns(); ++j)
177        prediction(i,j) = std::exp(prediction(i,j));
178   
179    // normalize each row (label) to sum up to unity (probability)
180    for (size_t i=0; i<prediction.rows(); ++i){
181      // calculate sum of row ignoring NaNs
182      statistics::AveragerWeighted a;
183      add(a, mlw.begin_row(i), mlw.end_row(i));
184      prediction.row_view(i) *= 1.0/a.sum_wx();
185    }
186  }
187
188
189  double NBC::sum_logsigma(size_t label) const
190  {
191    double sum_log_sigma=0;
192    assert(label<sigma2_.columns());
193    for (size_t i=0; i<sigma2_.rows(); ++i) {
194      sum_log_sigma += std::log(sigma2_(i, label));
195    }
196    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
197  }
198
199}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.