source: trunk/yat/classifier/NBC.cc @ 1144

Last change on this file since 1144 was 1144, checked in by Markus Ringnér, 14 years ago

Fixes #334

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 6.7 KB
Line 
1// $Id: NBC.cc 1144 2008-02-25 16:51:58Z markus $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "NBC.h"
26#include "DataLookup2D.h"
27#include "MatrixLookup.h"
28#include "MatrixLookupWeighted.h"
29#include "Target.h"
30#include "yat/statistics/AveragerWeighted.h"
31#include "yat/utility/Matrix.h"
32
33#include <cassert>
34#include <cmath>
35#include <stdexcept>
36#include <vector>
37
38namespace theplu {
39namespace yat {
40namespace classifier {
41
42  NBC::NBC(const MatrixLookup& data, const Target& target) 
43    : SupervisedClassifier(target), data_(data)
44  {
45  }
46
47  NBC::NBC(const MatrixLookupWeighted& data, const Target& target) 
48    : SupervisedClassifier(target), data_(data)
49  {
50  }
51
52  NBC::~NBC()   
53  {
54  }
55
56
57  const DataLookup2D& NBC::data(void) const
58  {
59    return data_;
60  }
61
62
63  NBC* 
64  NBC::make_classifier(const DataLookup2D& data, const Target& target) const 
65  {     
66    NBC* nbc=0;
67    try {
68      if(data.weighted()) {
69        nbc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
70      }
71      else {
72        nbc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
73      }     
74    }
75    catch (std::bad_cast) {
76      std::string str = 
77        "Error in NBC::make_classifier: DataLookup2D of unexpected class.";
78      throw std::runtime_error(str);
79    }
80    return nbc;
81  }
82
83
84  void NBC::train()
85  {   
86    sigma2_.resize(data_.rows(), target_.nof_classes());
87    centroids_.resize(data_.rows(), target_.nof_classes());
88    utility::Matrix nof_in_class(data_.rows(), target_.nof_classes());
89   
90    // unweighted
91    if (data_.weighted()){
92      const MatrixLookupWeighted& data = 
93        dynamic_cast<const MatrixLookupWeighted&>(data_);
94      for(size_t i=0; i<data_.rows(); ++i) {
95        std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
96        for(size_t j=0; j<data_.columns(); ++j) 
97          aver[target_(j)].add(data.data(i,j), data.weight(i,j));
98
99        assert(centroids_.columns()==target_.nof_classes());
100        for (size_t j=0; j<target_.nof_classes(); ++j){
101          assert(i<centroids_.rows());
102          assert(j<centroids_.columns());
103          assert(i<sigma2_.rows());
104          assert(j<sigma2_.columns());
105          if (aver[j].n()>1){
106            sigma2_(i,j) = aver[j].variance();
107            centroids_(i,j) = aver[j].mean();
108          }
109          else {
110            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
111            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
112          }
113        }
114      }
115    }
116    else { 
117      const MatrixLookup& data = dynamic_cast<const MatrixLookup&>(data_);
118      for(size_t i=0; i<data_.rows(); ++i) {
119        std::vector<statistics::Averager> aver(target_.nof_classes());
120        for(size_t j=0; j<data_.columns(); ++j) 
121          aver[target_(j)].add(data(i,j));
122
123        assert(centroids_.columns()==target_.nof_classes());
124        for (size_t j=0; j<target_.nof_classes(); ++j){
125          assert(i<centroids_.rows());
126          assert(j<centroids_.columns());
127          centroids_(i,j) = aver[j].mean();
128          assert(i<sigma2_.rows());
129          assert(j<sigma2_.columns());
130          if (aver[j].n()>1){
131            sigma2_(i,j) = aver[j].variance();
132            centroids_(i,j) = aver[j].mean();
133          }
134          else {
135            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
136            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
137          }
138        }
139      }
140    }   
141    trained_=true;
142  }
143
144
145  void NBC::predict(const DataLookup2D& x,                   
146                    utility::Matrix& prediction) const
147  {   
148    assert(data_.rows()==x.rows());
149    assert(x.rows()==sigma2_.rows());
150    assert(x.rows()==centroids_.rows());
151   
152   
153    // each row in prediction corresponds to a sample label (class)
154    prediction.resize(centroids_.columns(), x.columns(), 0);
155    // weighted calculation
156    if (const MatrixLookupWeighted* mlw = 
157        dynamic_cast<const MatrixLookupWeighted*>(&x)) {
158      // first calculate -lnP = sum ln_sigma_i + (x_i-m_i)^2/2sigma_i^2
159      for (size_t label=0; label<centroids_.columns(); ++label) {
160        double sum_log_sigma = sum_logsigma(label);
161        for (size_t sample=0; sample<prediction.rows(); ++sample) {
162          prediction(label,sample) = sum_log_sigma;
163          for (size_t i=0; i<x.rows(); ++i) 
164            // taking care of NaN and missing training features
165            if (mlw->weight(i, label) && !std::isnan(sigma2_(i, label))) {
166              prediction(label, sample) += mlw->weight(i, label)*
167                std::pow(mlw->data(i, label)-centroids_(i, label),2)/
168                sigma2_(i, label);
169            }
170     
171        }
172      }
173    }
174      // no weights
175    else if (const MatrixLookup* ml = dynamic_cast<const MatrixLookup*>(&x)) {
176      // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
177      for (size_t label=0; label<centroids_.columns(); ++label) {
178        double sum_log_sigma = sum_logsigma(label);
179        for (size_t sample=0; sample<prediction.rows(); ++sample) {
180          prediction(label,sample) = sum_log_sigma;
181          for (size_t i=0; i<ml->rows(); ++i) 
182            // Ignoring missing features
183            if (!std::isnan(sigma2_(i, label)))
184              prediction(label, sample) += 
185                std::pow((*ml)(i, label)-centroids_(i, label),2)/
186                sigma2_(i, label);
187        }
188      }
189    }
190    else {
191      std::string str = 
192        "Error in NBC::predict: DataLookup2D of unexpected class.";
193      throw std::runtime_error(str);
194    }
195
196
197    // -lnP might be a large number, in order to avoid out of bound
198    // problems when calculating P = exp(- -lnP), we centralize matrix
199    // by adding a constant.
200    statistics::Averager a;
201    add(a, prediction.begin(), prediction.end());
202    prediction -= a.mean();
203
204    // exponentiate
205    for (size_t i=0; i<prediction.rows(); ++i)
206      for (size_t j=0; j<prediction.columns(); ++j)
207        prediction(i,j) = std::exp(prediction(i,j));
208
209    // normalize each row (label) to sum up to unity (probability)
210    for (size_t i=0; i<prediction.rows(); ++i){
211      prediction.row_view(i) *= 1.0/sum(prediction.row_const_view(i));
212    }
213  }
214
215
216  double NBC::sum_logsigma(size_t label) const
217  {
218    double sum_log_sigma=0;
219    assert(label<sigma2_.columns());
220    for (size_t i=0; i<sigma2_.rows(); ++i) {
221      if (!std::isnan(sigma2_(i,label))) 
222        sum_log_sigma += std::log(sigma2_(i, label));
223    }
224    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
225  }
226
227}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.