Changeset 3648 for trunk/test/poisson.cc


Ignore:
Timestamp:
Jun 2, 2017, 8:01:27 AM (6 years ago)
Author:
Peter
Message:

add some tests for regression::Poisson. refs #882

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/poisson.cc

    r3615 r3648  
    22
    33/*
    4   Copyright (C) 2017 Jari Häkkinen
     4  Copyright (C) 2017 Peter Johansson
    55
    66  This file is part of the yat library, http://dev.thep.lu.se/yat
     
    2525
    2626#include "yat/regression/Poisson.h"
     27#include "yat/random/random.h"
     28#include "yat/statistics/Averager.h"
    2729#include "yat/utility/Matrix.h"
    2830#include "yat/utility/Vector.h"
    2931
     32#include <vector>
     33
    3034using namespace theplu::yat;
     35
     36void analyse(const utility::Vector& b,
     37             std::vector<statistics::Averager>& stats,
     38             std::vector<statistics::Averager>& stats2);
    3139
    3240int main(int argc, char* argv[])
     
    4856  model.predict(X.row_const_view(0));
    4957
     58  utility::Vector b(4);
     59  b(0) = 0.5;
     60  b(1) = 2.0;
     61  b(2) = 0.75;
     62  b(3) = -1.25;
     63  std::vector<statistics::Averager> stats(b.size());
     64  std::vector<statistics::Averager> stats2(b.size());
     65  for (size_t i=0; i<100; ++i)
     66    analyse(b, stats, stats2);
     67
     68  for (size_t i=0; i<b.size(); ++i) {
     69    suite.out() << i << " " << b(i) << " " << stats[i].mean() << " "
     70                << stats[i].standard_error() << "\n";
     71    if (stats[i].standard_error() == 0.0) {
     72      suite.xadd(false);
     73      suite.err() << "error: standard error is 0.0\n";
     74    }
     75    else if (std::abs(stats[i].mean() - b(i)) > 5*stats[i].standard_error()) {
     76      suite.xadd(false);
     77      suite.err() << "error: average for param " << i << ": "
     78                  << stats[i].mean() << "\n";
     79    }
     80  }
     81
    5082  return suite.return_value();
    5183}
     84
     85
     86void generate_data(const utility::Vector& b,
     87                   utility::Matrix& X, utility::Vector& y)
     88{
     89  size_t n = 5000;
     90  X.resize(n, b.size()-1);
     91  y.resize(n);
     92  random::Gaussian gauss;
     93  random::Poisson poisson;
     94  for (utility::Matrix::iterator it=X.begin(); it!=X.end(); ++it)
     95    *it = gauss();
     96
     97  for (size_t i=0; i<n; ++i) {
     98    double lnmu = b(0);
     99    for (size_t j=1; j<b.size(); ++j)
     100      lnmu += X(i, j-1) * b(j);
     101
     102    double mu = exp(lnmu);
     103    y(i) = poisson(mu);
     104  }
     105}
     106
     107
     108void analyse(const utility::Vector& b,
     109             std::vector<statistics::Averager>& stats,
     110             std::vector<statistics::Averager>& stats2)
     111{
     112  utility::Matrix X;
     113  utility::Vector y;
     114  generate_data(b, X, y);
     115  theplu::yat::regression::Poisson poisson;
     116  poisson.fit(X, y);
     117  for (size_t i=0; i<poisson.fit_parameters().size(); ++i) {
     118    stats[i].add(poisson.fit_parameters()(i));
     119  }
     120}
Note: See TracChangeset for help on using the changeset viewer.