source: trunk/yat/classifier/Perceptron.cc @ 3709

Last change on this file since 3709 was 3709, checked in by Peter, 4 years ago

new class that impements a Perceptron - special case of multivariate logistic regression. closes #901

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 3.2 KB
Line 
1// $Id: Perceptron.cc 3709 2017-11-08 22:49:06Z peter $
2
3/*
4  Copyright (C) 2017 Peter Johansson
5
6  This file is part of the yat library, http://dev.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 3 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with yat. If not, see <http://www.gnu.org/licenses/>.
20*/
21
22#include <config.h>
23
24#include "Perceptron.h"
25
26#include "Target.h"
27
28#include "yat/utility/DiagonalMatrix.h"
29#include "yat/utility/Matrix.h"
30#include "yat/utility/Vector.h"
31
32#include <gsl/gsl_cdf.h>
33
34#include <cassert>
35#include <cmath>
36#include <cmath>
37
38namespace theplu {
39namespace yat {
40namespace classifier {
41
42  const utility::Matrix& Perceptron::covariance(void) const
43  {
44    return covariance_;
45  }
46
47
48  double Perceptron::margin(size_t i, double alpha) const
49  {
50    return gsl_cdf_ugaussian_Qinv(alpha/2) * std::sqrt(covariance_(i, i));
51  }
52
53
54  double Perceptron::oddsratio(size_t i) const
55  {
56    return std::exp(weight_(i));
57  }
58
59
60  double Perceptron::oddsratio_lower_CI(size_t i, double alpha) const
61  {
62    return std::exp(weight_(i) - margin(i, alpha));
63  }
64
65
66  double Perceptron::oddsratio_upper_CI(size_t i, double alpha) const
67  {
68    return std::exp(weight_(i) + margin(i, alpha));
69  }
70
71
72  double Perceptron::p_value(size_t i) const
73  {
74    double z = weight_(i) / std::sqrt(covariance_(i, i));
75    return 2*gsl_cdf_ugaussian_Q(std::abs(z));
76  }
77
78
79  double Perceptron::predict(const utility::VectorBase& x) const
80  {
81    assert(x.size() == weight_.size());
82    const double f = weight_ * x;
83    return 1.0 / (1 + std::exp(-f));
84  }
85
86
87  void Perceptron::train(const utility::Matrix& X, const Target& target)
88  {
89    size_t n = X.rows();
90    size_t p = X.columns();
91
92    assert(target.size() == n);
93    weight_.resize(p);
94    covariance_.resize(p, p);
95
96    // weight vector is updated as
97    // w = (X'SX)^-1 X' (SXw + y - mu)
98    // X is n x p
99    // mu is vector of (trained) expected values (see predict(1))
100    utility::Vector mu(n);
101    // S is diagonal n x n with S_ii = mu_i (1 - mu_i)
102    utility::DiagonalMatrix S(n, n);
103    // y is binary vector
104    utility::Vector y(n);
105    for (size_t i=0; i<n; ++i)
106      if (target.binary(i))
107        y(i) = 1.0;
108
109    size_t max_epochs = 100;
110    double sum_squared = 1.0; // some (relatively) large number
111    for (size_t epoch=0; sum_squared > 1e-8 && epoch < max_epochs; ++epoch) {
112      for (size_t i=0; i<mu.size(); ++i) {
113        mu(i) = predict(X.row_const_view(i));
114        S(i) = mu(i) * (1.0 - mu(i));
115      }
116
117      // w = (X'SX)^-1 X' (SXw + y - mu)
118      assert(X.rows() == S.rows());
119      assert(S.columns() == X.rows());
120      utility::inverse_svd(transpose(X)*S*X, covariance_);
121
122      assert(y.size() == mu.size());
123      utility::Vector delta = covariance_ * (transpose(X) * (y - mu));
124      weight_ += delta;
125      sum_squared = delta * delta;
126    }
127  }
128
129
130  const utility::Vector& Perceptron::weight(void) const
131  {
132    return weight_;
133  }
134
135
136}}}
Note: See TracBrowser for help on using the repository browser.