source: trunk/yat/random/random.cc @ 3469

Last change on this file since 3469 was 3469, checked in by Peter, 6 years ago

New class for NegativeHyperGeometric?. closes #857. The implementation
is quite naive and calls RNG several times within one call, which
might be something that can be improved. This code was inspired by
gsl_ran_hypergeometric.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.6 KB
Line 
1// $Id: random.cc 3469 2016-02-29 01:55:17Z peter $
2
3/*
4  Copyright (C) 2005, 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2009, 2011, 2012, 2013, 2015, 2016 Peter Johansson
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include <config.h>
24
25#include "random.h"
26#include "yat/statistics/Histogram.h"
27#include "yat/utility/Exception.h"
28
29#include <boost/thread/locks.hpp>
30
31#include <cassert>
32#include <cstring>
33#include <fstream>
34#include <sstream>
35
36namespace theplu {
37namespace yat {
38namespace random {
39
40  RNG* RNG::instance_=NULL;
41
42  RNG::RNG(void)
43    : rng_(gsl_rng_free)
44  {
45      // support rng/seed changes through environment vars
46    if (!gsl_rng_env_setup())
47      throw utility::GSL_error("RNG::RNG unknown generator");
48    seed_ = gsl_rng_default_seed;
49    // let's allocate already here, just to behave as yat 0.8
50    rng_alloc();
51  }
52
53
54  RNG::~RNG(void)
55  {
56  }
57
58
59  RNG* RNG::instance(void)
60  {
61    if (instance_==NULL)
62      instance_ = new RNG;
63    return instance_;
64  }
65
66
67  unsigned long RNG::max(void) const
68  {
69    return gsl_rng_max(rng());
70  }
71
72
73  unsigned long RNG::min(void) const
74  {
75    return gsl_rng_min(rng());
76  }
77
78
79  std::string RNG::name(void) const
80  {
81    return gsl_rng_name(rng());
82  }
83
84
85  const gsl_rng* RNG::rng(void) const
86  {
87    if (rng_.get()==NULL)
88      rng_alloc();
89    return rng_.get();
90  }
91
92
93  void RNG::rng_alloc(void) const
94  {
95    assert(rng_.get()==NULL);
96    gsl_rng* rng = gsl_rng_alloc(gsl_rng_default);
97    if (!rng)
98      throw utility::GSL_error("RNG failed to allocate memory");
99    boost::unique_lock<boost::mutex> lock(mutex_);
100    gsl_rng_set(rng, seed_);
101    // bump seed to avoid subsequent gsl_rng to be identical
102    ++seed_;
103    // rng_ owns rng and takes care of deallocation
104    rng_.reset(rng);
105  } // lock is released here
106
107
108  void RNG::seed(unsigned long s) const
109  {
110    boost::unique_lock<boost::mutex> lock(mutex_);
111    gsl_rng_set(rng(),s);
112    seed_ = s+1;
113  } // lock is released here
114
115
116  unsigned long RNG::seed_from_devurandom(void)
117  {
118    unsigned char ulongsize=sizeof(unsigned long);
119    char* buffer=new char[ulongsize];
120    std::ifstream is("/dev/urandom", std::ios::binary);
121    is.read(buffer,ulongsize);
122    is.close();
123    unsigned long s=0;
124    memcpy(&s, buffer, ulongsize);
125    delete[] buffer;
126    seed(s);
127    return s;
128  }
129
130
131  int RNG::set_state(const RNG_state& state)
132  {
133    if (rng_.get()==NULL)
134      rng_alloc();
135    if (gsl_rng_memcpy(rng_.get(), state.rng()))
136      throw utility::GSL_error("yat::random::RNG::set_state failed");
137    return 0;
138  }
139
140  // --------------------- RNG_state ----------------------------------
141
142  RNG_state::RNG_state(const RNG* rng)
143  {
144    clone(*rng->rng());
145  }
146
147
148  RNG_state::RNG_state(const RNG_state& state)
149  {
150    clone(*state.rng());
151  }
152
153
154  RNG_state::~RNG_state(void)
155  {
156    gsl_rng_free(rng_);
157    rng_=NULL;
158  }
159
160  const gsl_rng* RNG_state::rng(void) const
161  {
162    return rng_;
163  }
164
165
166  void RNG_state::clone(const gsl_rng& rng)
167  {
168    assert(rng_!=&rng);
169    if (!(rng_ = gsl_rng_clone(&rng)))
170      throw utility::GSL_error("RNG_state::clone failed to allocate memory");
171  }
172
173  RNG_state& RNG_state::operator=(const RNG_state& rhs)
174  {
175    if (this != &rhs) {
176      gsl_rng_free(rng_);
177      clone(*rhs.rng());
178    }
179    return *this;
180  }
181
182  // --------------------- Discrete distribtuions ---------------------
183
184  Discrete::Discrete(void)
185    : rng_(RNG::instance())
186  {
187  }
188
189
190  Discrete::~Discrete(void)
191  {
192  }
193
194
195  void Discrete::seed(unsigned long s) const
196  {
197    rng_->seed(s);
198  }
199
200
201  unsigned long Discrete::seed_from_devurandom(void)
202  {
203    return rng_->seed_from_devurandom();
204  }
205
206
207  Binomial::Binomial(double p, unsigned int n)
208    : Discrete(), p_(p), n_(n)
209  {
210  }
211
212
213  unsigned long Binomial::operator()(void) const
214  {
215    return gsl_ran_binomial(rng_->rng(), p_, n_);
216  }
217
218
219  DiscreteGeneral::DiscreteGeneral(const statistics::Histogram& hist)
220    : gen_(NULL)
221  {
222    p_.reserve(hist.nof_bins());
223    for (size_t i=0; i<hist.nof_bins(); i++)
224      p_.push_back(hist[i]);
225    preproc();
226  }
227
228
229  DiscreteGeneral::DiscreteGeneral(const DiscreteGeneral& other)
230    : Discrete(other), gen_(NULL), p_(other.p_)
231  {
232    preproc();
233  }
234
235
236  DiscreteGeneral::~DiscreteGeneral(void)
237  {
238    free();
239  }
240
241
242  void DiscreteGeneral::free(void)
243  {
244    if (gen_)
245      gsl_ran_discrete_free( gen_ );
246    gen_ = NULL;
247  }
248
249
250  void DiscreteGeneral::preproc(void)
251  {
252    assert(!gen_);
253    assert(p_.size());
254    gen_ = gsl_ran_discrete_preproc( p_.size(), &p_.front() );
255    if (!gen_)
256      throw utility::GSL_error("DiscreteGeneral failed to setup generator.");
257  }
258
259
260  DiscreteGeneral& DiscreteGeneral::operator=(const DiscreteGeneral& rhs)
261  {
262    free();
263    p_ = rhs.p_;
264    preproc();
265    return *this;
266  }
267
268
269  unsigned long DiscreteGeneral::operator()(void) const
270  {
271    return gsl_ran_discrete(rng_->rng(), gen_);
272  }
273
274
275  DiscreteUniform::DiscreteUniform(unsigned long n)
276    : range_(n)
277  {
278    if (range_>rng_->max()) {
279      std::ostringstream ss;
280      ss << "DiscreteUniform::DiscreteUniform: ";
281      ss << n << " is too large for RNG " << rng_->name();
282      ss << "; maximal argument is " << rng_->max();
283      throw utility::GSL_error(ss.str());
284    }
285  }
286
287
288  unsigned long DiscreteUniform::operator()(void) const
289  {
290    return (range_ ?
291            gsl_rng_uniform_int(rng_->rng(),range_) : gsl_rng_get(rng_->rng()));
292  }
293
294
295  unsigned long DiscreteUniform::operator()(unsigned long n) const
296  {
297    // making sure that n is not larger than the range of the
298    // underlying RNG
299    if (n>rng_->max()) {
300      std::ostringstream ss;
301      ss << "DiscreteUniform::operator(unsigned long): ";
302      ss << n << " is too large for RNG " << rng_->name();
303      ss << "; maximal argument is " << rng_->max();
304      throw utility::GSL_error(ss.str());
305    }
306    return gsl_rng_uniform_int(rng_->rng(),n);
307  }
308
309
310  Geometric::Geometric(double p)
311    : p_(p)
312  {}
313
314
315  unsigned long int Geometric::operator()(void) const
316  {
317    return gsl_ran_geometric (rng_->rng(), p_); 
318  }
319
320
321  unsigned long int Geometric::operator()(double p) const
322  {
323    return gsl_ran_geometric (rng_->rng(), p);
324  }
325
326
327  HyperGeometric::HyperGeometric(void)
328  {}
329
330
331  HyperGeometric::HyperGeometric(unsigned int n1, unsigned int n2,
332                                 unsigned int t)
333    : n1_(n1), n2_(n2), t_(t)
334  {}
335
336
337  unsigned long int HyperGeometric::operator()(void) const
338  {
339    return (*this)(n1_, n2_, t_);
340  }
341
342
343  unsigned long int HyperGeometric::operator()(unsigned int n1,
344                                               unsigned int n2,
345                                               unsigned int t) const
346  {
347    return gsl_ran_hypergeometric(rng_->rng(), n1, n2, t);
348  }
349
350
351  NegativeHyperGeometric::NegativeHyperGeometric(void)
352  {}
353
354
355  NegativeHyperGeometric::NegativeHyperGeometric(unsigned int n1,
356                                                 unsigned int n2, unsigned int t)
357    : n1_(n1), n2_(n2), t_(t)
358  {}
359
360
361  unsigned long int NegativeHyperGeometric::operator()(void) const
362  {
363    return (*this)(n1_, n2_, t_);
364  }
365
366
367  unsigned long int NegativeHyperGeometric::operator()(unsigned int n1,
368                                                       unsigned int n2,
369                                                       unsigned int t) const
370  {
371    assert(t <= n2);
372
373    // NHG can be described as an array with n1 true and n2 false, and
374    // NHG(n1, n2, t) is the number of true left of the t:th false. By
375    // symmetry number of true right of the t:th false is NHG(n1, n2,
376    // n2-t+1) since t:th false counting from left is the (n2-t+1):th
377    // false counting from right.
378
379    // When t is larger than midpoint (2*t > n2+1) we use this
380    // symmetry to speed things up.
381    if (t > (n2+1)/2)
382      return n1 - (*this)(n1, n2, n2-t+1);
383
384    ContinuousUniform uniform;
385    unsigned long int k = 0;
386    while (t) {
387      assert(n1 + n2);
388      double x = uniform();
389      if (x * (n1+n2) < n1) {
390        --n1;
391        ++k;
392        if (!n1)
393          return k;
394      }
395      else {
396        --t;
397        --n2;
398      }
399    }
400    return k;
401  }
402
403
404  Poisson::Poisson(const double m)
405    : m_(m)
406  {
407  }
408
409  unsigned long Poisson::operator()(void) const
410  {
411    return gsl_ran_poisson(rng_->rng(), m_);
412  }
413
414
415  unsigned long Poisson::operator()(const double m) const
416  {
417    return gsl_ran_poisson(rng_->rng(), m);
418  }
419
420  // --------------------- Continuous distribtuions ---------------------
421
422  Continuous::Continuous(void)
423    : rng_(RNG::instance())
424  {
425  }
426
427
428  Continuous::~Continuous(void)
429  {
430  }
431
432
433  void Continuous::seed(unsigned long s) const
434  {
435    rng_->seed(s);
436  }
437
438
439  unsigned long Continuous::seed_from_devurandom(void)
440  {
441    return rng_->seed_from_devurandom();
442  }
443
444
445  ContinuousGeneral::ContinuousGeneral(const statistics::Histogram& hist)
446    : discrete_(DiscreteGeneral(hist)), hist_(hist)
447  {
448  }
449
450
451  double ContinuousGeneral::operator()(void) const
452  {
453    return hist_.observation_value(discrete_())+(u_()-0.5)*hist_.spacing();
454  }
455
456  double ContinuousUniform::operator()(void) const
457  {
458    return gsl_rng_uniform(rng_->rng());
459  }
460
461
462  Exponential::Exponential(const double m)
463    : m_(m)
464  {
465  }
466
467
468  double Exponential::operator()(void) const
469  {
470    return gsl_ran_exponential(rng_->rng(), m_);
471  }
472
473
474  double Exponential::operator()(const double m) const
475  {
476    return gsl_ran_exponential(rng_->rng(), m);
477  }
478
479
480  Gaussian::Gaussian(const double s, const double m)
481    : m_(m), s_(s)
482  {
483  }
484
485
486  double Gaussian::operator()(void) const
487  {
488    return gsl_ran_gaussian(rng_->rng(), s_)+m_;
489  }
490
491
492  double Gaussian::operator()(const double s) const
493  {
494    return gsl_ran_gaussian(rng_->rng(), s);
495  }
496
497
498  double Gaussian::operator()(const double s, const double m) const
499  {
500    return gsl_ran_gaussian(rng_->rng(), s)+m;
501  }
502
503}}} // of namespace random, yat, and theplu
Note: See TracBrowser for help on using the repository browser.