source: trunk/yat/classifier/NBC.cc @ 1162

Last change on this file since 1162 was 1162, checked in by Peter, 14 years ago

removing trained_

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 6.1 KB
Line 
1// $Id: NBC.cc 1162 2008-02-26 16:24:11Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "NBC.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "Target.h"
29#include "yat/statistics/Averager.h"
30#include "yat/statistics/AveragerWeighted.h"
31#include "yat/utility/Matrix.h"
32
33#include <cassert>
34#include <cmath>
35#include <stdexcept>
36#include <vector>
37
38namespace theplu {
39namespace yat {
40namespace classifier {
41
42  NBC::NBC() 
43    : SupervisedClassifier()
44  {
45  }
46
47
48  NBC::~NBC()   
49  {
50  }
51
52
53  NBC* NBC::make_classifier() const 
54  {     
55    return new NBC();
56  }
57
58
59  void NBC::train(const MatrixLookup& data, const Target& target)
60  {   
61    sigma2_.resize(data.rows(), target.nof_classes());
62    centroids_.resize(data.rows(), target.nof_classes());
63    utility::Matrix nof_in_class(data.rows(), target.nof_classes());
64   
65    for(size_t i=0; i<data.rows(); ++i) {
66      std::vector<statistics::Averager> aver(target.nof_classes());
67      for(size_t j=0; j<data.columns(); ++j) 
68        aver[target(j)].add(data(i,j));
69     
70      assert(centroids_.columns()==target.nof_classes());
71      for (size_t j=0; j<target.nof_classes(); ++j){
72        assert(i<centroids_.rows());
73        assert(j<centroids_.columns());
74        centroids_(i,j) = aver[j].mean();
75        assert(i<sigma2_.rows());
76        assert(j<sigma2_.columns());
77        if (aver[j].n()>1){
78          sigma2_(i,j) = aver[j].variance();
79          centroids_(i,j) = aver[j].mean();
80        }
81          else {
82            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
83            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
84          }
85      }
86    }
87  }   
88
89
90  void NBC::train(const MatrixLookupWeighted& data, const Target& target)
91  {   
92    sigma2_.resize(data.rows(), target.nof_classes());
93    centroids_.resize(data.rows(), target.nof_classes());
94    utility::Matrix nof_in_class(data.rows(), target.nof_classes());
95
96    for(size_t i=0; i<data.rows(); ++i) {
97      std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
98      for(size_t j=0; j<data.columns(); ++j) 
99        aver[target(j)].add(data.data(i,j), data.weight(i,j));
100     
101      assert(centroids_.columns()==target.nof_classes());
102      for (size_t j=0; j<target.nof_classes(); ++j) {
103        assert(i<centroids_.rows());
104        assert(j<centroids_.columns());
105        assert(i<sigma2_.rows());
106        assert(j<sigma2_.columns());
107        if (aver[j].n()>1){
108          sigma2_(i,j) = aver[j].variance();
109          centroids_(i,j) = aver[j].mean();
110        }
111        else {
112          sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
113          centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
114        }
115      }
116    }
117  }
118
119
120  void NBC::predict(const MatrixLookup& ml,                     
121                    utility::Matrix& prediction) const
122  {   
123    assert(ml.rows()==sigma2_.rows());
124    assert(ml.rows()==centroids_.rows());
125    // each row in prediction corresponds to a sample label (class)
126    prediction.resize(centroids_.columns(), ml.columns(), 0);
127
128    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
129    for (size_t label=0; label<centroids_.columns(); ++label) {
130      double sum_log_sigma = sum_logsigma(label);
131      for (size_t sample=0; sample<prediction.rows(); ++sample) {
132        prediction(label,sample) = sum_log_sigma;
133        for (size_t i=0; i<ml.rows(); ++i) 
134          // Ignoring missing features
135          if (!std::isnan(sigma2_(i, label)))
136            prediction(label, sample) += 
137              std::pow(ml(i, label)-centroids_(i, label),2)/
138              sigma2_(i, label);
139      }
140    }
141    standardize_lnP(prediction);
142  }
143
144     
145  void NBC::predict(const MatrixLookupWeighted& mlw,                   
146                    utility::Matrix& prediction) const
147  {   
148    assert(mlw.rows()==sigma2_.rows());
149    assert(mlw.rows()==centroids_.rows());
150   
151    // each row in prediction corresponds to a sample label (class)
152    prediction.resize(centroids_.columns(), mlw.columns(), 0);
153
154    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
155    for (size_t label=0; label<centroids_.columns(); ++label) {
156      double sum_log_sigma = sum_logsigma(label);
157      for (size_t sample=0; sample<prediction.rows(); ++sample) {
158        prediction(label,sample) = sum_log_sigma;
159        for (size_t i=0; i<mlw.rows(); ++i) 
160          // taking care of NaN and missing training features
161          if (mlw.weight(i, label) && !std::isnan(sigma2_(i, label))) {
162            prediction(label, sample) += mlw.weight(i, label)*
163              std::pow(mlw.data(i, label)-centroids_(i, label),2)/
164              sigma2_(i, label);
165          }
166       
167      }
168    }
169    standardize_lnP(prediction);
170  }
171
172  void NBC::standardize_lnP(utility::Matrix& prediction) const
173  {
174    // -lnP might be a large number, in order to avoid out of bound
175    // problems when calculating P = exp(- -lnP), we centralize matrix
176    // by adding a constant.
177    statistics::Averager a;
178    add(a, prediction.begin(), prediction.end());
179    prediction -= a.mean();
180   
181    // exponentiate
182    for (size_t i=0; i<prediction.rows(); ++i)
183      for (size_t j=0; j<prediction.columns(); ++j)
184        prediction(i,j) = std::exp(prediction(i,j));
185   
186    // normalize each row (label) to sum up to unity (probability)
187    for (size_t i=0; i<prediction.rows(); ++i){
188      prediction.row_view(i) *= 1.0/sum(prediction.row_const_view(i));
189    }
190  }
191
192
193  double NBC::sum_logsigma(size_t label) const
194  {
195    double sum_log_sigma=0;
196    assert(label<sigma2_.columns());
197    for (size_t i=0; i<sigma2_.rows(); ++i) {
198      if (!std::isnan(sigma2_(i,label))) 
199        sum_log_sigma += std::log(sigma2_(i, label));
200    }
201    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
202  }
203
204}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.