source: trunk/yat/classifier/NBC.cc @ 959

Last change on this file since 959 was 959, checked in by Peter, 14 years ago

Fixed so NBC and SVM are throwing when unexpected DataLookup2D is
apssed to make_classifier or predict.

Speeding up NBC::predict by separating weighted code from
unweighted. Also fixed some bugs in NBC.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 5.6 KB
Line 
1// $Id: NBC.cc 959 2007-10-10 16:49:39Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "NBC.h"
26#include "DataLookup2D.h"
27#include "MatrixLookup.h"
28#include "MatrixLookupWeighted.h"
29#include "Target.h"
30#include "yat/statistics/AveragerWeighted.h"
31#include "yat/utility/matrix.h"
32
33#include <cassert>
34#include <cmath>
35#include <vector>
36
37namespace theplu {
38namespace yat {
39namespace classifier {
40
41  NBC::NBC(const MatrixLookup& data, const Target& target) 
42    : SupervisedClassifier(target), data_(data)
43  {
44  }
45
46  NBC::NBC(const MatrixLookupWeighted& data, const Target& target) 
47    : SupervisedClassifier(target), data_(data)
48  {
49  }
50
51  NBC::~NBC()   
52  {
53  }
54
55
56  const DataLookup2D& NBC::data(void) const
57  {
58    return data_;
59  }
60
61
62  SupervisedClassifier* 
63  NBC::make_classifier(const DataLookup2D& data, const Target& target) const 
64  {     
65    NBC* nbc=0;
66    try {
67      if(data.weighted()) {
68        nbc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
69      }
70      else {
71        nbc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
72      }     
73    }
74    catch (std::bad_cast) {
75      std::string str = 
76        "Error in NBC::make_classifier: DataLookup2D of unexpected class.";
77      throw std::runtime_error(str);
78    }
79    return nbc;
80  }
81
82
83  bool NBC::train()
84  {   
85    sigma2_.resize(data_.rows(), target_.nof_classes());
86    centroids_.resize(data_.rows(), target_.nof_classes());
87    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
88   
89    for(size_t i=0; i<data_.rows(); ++i) {
90      std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
91      for(size_t j=0; j<data_.columns(); ++j) {
92        if (data_.weighted()){
93          const MatrixLookupWeighted& data = 
94            dynamic_cast<const MatrixLookupWeighted&>(data_);
95          aver[target_(j)].add(data.data(i,j), data.weight(i,j));
96        }
97        else
98          aver[target_(j)].add(data_(i,j),1.0);
99      }
100      assert(centroids_.columns()==target_.nof_classes());
101      for (size_t j=0; j<target_.nof_classes(); ++j){
102        assert(i<centroids_.rows());
103        assert(j<centroids_.columns());
104        centroids_(i,j) = aver[j].mean();
105        assert(i<sigma2_.rows());
106        assert(j<sigma2_.columns());
107        sigma2_(i,j) = aver[j].variance();
108      }
109    }   
110    trained_=true;
111    return trained_;
112  }
113
114
115  void NBC::predict(const DataLookup2D& x,                   
116                    utility::matrix& prediction) const
117  {   
118    assert(data_.rows()==x.rows());
119    assert(x.rows()==sigma2_.rows());
120    assert(x.rows()==centroids_.rows());
121
122   
123   
124    // each row in prediction corresponds to a sample label (class)
125    prediction.resize(centroids_.columns(), x.columns(), 0);
126    // weighted calculation
127    if (const MatrixLookupWeighted* mlw = 
128        dynamic_cast<const MatrixLookupWeighted*>(&x)) {
129      // first calculate -lnP = sum ln_sigma_i + (x_i-m_i)^2/2sigma_i^2
130      for (size_t label=0; label<centroids_.columns(); ++label) {
131        double sum_log_sigma = sum_logsigma(label);
132        for (size_t sample=0; sample<prediction.rows(); ++sample) {
133          prediction(label,sample) = sum_log_sigma;
134          for (size_t i=0; i<x.rows(); ++i) 
135            // taking care of NaN
136            if (mlw->weight(i, label)) {
137              prediction(label, sample) += mlw->weight(i, label)*
138                std::pow(mlw->data(i, label)-centroids_(i, label),2)/
139                sigma2_(i, label);
140            }
141     
142        }
143      }
144    }
145      // no weights
146    else if (const MatrixLookup* ml = dynamic_cast<const MatrixLookup*>(&x)) {
147      // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
148      for (size_t label=0; label<centroids_.columns(); ++label) {
149        double sum_log_sigma = sum_logsigma(label);
150        for (size_t sample=0; sample<prediction.rows(); ++sample) {
151          prediction(label,sample) = sum_log_sigma;
152          for (size_t i=0; i<ml->rows(); ++i) 
153            prediction(label, sample) += 
154              std::pow((*ml)(i, label)-centroids_(i, label),2)/sigma2_(i, label);
155        }
156      }
157    }
158    else {
159      std::string str = 
160        "Error in NBC::predict: DataLookup2D of unexpected class.";
161      throw std::runtime_error(str);
162    }
163
164
165    // -lnP might be a large number, in order to avoid out of bound
166    // problems when calculating P = exp(- -lnP), we centralize matrix
167    // by adding a constant.
168    double m=0;
169    for (size_t i=0; i<prediction.rows(); ++i)
170      for (size_t j=0; j<prediction.columns(); ++j)
171        m+=prediction(i,j);
172    prediction -= m/prediction.rows()/prediction.columns();
173
174    // exponentiate
175    for (size_t i=0; i<prediction.rows(); ++i)
176      for (size_t j=0; j<prediction.columns(); ++j)
177        prediction(i,j) = std::exp(prediction(i,j));
178
179    // normalize each row (label) to sum up to unity (probability)
180    for (size_t i=0; i<prediction.rows(); ++i)
181      utility::vector(prediction,i) *= 
182        1.0/utility::sum(utility::vector(prediction,i));
183
184  }
185
186
187  double NBC::sum_logsigma(size_t label) const
188  {
189    double sum_log_sigma=0;
190    assert(label<sigma2_.columns());
191    for (size_t i=0; i<sigma2_.rows(); ++i) {
192      sum_log_sigma += std::log(sigma2_(i, label));
193    }
194    return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
195  }
196
197}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.