source: trunk/yat/classifier/NBC.cc @ 865

Last change on this file since 865 was 865, checked in by Peter, 14 years ago

changing URL to http://trac.thep.lu.se/trac/yat

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.8 KB
Line 
1// $Id: NBC.cc 865 2007-09-10 19:41:04Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "NBC.h"
26#include "DataLookup2D.h"
27#include "MatrixLookup.h"
28#include "MatrixLookupWeighted.h"
29#include "Target.h"
30#include "yat/statistics/AveragerWeighted.h"
31#include "yat/utility/matrix.h"
32
33#include <cassert>
34#include <cmath>
35#include <vector>
36
37namespace theplu {
38namespace yat {
39namespace classifier {
40
41  NBC::NBC(const MatrixLookup& data, const Target& target) 
42    : SupervisedClassifier(target), data_(data)
43  {
44  }
45
46  NBC::NBC(const MatrixLookupWeighted& data, const Target& target) 
47    : SupervisedClassifier(target), data_(data)
48  {
49  }
50
51  NBC::~NBC()   
52  {
53  }
54
55
56  const DataLookup2D& NBC::data(void) const
57  {
58    return data_;
59  }
60
61
62  SupervisedClassifier* 
63  NBC::make_classifier(const DataLookup2D& data, const Target& target) const 
64  {     
65    NBC* ncc=0;
66    if(data.weighted()) {
67      ncc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
68    }
69    else {
70      ncc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
71    }
72    return ncc;
73  }
74
75
76  bool NBC::train()
77  {   
78    sigma2_.resize(data_.rows(), target_.nof_classes());
79    centroids_.resize(data_.rows(), target_.nof_classes());
80    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
81   
82    for(size_t i=0; i<data_.rows(); ++i) {
83      std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
84      for(size_t j=0; j<data_.columns(); ++j) {
85        if (data_.weighted()){
86          const MatrixLookupWeighted& data = 
87            dynamic_cast<const MatrixLookupWeighted&>(data_);
88          aver[target_(j)].add(data.data(i,j), data.weight(i,j));
89        }
90        else
91          aver[target_(j)].add(data_(i,j),1.0);
92      }
93      assert(centroids_.columns()==target_.nof_classes());
94      for (size_t j=0; j<target_.nof_classes(); ++j){
95        assert(i<centroids_.rows());
96        assert(j<centroids_.columns());
97        centroids_(i,j) = aver[j].mean();
98        assert(i<sigma2_.rows());
99        assert(j<sigma2_.columns());
100        sigma2_(i,j) = aver[j].variance();
101      }
102    }   
103    trained_=true;
104    return trained_;
105  }
106
107
108  void NBC::predict(const DataLookup2D& x,                   
109                    utility::matrix& prediction) const
110  {   
111    assert(data_.rows()==x.rows());
112    assert(x.rows()==sigma2_.rows());
113    assert(x.rows()==centroids_.rows());
114
115    const MatrixLookupWeighted* w = 
116      dynamic_cast<const MatrixLookupWeighted*>(&x);
117
118    // each row in prediction corresponds to a sample label (class)
119    prediction.resize(centroids_.columns(), x.columns(), 0);
120    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
121    for (size_t label=0; label<centroids_.columns(); ++label) {
122      double sum_ln_sigma=0;
123      assert(label<sigma2_.columns());
124      for (size_t i=0; i<x.rows(); ++i) {
125        assert(i<sigma2_.rows());
126        sum_ln_sigma += std::log(sigma2_(i, label));
127      }
128      sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2
129      for (size_t sample=0; sample<prediction.rows(); ++sample) {
130        for (size_t i=0; i<x.rows(); ++i) {
131          // weighted calculation
132          if (w){
133            // taking care of NaN
134            if (w->weight(i, label)){
135            prediction(label, sample) += w->weight(i, label)*
136              std::pow(w->data(i, label)-centroids_(i, label),2)/
137              sigma2_(i, label);
138            }
139          }
140          // no weights
141          else {
142            prediction(label, sample) += 
143              std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
144          }
145        }
146      }
147    }
148
149    // -lnP might be a large number, in order to avoid out of bound
150    // problems when calculating P = exp(- -lnP), we centralize matrix
151    // by adding a constant.
152    double m=0;
153    for (size_t i=0; i<prediction.rows(); ++i)
154      for (size_t j=0; j<prediction.columns(); ++j)
155        m+=prediction(i,j);
156    prediction -= m/prediction.rows()/prediction.columns();
157
158    // exponentiate
159    for (size_t i=0; i<prediction.rows(); ++i)
160      for (size_t j=0; j<prediction.columns(); ++j)
161        prediction(i,j) = std::exp(prediction(i,j));
162
163    // normalize each row (label) to sum up to unity (probability)
164    for (size_t i=0; i<prediction.rows(); ++i)
165      utility::vector(prediction,i) *= 
166        1.0/utility::sum(utility::vector(prediction,i));
167
168  }
169
170
171}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.