source: trunk/yat/regression/MultiDimensionalWeighted.cc @ 1650

Last change on this file since 1650 was 1487, checked in by Jari Häkkinen, 13 years ago

Addresses #436. GPL license copy reference should also be updated.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 3.1 KB
Line 
1// $Id: MultiDimensionalWeighted.cc 1487 2008-09-10 08:41:36Z jari $
2
3/*
4  Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
5
6  This file is part of the yat library, http://dev.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 3 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with yat. If not, see <http://www.gnu.org/licenses/>.
20*/
21
22#include "MultiDimensionalWeighted.h"
23#include "yat/statistics/AveragerWeighted.h"
24#include "yat/utility/Matrix.h"
25#include "yat/utility/Vector.h"
26
27#include <cassert>
28
29namespace theplu {
30namespace yat {
31namespace regression {
32
33  MultiDimensionalWeighted::MultiDimensionalWeighted(void)
34    : chisquare_(0), work_(NULL)
35  {
36  }
37
38  MultiDimensionalWeighted::~MultiDimensionalWeighted(void)
39  {
40    if (work_)
41      gsl_multifit_linear_free(work_);
42  }
43
44
45  double MultiDimensionalWeighted::chisq() const
46  {
47    return chisquare_;
48  }
49
50
51  void MultiDimensionalWeighted::fit(const utility::Matrix& x, 
52                                     const utility::VectorBase& y,
53                                     const utility::VectorBase& w)
54  {
55    assert(y.size()==w.size());
56    assert(x.rows()==y.size());
57
58    covariance_.resize(x.columns(),x.columns());
59    fit_parameters_ = utility::Vector(x.columns());
60    if (work_)
61      gsl_multifit_linear_free(work_);
62    if (!(work_=gsl_multifit_linear_alloc(x.rows(),fit_parameters_.size())))
63      throw utility::GSL_error("MultiDimensionalWeighted::fit failed to allocate memory");
64    int status = gsl_multifit_wlinear(x.gsl_matrix_p(), w.gsl_vector_p(),
65                                      y.gsl_vector_p(), 
66                                      fit_parameters_.gsl_vector_p(),
67                                      covariance_.gsl_matrix_p(), &chisquare_,
68                                      work_);
69    if (status)
70      throw utility::GSL_error(std::string("MultiDimensionalWeighted::fit",
71                                           status));
72
73    statistics::AveragerWeighted aw;
74    add(aw, y.begin(), y.end(), w.begin());
75    s2_ = chisquare_ / (aw.n()-fit_parameters_.size());
76    covariance_ *= s2_;
77  }
78
79
80  const utility::Vector& MultiDimensionalWeighted::fit_parameters(void) const
81  {
82    return fit_parameters_;
83  }
84
85
86  double MultiDimensionalWeighted::predict(const utility::VectorBase& x) const
87  {
88    assert(x.size()==fit_parameters_.size());
89    return fit_parameters_ * x;
90  }
91
92
93  double MultiDimensionalWeighted::prediction_error2(const utility::VectorBase& x,
94                                                     const double w) const
95  {
96    return standard_error2(x) + s2(w);
97  }
98
99
100  double MultiDimensionalWeighted::s2(const double w) const
101  {
102    return s2_/w;
103  }
104
105
106  double 
107  MultiDimensionalWeighted::standard_error2(const utility::VectorBase& x) const
108  {
109    double c = 0;
110    for (size_t i=0; i<x.size(); ++i){
111      c += covariance_(i,i)*x(i)*x(i);
112      for (size_t j=i+1; j<x.size(); ++j)
113        c += 2*covariance_(i,j)*x(i)*x(j);
114    }
115    return c;
116  }
117
118}}} // of namespaces regression, yat, and theplu
Note: See TracBrowser for help on using the repository browser.