source: trunk/yat/classifier/NBC.cc @ 961

Last change on this file since 961 was 961, checked in by Peter, 14 years ago

correcting NBC train. refs ticket:271

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 6.6 KB
Line 
1// $Id: NBC.cc 961 2007-10-10 17:51:25Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "NBC.h"
26#include "DataLookup2D.h"
27#include "MatrixLookup.h"
28#include "MatrixLookupWeighted.h"
29#include "Target.h"
30#include "yat/statistics/AveragerWeighted.h"
31#include "yat/utility/matrix.h"
32
33#include <cassert>
34#include <cmath>
35#include <vector>
36
37namespace theplu {
38namespace yat {
39namespace classifier {
40
41  NBC::NBC(const MatrixLookup& data, const Target& target) 
42    : SupervisedClassifier(target), data_(data)
43  {
44  }
45
46  NBC::NBC(const MatrixLookupWeighted& data, const Target& target) 
47    : SupervisedClassifier(target), data_(data)
48  {
49  }
50
51  NBC::~NBC()   
52  {
53  }
54
55
56  const DataLookup2D& NBC::data(void) const
57  {
58    return data_;
59  }
60
61
62  SupervisedClassifier* 
63  NBC::make_classifier(const DataLookup2D& data, const Target& target) const 
64  {     
65    NBC* nbc=0;
66    try {
67      if(data.weighted()) {
68        nbc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
69      }
70      else {
71        nbc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
72      }     
73    }
74    catch (std::bad_cast) {
75      std::string str = 
76        "Error in NBC::make_classifier: DataLookup2D of unexpected class.";
77      throw std::runtime_error(str);
78    }
79    return nbc;
80  }
81
82
83  bool NBC::train()
84  {   
85    sigma2_.resize(data_.rows(), target_.nof_classes());
86    centroids_.resize(data_.rows(), target_.nof_classes());
87    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
88   
89    // unweighted
90    if (data_.weighted()){
91      const MatrixLookupWeighted& data = 
92        dynamic_cast<const MatrixLookupWeighted&>(data_);
93      for(size_t i=0; i<data_.rows(); ++i) {
94        std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
95        for(size_t j=0; j<data_.columns(); ++j) 
96          aver[target_(j)].add(data.data(i,j), data.weight(i,j));
97
98        assert(centroids_.columns()==target_.nof_classes());
99        for (size_t j=0; j<target_.nof_classes(); ++j){
100          assert(i<centroids_.rows());
101          assert(j<centroids_.columns());
102          centroids_(i,j) = aver[j].mean();
103          assert(i<sigma2_.rows());
104          assert(j<sigma2_.columns());
105          if (aver[j].n()>1)
106            sigma2_(i,j) = aver[j].variance();
107          else 
108            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
109        }
110      }
111    }
112    else { 
113      const MatrixLookup& data = dynamic_cast<const MatrixLookup&>(data_);
114      for(size_t i=0; i<data_.rows(); ++i) {
115        std::vector<statistics::Averager> aver(target_.nof_classes());
116        for(size_t j=0; j<data_.columns(); ++j) 
117          aver[target_(j)].add(data(i,j));
118
119        assert(centroids_.columns()==target_.nof_classes());
120        for (size_t j=0; j<target_.nof_classes(); ++j){
121          assert(i<centroids_.rows());
122          assert(j<centroids_.columns());
123          centroids_(i,j) = aver[j].mean();
124          assert(i<sigma2_.rows());
125          assert(j<sigma2_.columns());
126          if (aver[j].n()>1)
127            sigma2_(i,j) = aver[j].variance();
128          else 
129            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
130        }
131      }
132    }   
133    trained_=true;
134    return trained_;
135  }
136
137
138  void NBC::predict(const DataLookup2D& x,                   
139                    utility::matrix& prediction) const
140  {   
141    assert(data_.rows()==x.rows());
142    assert(x.rows()==sigma2_.rows());
143    assert(x.rows()==centroids_.rows());
144
145   
146   
147    // each row in prediction corresponds to a sample label (class)
148    prediction.resize(centroids_.columns(), x.columns(), 0);
149    // weighted calculation
150    if (const MatrixLookupWeighted* mlw = 
151        dynamic_cast<const MatrixLookupWeighted*>(&x)) {
152      // first calculate -lnP = sum ln_sigma_i + (x_i-m_i)^2/2sigma_i^2
153      for (size_t label=0; label<centroids_.columns(); ++label) {
154        double sum_log_sigma = sum_logsigma(label);
155        for (size_t sample=0; sample<prediction.rows(); ++sample) {
156          prediction(label,sample) = sum_log_sigma;
157          for (size_t i=0; i<x.rows(); ++i) 
158            // taking care of NaN and missing training features
159            if (mlw->weight(i, label) && !std::isnan(sigma2_(i, label))) {
160              prediction(label, sample) += mlw->weight(i, label)*
161                std::pow(mlw->data(i, label)-centroids_(i, label),2)/
162                sigma2_(i, label);
163            }
164     
165        }
166      }
167    }
168      // no weights
169    else if (const MatrixLookup* ml = dynamic_cast<const MatrixLookup*>(&x)) {
170      // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
171      for (size_t label=0; label<centroids_.columns(); ++label) {
172        double sum_log_sigma = sum_logsigma(label);
173        for (size_t sample=0; sample<prediction.rows(); ++sample) {
174          prediction(label,sample) = sum_log_sigma;
175          for (size_t i=0; i<ml->rows(); ++i) 
176            // Ignoring missing features
177            if (!std::isnan(sigma2_(i, label)))
178              prediction(label, sample) += 
179                std::pow((*ml)(i, label)-centroids_(i, label),2)/
180                sigma2_(i, label);
181        }
182      }
183    }
184    else {
185      std::string str = 
186        "Error in NBC::predict: DataLookup2D of unexpected class.";
187      throw std::runtime_error(str);
188    }
189
190
191    // -lnP might be a large number, in order to avoid out of bound
192    // problems when calculating P = exp(- -lnP), we centralize matrix
193    // by adding a constant.
194    double m=0;
195    for (size_t i=0; i<prediction.rows(); ++i)
196      for (size_t j=0; j<prediction.columns(); ++j)
197        m+=prediction(i,j);
198    prediction -= m/prediction.rows()/prediction.columns();
199
200    // exponentiate
201    for (size_t i=0; i<prediction.rows(); ++i)
202      for (size_t j=0; j<prediction.columns(); ++j)
203        prediction(i,j) = std::exp(prediction(i,j));
204
205    // normalize each row (label) to sum up to unity (probability)
206    for (size_t i=0; i<prediction.rows(); ++i)
207      utility::vector(prediction,i) *= 
208        1.0/utility::sum(utility::vector(prediction,i));
209
210  }
211
212
213  double NBC::sum_logsigma(size_t label) const
214  {
215    double sum_log_sigma=0;
216    assert(label<sigma2_.columns());
217    for (size_t i=0; i<sigma2_.rows(); ++i) {
218      if (!std::isnan(sigma2_(i,label))) 
219        sum_log_sigma += std::log(sigma2_(i, label));
220    }
221    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
222  }
223
224}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.