source: trunk/yat/classifier/NBC.cc @ 1182

Last change on this file since 1182 was 1182, checked in by Peter, 14 years ago

working on #335. Fixed weighted test data case. Left to fix is when there are missing features in training in other words what should happen when complete training cannot be done because lack of data. The current behavior is probably not optimal, but have to look into it in more detail.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 6.1 KB
Line 
1// $Id: NBC.cc 1182 2008-02-28 12:27:37Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "NBC.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "Target.h"
29#include "yat/statistics/Averager.h"
30#include "yat/statistics/AveragerWeighted.h"
31#include "yat/utility/Matrix.h"
32
33#include <cassert>
34#include <cmath>
35#include <stdexcept>
36#include <vector>
37
38namespace theplu {
39namespace yat {
40namespace classifier {
41
42  NBC::NBC() 
43    : SupervisedClassifier()
44  {
45  }
46
47
48  NBC::~NBC()   
49  {
50  }
51
52
53  NBC* NBC::make_classifier() const 
54  {     
55    return new NBC();
56  }
57
58
59  void NBC::train(const MatrixLookup& data, const Target& target)
60  {   
61    sigma2_.resize(data.rows(), target.nof_classes());
62    centroids_.resize(data.rows(), target.nof_classes());
63    utility::Matrix nof_in_class(data.rows(), target.nof_classes());
64   
65    for(size_t i=0; i<data.rows(); ++i) {
66      std::vector<statistics::Averager> aver(target.nof_classes());
67      for(size_t j=0; j<data.columns(); ++j) 
68        aver[target(j)].add(data(i,j));
69     
70      assert(centroids_.columns()==target.nof_classes());
71      for (size_t j=0; j<target.nof_classes(); ++j){
72        assert(i<centroids_.rows());
73        assert(j<centroids_.columns());
74        centroids_(i,j) = aver[j].mean();
75        assert(i<sigma2_.rows());
76        assert(j<sigma2_.columns());
77        if (aver[j].n()>1){
78          sigma2_(i,j) = aver[j].variance();
79          centroids_(i,j) = aver[j].mean();
80        }
81          else {
82            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
83            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
84          }
85      }
86    }
87  }   
88
89
90  void NBC::train(const MatrixLookupWeighted& data, const Target& target)
91  {   
92    sigma2_.resize(data.rows(), target.nof_classes());
93    centroids_.resize(data.rows(), target.nof_classes());
94    utility::Matrix nof_in_class(data.rows(), target.nof_classes());
95
96    for(size_t i=0; i<data.rows(); ++i) {
97      std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
98      for(size_t j=0; j<data.columns(); ++j) 
99        aver[target(j)].add(data.data(i,j), data.weight(i,j));
100     
101      assert(centroids_.columns()==target.nof_classes());
102      for (size_t j=0; j<target.nof_classes(); ++j) {
103        assert(i<centroids_.rows());
104        assert(j<centroids_.columns());
105        assert(i<sigma2_.rows());
106        assert(j<sigma2_.columns());
107        if (aver[j].n()>1){
108          sigma2_(i,j) = aver[j].variance();
109          centroids_(i,j) = aver[j].mean();
110        }
111        else {
112          sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
113          centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
114        }
115      }
116    }
117  }
118
119
120  void NBC::predict(const MatrixLookup& ml,                     
121                    utility::Matrix& prediction) const
122  {   
123    assert(ml.rows()==sigma2_.rows());
124    assert(ml.rows()==centroids_.rows());
125    // each row in prediction corresponds to a sample label (class)
126    prediction.resize(centroids_.columns(), ml.columns(), 0);
127
128    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
129    for (size_t label=0; label<centroids_.columns(); ++label) {
130      double sum_log_sigma = sum_logsigma(label);
131      for (size_t sample=0; sample<prediction.rows(); ++sample) {
132        prediction(label,sample) = sum_log_sigma;
133        for (size_t i=0; i<ml.rows(); ++i) 
134          // Ignoring missing features
135          if (!std::isnan(sigma2_(i, label)))
136            prediction(label, sample) += 
137              std::pow(ml(i, label)-centroids_(i, label),2)/
138              sigma2_(i, label);
139      }
140    }
141    standardize_lnP(prediction);
142  }
143
144     
145  void NBC::predict(const MatrixLookupWeighted& mlw,                   
146                    utility::Matrix& prediction) const
147  {   
148    assert(mlw.rows()==sigma2_.rows());
149    assert(mlw.rows()==centroids_.rows());
150   
151    // each row in prediction corresponds to a sample label (class)
152    prediction.resize(centroids_.columns(), mlw.columns(), 0);
153
154    // first calculate -lnP = sum (sigma_i) +
155    // N sum w_i(x_i-m_i)^2/2sigma_i^2 / sum w_i
156    for (size_t label=0; label<centroids_.columns(); ++label) {
157      double sum_log_sigma = sum_logsigma(label);
158      for (size_t sample=0; sample<prediction.rows(); ++sample) {
159        statistics::AveragerWeighted aw;
160        for (size_t i=0; i<mlw.rows(); ++i) 
161          // missing training features
162          if (!std::isnan(sigma2_(i, label))) 
163            aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
164                   sigma2_(i, label), mlw.weight(i, label));
165        prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2;
166      }
167    }
168    standardize_lnP(prediction);
169  }
170
171  void NBC::standardize_lnP(utility::Matrix& prediction) const
172  {
173    // -lnP might be a large number, in order to avoid out of bound
174    // problems when calculating P = exp(- -lnP), we centralize matrix
175    // by adding a constant.
176    statistics::Averager a;
177    add(a, prediction.begin(), prediction.end());
178    prediction -= a.mean();
179   
180    // exponentiate
181    for (size_t i=0; i<prediction.rows(); ++i)
182      for (size_t j=0; j<prediction.columns(); ++j)
183        prediction(i,j) = std::exp(prediction(i,j));
184   
185    // normalize each row (label) to sum up to unity (probability)
186    for (size_t i=0; i<prediction.rows(); ++i){
187      prediction.row_view(i) *= 1.0/sum(prediction.row_const_view(i));
188    }
189  }
190
191
192  double NBC::sum_logsigma(size_t label) const
193  {
194    double sum_log_sigma=0;
195    assert(label<sigma2_.columns());
196    for (size_t i=0; i<sigma2_.rows(); ++i) {
197      if (!std::isnan(sigma2_(i,label))) 
198        sum_log_sigma += std::log(sigma2_(i, label));
199    }
200    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
201  }
202
203}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.