source: trunk/yat/classifier/NBC.cc @ 2909

Last change on this file since 2909 was 2909, checked in by Peter, 9 years ago

avoid creating a MatrixLookupWeighted? from a Matrix as it involves copying the Matrix.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 6.1 KB
Line 
1// $Id: NBC.cc 2909 2012-12-16 12:57:13Z peter $
2
3/*
4  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
5
6  This file is part of the yat library, http://dev.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 3 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with yat. If not, see <http://www.gnu.org/licenses/>.
20*/
21
22#include <config.h>
23
24#include "NBC.h"
25#include "MatrixLookup.h"
26#include "MatrixLookupWeighted.h"
27#include "Target.h"
28#include "yat/statistics/Averager.h"
29#include "yat/statistics/AveragerWeighted.h"
30#include "yat/utility/Matrix.h"
31#include "yat/utility/WeightedIterator.h"
32
33#include <cassert>
34#include <cmath>
35#include <limits>
36#include <stdexcept>
37#include <vector>
38
39namespace theplu {
40namespace yat {
41namespace classifier {
42
43  NBC::NBC() 
44    : SupervisedClassifier()
45  {
46  }
47
48
49  NBC::~NBC()   
50  {
51  }
52
53
54  NBC* NBC::make_classifier() const 
55  {     
56    return new NBC();
57  }
58
59
60  void NBC::train(const MatrixLookup& data, const Target& target)
61  {   
62    sigma2_.resize(data.rows(), target.nof_classes());
63    centroids_.resize(data.rows(), target.nof_classes());
64   
65    for(size_t i=0; i<data.rows(); ++i) {
66      std::vector<statistics::Averager> aver(target.nof_classes());
67      for(size_t j=0; j<data.columns(); ++j) 
68        aver[target(j)].add(data(i,j));
69     
70      assert(centroids_.columns()==target.nof_classes());
71      for (size_t j=0; j<target.nof_classes(); ++j){
72        assert(i<centroids_.rows());
73        assert(j<centroids_.columns());
74        assert(i<sigma2_.rows());
75        assert(j<sigma2_.columns());
76        if (aver[j].n()>1){
77          sigma2_(i,j) = aver[j].variance();
78          centroids_(i,j) = aver[j].mean();
79        }
80        else {
81            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
82            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
83        }
84      }
85    }
86  }   
87
88
89  void NBC::train(const MatrixLookupWeighted& data, const Target& target)
90  {   
91    sigma2_.resize(data.rows(), target.nof_classes());
92    centroids_.resize(data.rows(), target.nof_classes());
93
94    for(size_t i=0; i<data.rows(); ++i) {
95      std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
96      for(size_t j=0; j<data.columns(); ++j) 
97        aver[target(j)].add(data.data(i,j), data.weight(i,j));
98     
99      assert(centroids_.columns()==target.nof_classes());
100      for (size_t j=0; j<target.nof_classes(); ++j) {
101        assert(i<centroids_.rows());
102        assert(j<centroids_.columns());
103        assert(i<sigma2_.rows());
104        assert(j<sigma2_.columns());
105        if (aver[j].n()>1){
106          sigma2_(i,j) = aver[j].variance();
107          centroids_(i,j) = aver[j].mean();
108        }
109        else {
110          sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
111          centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
112        }
113      }
114    }
115  }
116
117
118  void NBC::predict(const MatrixLookup& ml,                     
119                    utility::Matrix& prediction) const
120  {   
121    assert(ml.rows()==sigma2_.rows());
122    assert(ml.rows()==centroids_.rows());
123    // each row in prediction corresponds to a sample label (class)
124    prediction.resize(centroids_.columns(), ml.columns(), 0);
125
126    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
127    for (size_t label=0; label<centroids_.columns(); ++label) {
128      double sum_log_sigma = sum_logsigma(label);
129      for (size_t sample=0; sample<prediction.rows(); ++sample) {
130        prediction(label,sample) = sum_log_sigma;
131        for (size_t i=0; i<ml.rows(); ++i) 
132          prediction(label, sample) += 
133            std::pow(ml(i, label)-centroids_(i, label),2)/
134            sigma2_(i, label);
135      }
136    }
137    standardize_lnP(prediction);
138  }
139
140     
141  void NBC::predict(const MatrixLookupWeighted& mlw,                   
142                    utility::Matrix& prediction) const
143  {   
144    assert(mlw.rows()==sigma2_.rows());
145    assert(mlw.rows()==centroids_.rows());
146   
147    // each row in prediction corresponds to a sample label (class)
148    prediction.resize(centroids_.columns(), mlw.columns(), 0);
149
150    // first calculate -lnP = sum (sigma_i) +
151    // N sum w_i(x_i-m_i)^2/2sigma_i^2 / sum w_i
152    for (size_t label=0; label<centroids_.columns(); ++label) {
153      double sum_log_sigma = sum_logsigma(label);
154      for (size_t sample=0; sample<prediction.rows(); ++sample) {
155        statistics::AveragerWeighted aw;
156        for (size_t i=0; i<mlw.rows(); ++i) 
157          aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
158                 sigma2_(i, label), mlw.weight(i, label));
159        prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2;
160      }
161    }
162    standardize_lnP(prediction);
163  }
164
165  void NBC::standardize_lnP(utility::Matrix& prediction) const
166  {
167    /// -lnP might be a large number, in order to avoid out of bound
168    /// problems when calculating P = exp(- -lnP), we centralize matrix
169    /// by adding a constant.
170    utility::Matrix weights;
171    // create zero/unity weight matrix (w=0 if NaN)
172    nan(prediction, weights);
173    using utility::weighted_iterator;
174    statistics::AveragerWeighted a;
175    add(a, weighted_iterator(prediction.begin(), weights.begin()),
176        weighted_iterator(prediction.end(), weights.end()));
177    prediction -= a.mean();
178
179    // exponentiate
180    for (size_t i=0; i<prediction.rows(); ++i)
181      for (size_t j=0; j<prediction.columns(); ++j)
182        prediction(i,j) = std::exp(prediction(i,j));
183
184    // normalize each row (label) to sum up to unity (probability)
185    for (size_t i=0; i<prediction.rows(); ++i){
186      // calculate sum of row ignoring NaNs
187      statistics::AveragerWeighted a;
188      add(a, weighted_iterator(prediction.begin_row(i), weights.begin_row(i)),
189          weighted_iterator(prediction.end_row(i), weights.end_row(i)));
190      prediction.row_view(i) *= 1.0/a.sum_wx();
191    }
192  }
193
194
195  double NBC::sum_logsigma(size_t label) const
196  {
197    double sum_log_sigma=0;
198    assert(label<sigma2_.columns());
199    for (size_t i=0; i<sigma2_.rows(); ++i) {
200      sum_log_sigma += std::log(sigma2_(i, label));
201    }
202    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
203  }
204
205}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.