source: trunk/yat/classifier/NBC.cc @ 2881

Last change on this file since 2881 was 2881, checked in by Peter, 9 years ago

Define PP variables in config.h rather than in CPPFLAGS. Include
config.h into all source files. Only ammend CXXFLAGS with '-Wall
-pedantic' when --enable-debug. In default mode we respect CXXFLAGS
value set by user, or set to default value '-O3'.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 5.8 KB
Line 
1// $Id: NBC.cc 2881 2012-11-18 01:28:05Z peter $
2
3/*
4  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
5
6  This file is part of the yat library, http://dev.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 3 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with yat. If not, see <http://www.gnu.org/licenses/>.
20*/
21
22#include <config.h>
23
24#include "NBC.h"
25#include "MatrixLookup.h"
26#include "MatrixLookupWeighted.h"
27#include "Target.h"
28#include "yat/statistics/Averager.h"
29#include "yat/statistics/AveragerWeighted.h"
30#include "yat/utility/Matrix.h"
31
32#include <cassert>
33#include <cmath>
34#include <limits>
35#include <stdexcept>
36#include <vector>
37
38namespace theplu {
39namespace yat {
40namespace classifier {
41
42  NBC::NBC() 
43    : SupervisedClassifier()
44  {
45  }
46
47
48  NBC::~NBC()   
49  {
50  }
51
52
53  NBC* NBC::make_classifier() const 
54  {     
55    return new NBC();
56  }
57
58
59  void NBC::train(const MatrixLookup& data, const Target& target)
60  {   
61    sigma2_.resize(data.rows(), target.nof_classes());
62    centroids_.resize(data.rows(), target.nof_classes());
63   
64    for(size_t i=0; i<data.rows(); ++i) {
65      std::vector<statistics::Averager> aver(target.nof_classes());
66      for(size_t j=0; j<data.columns(); ++j) 
67        aver[target(j)].add(data(i,j));
68     
69      assert(centroids_.columns()==target.nof_classes());
70      for (size_t j=0; j<target.nof_classes(); ++j){
71        assert(i<centroids_.rows());
72        assert(j<centroids_.columns());
73        assert(i<sigma2_.rows());
74        assert(j<sigma2_.columns());
75        if (aver[j].n()>1){
76          sigma2_(i,j) = aver[j].variance();
77          centroids_(i,j) = aver[j].mean();
78        }
79        else {
80            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
81            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
82        }
83      }
84    }
85  }   
86
87
88  void NBC::train(const MatrixLookupWeighted& data, const Target& target)
89  {   
90    sigma2_.resize(data.rows(), target.nof_classes());
91    centroids_.resize(data.rows(), target.nof_classes());
92
93    for(size_t i=0; i<data.rows(); ++i) {
94      std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
95      for(size_t j=0; j<data.columns(); ++j) 
96        aver[target(j)].add(data.data(i,j), data.weight(i,j));
97     
98      assert(centroids_.columns()==target.nof_classes());
99      for (size_t j=0; j<target.nof_classes(); ++j) {
100        assert(i<centroids_.rows());
101        assert(j<centroids_.columns());
102        assert(i<sigma2_.rows());
103        assert(j<sigma2_.columns());
104        if (aver[j].n()>1){
105          sigma2_(i,j) = aver[j].variance();
106          centroids_(i,j) = aver[j].mean();
107        }
108        else {
109          sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
110          centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
111        }
112      }
113    }
114  }
115
116
117  void NBC::predict(const MatrixLookup& ml,                     
118                    utility::Matrix& prediction) const
119  {   
120    assert(ml.rows()==sigma2_.rows());
121    assert(ml.rows()==centroids_.rows());
122    // each row in prediction corresponds to a sample label (class)
123    prediction.resize(centroids_.columns(), ml.columns(), 0);
124
125    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
126    for (size_t label=0; label<centroids_.columns(); ++label) {
127      double sum_log_sigma = sum_logsigma(label);
128      for (size_t sample=0; sample<prediction.rows(); ++sample) {
129        prediction(label,sample) = sum_log_sigma;
130        for (size_t i=0; i<ml.rows(); ++i) 
131          prediction(label, sample) += 
132            std::pow(ml(i, label)-centroids_(i, label),2)/
133            sigma2_(i, label);
134      }
135    }
136    standardize_lnP(prediction);
137  }
138
139     
140  void NBC::predict(const MatrixLookupWeighted& mlw,                   
141                    utility::Matrix& prediction) const
142  {   
143    assert(mlw.rows()==sigma2_.rows());
144    assert(mlw.rows()==centroids_.rows());
145   
146    // each row in prediction corresponds to a sample label (class)
147    prediction.resize(centroids_.columns(), mlw.columns(), 0);
148
149    // first calculate -lnP = sum (sigma_i) +
150    // N sum w_i(x_i-m_i)^2/2sigma_i^2 / sum w_i
151    for (size_t label=0; label<centroids_.columns(); ++label) {
152      double sum_log_sigma = sum_logsigma(label);
153      for (size_t sample=0; sample<prediction.rows(); ++sample) {
154        statistics::AveragerWeighted aw;
155        for (size_t i=0; i<mlw.rows(); ++i) 
156          aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
157                 sigma2_(i, label), mlw.weight(i, label));
158        prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2;
159      }
160    }
161    standardize_lnP(prediction);
162  }
163
164  void NBC::standardize_lnP(utility::Matrix& prediction) const
165  {
166    /// -lnP might be a large number, in order to avoid out of bound
167    /// problems when calculating P = exp(- -lnP), we centralize matrix
168    /// by adding a constant.
169    // lookup of prediction with zero weights for NaNs
170    MatrixLookupWeighted mlw(prediction);
171    statistics::AveragerWeighted a;
172    add(a, mlw.begin(), mlw.end());
173    prediction -= a.mean();
174   
175    // exponentiate
176    for (size_t i=0; i<prediction.rows(); ++i)
177      for (size_t j=0; j<prediction.columns(); ++j)
178        prediction(i,j) = std::exp(prediction(i,j));
179   
180    // normalize each row (label) to sum up to unity (probability)
181    for (size_t i=0; i<prediction.rows(); ++i){
182      // calculate sum of row ignoring NaNs
183      statistics::AveragerWeighted a;
184      add(a, mlw.begin_row(i), mlw.end_row(i));
185      prediction.row_view(i) *= 1.0/a.sum_wx();
186    }
187  }
188
189
190  double NBC::sum_logsigma(size_t label) const
191  {
192    double sum_log_sigma=0;
193    assert(label<sigma2_.columns());
194    for (size_t i=0; i<sigma2_.rows(); ++i) {
195      sum_log_sigma += std::log(sigma2_(i, label));
196    }
197    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
198  }
199
200}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.