source: trunk/yat/classifier/NBC.cc @ 813

Last change on this file since 813 was 813, checked in by Peter, 15 years ago

Predict in NBC. Fixes #57

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.7 KB
Line 
1// $Id: NBC.cc 813 2007-03-16 19:30:02Z peter $
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "NBC.h"
25#include "DataLookup2D.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "Target.h"
29#include "yat/statistics/AveragerWeighted.h"
30#include "yat/utility/matrix.h"
31
32#include <cassert>
33#include <cmath>
34#include <vector>
35
36namespace theplu {
37namespace yat {
38namespace classifier {
39
40  NBC::NBC(const MatrixLookup& data, const Target& target) 
41    : SupervisedClassifier(target), data_(data)
42  {
43  }
44
45  NBC::NBC(const MatrixLookupWeighted& data, const Target& target) 
46    : SupervisedClassifier(target), data_(data)
47  {
48  }
49
50  NBC::~NBC()   
51  {
52  }
53
54
55  const DataLookup2D& NBC::data(void) const
56  {
57    return data_;
58  }
59
60
61  SupervisedClassifier* 
62  NBC::make_classifier(const DataLookup2D& data, const Target& target) const 
63  {     
64    NBC* ncc=0;
65    if(data.weighted()) {
66      ncc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
67    }
68    else {
69      ncc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
70    }
71    return ncc;
72  }
73
74
75  bool NBC::train()
76  {   
77    sigma2_.resize(data_.rows(), target_.nof_classes());
78    centroids_.resize(data_.rows(), target_.nof_classes());
79    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
80   
81    for(size_t i=0; i<data_.rows(); ++i) {
82      std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
83      for(size_t j=0; j<data_.columns(); ++j) {
84        if (data_.weighted()){
85          const MatrixLookupWeighted& data = 
86            dynamic_cast<const MatrixLookupWeighted&>(data_);
87          aver[target_(j)].add(data.data(i,j), data.weight(i,j));
88        }
89        else
90          aver[target_(j)].add(data_(i,j),1.0);
91      }
92      assert(centroids_.columns()==target_.nof_classes());
93      for (size_t j=0; j<target_.nof_classes(); ++j){
94        assert(i<centroids_.rows());
95        assert(j<centroids_.columns());
96        centroids_(i,j) = aver[j].mean();
97        assert(i<sigma2_.rows());
98        assert(j<sigma2_.columns());
99        sigma2_(i,j) = aver[j].variance();
100      }
101    }   
102    trained_=true;
103    return trained_;
104  }
105
106
107  void NBC::predict(const DataLookup2D& x,                   
108                    utility::matrix& prediction) const
109  {   
110    assert(data_.rows()==x.rows());
111    assert(x.rows()==sigma2_.rows());
112    assert(x.rows()==centroids_.rows());
113
114    const MatrixLookupWeighted* w = 
115      dynamic_cast<const MatrixLookupWeighted*>(&x);
116
117    // each row in prediction corresponds to a sample label (class)
118    prediction.resize(centroids_.columns(), x.columns(), 0);
119    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
120    for (size_t label=0; label<centroids_.columns(); ++label) {
121      double sum_ln_sigma=0;
122      assert(label<sigma2_.columns());
123      for (size_t i=0; i<x.rows(); ++i) {
124        assert(i<sigma2_.rows());
125        sum_ln_sigma += std::log(sigma2_(i, label));
126      }
127      sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2
128      for (size_t sample=0; sample<prediction.rows(); ++sample) {
129        for (size_t i=0; i<x.rows(); ++i) {
130          // weighted calculation
131          if (w){
132            // taking care of NaN
133            if (w->weight(i, label)){
134            prediction(label, sample) += w->weight(i, label)*
135              std::pow(w->data(i, label)-centroids_(i, label),2)/
136              sigma2_(i, label);
137            }
138          }
139          // no weights
140          else {
141            prediction(label, sample) += 
142              std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
143          }
144        }
145      }
146    }
147
148    // -lnP might be a large number, in order to avoid out of bound
149    // problems when calculating P = exp(- -lnP), we centralize matrix
150    // by adding a constant.
151    double m=0;
152    for (size_t i=0; i<prediction.rows(); ++i)
153      for (size_t j=0; j<prediction.columns(); ++j)
154        m+=prediction(i,j);
155    prediction -= m/prediction.rows()/prediction.columns();
156
157    // exponentiate
158    for (size_t i=0; i<prediction.rows(); ++i)
159      for (size_t j=0; j<prediction.columns(); ++j)
160        prediction(i,j) = std::exp(prediction(i,j));
161
162    // normalize each row (label) to sum up to unity (probability)
163    for (size_t i=0; i<prediction.rows(); ++i)
164      utility::vector(prediction,i) *= 
165        1.0/utility::sum(utility::vector(prediction,i));
166
167  }
168
169
170}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.