source: trunk/yat/classifier/NBC.cc @ 812

Last change on this file since 812 was 812, checked in by Peter, 15 years ago

implemented NBC predict. Have to check what happens in the weighted case though, Refs #57

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.0 KB
Line 
1// $Id: NBC.cc 812 2007-03-16 01:02:07Z peter $
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "NBC.h"
25#include "DataLookup2D.h"
26#include "MatrixLookup.h"
27#include "MatrixLookupWeighted.h"
28#include "Target.h"
29#include "yat/statistics/AveragerWeighted.h"
30#include "yat/utility/matrix.h"
31
32#include <cassert>
33#include <cmath>
34#include <vector>
35
36namespace theplu {
37namespace yat {
38namespace classifier {
39
40  NBC::NBC(const MatrixLookup& data, const Target& target) 
41    : SupervisedClassifier(target), data_(data)
42  {
43  }
44
45  NBC::NBC(const MatrixLookupWeighted& data, const Target& target) 
46    : SupervisedClassifier(target), data_(data)
47  {
48  }
49
50  NBC::~NBC()   
51  {
52  }
53
54
55  const DataLookup2D& NBC::data(void) const
56  {
57    return data_;
58  }
59
60
61  SupervisedClassifier* 
62  NBC::make_classifier(const DataLookup2D& data, const Target& target) const 
63  {     
64    NBC* ncc=0;
65    if(data.weighted()) {
66      ncc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
67    }
68    else {
69      ncc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
70    }
71    return ncc;
72  }
73
74
75  bool NBC::train()
76  {   
77    sigma2_.resize(data_.rows(), target_.nof_classes());
78    centroids_.resize(data_.rows(), target_.nof_classes());
79    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
80   
81    for(size_t i=0; i<data_.rows(); ++i) {
82      std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
83      for(size_t j=0; j<data_.columns(); ++j) {
84        if (data_.weighted()){
85          const MatrixLookupWeighted& data = 
86            dynamic_cast<const MatrixLookupWeighted&>(data_);
87          aver[target_(j)].add(data.data(i,j), data.weight(i,j));
88        }
89        else
90          aver[target_(j)].add(data_(i,j),1.0);
91      }
92      for (size_t j=0; target_.nof_classes(); ++j){
93        centroids_(i,j) = aver[j].mean();
94        sigma2_(i,j) = aver[j].variance();
95      }
96    }   
97    trained_=true;
98    return trained_;
99  }
100
101
102  void NBC::predict(const DataLookup2D& x,                   
103                    utility::matrix& prediction) const
104  {   
105    assert(data_.rows()==x.rows());
106
107    // each row in prediction corresponds to a sample label (class)
108    prediction.resize(centroids_.columns(), x.columns(), 0);
109    // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
110    for (size_t label=0; label<prediction.columns(); ++label) {
111      double sum_ln_sigma=0;
112      for (size_t i=0; i<x.rows(); ++i) 
113        sum_ln_sigma += std::log(sigma2_(i, label));
114      sum_ln_sigma /= 2; // taking sum of log(sigma) not sigma2
115      for (size_t sample=0; sample<prediction.rows(); ++sample) {
116        for (size_t i=0; i<x.rows(); ++i) {
117          prediction(label, sample) += 
118            std::pow(x(i, label)-centroids_(i, label),2)/sigma2_(i, label);
119        }
120      }
121    }
122
123    // -lnP might be a large number, in order to avoid out of bound
124    // problems when calculating P = exp(- -lnP), we centralize matrix
125    // by adding a constant.
126    double m=0;
127    for (size_t i=0; i<prediction.rows(); ++i)
128      for (size_t j=0; j<prediction.columns(); ++j)
129        m+=prediction(i,j);
130    prediction -= m/prediction.rows()/prediction.columns();
131
132    // exponentiate
133    for (size_t i=0; i<prediction.rows(); ++i)
134      for (size_t j=0; j<prediction.columns(); ++j)
135        prediction(i,j) = std::exp(prediction(i,j));
136
137    // normalize each row (label) to sum up to unity (probability)
138    for (size_t i=0; i<prediction.rows(); ++i)
139      utility::vector(prediction,i) *= 
140        1.0/utility::sum(utility::vector(prediction,i));
141
142  }
143
144
145}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.