source: trunk/yat/random/random.cc @ 4009

Last change on this file since 4009 was 4009, checked in by Peter, 12 months ago

closes #965; Default constructor for DiscreteGeneral?

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 9.9 KB
Line 
1// $Id: random.cc 4009 2020-10-19 00:54:47Z peter $
2
3/*
4  Copyright (C) 2005, 2006, 2007, 2008 Jari Häkkinen, Peter Johansson
5  Copyright (C) 2009, 2011, 2012, 2013, 2015, 2016, 2017, 2018, 2020 Peter Johansson
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include <config.h>
24
25#include "random.h"
26#include "yat/statistics/Histogram.h"
27#include "yat/utility/Exception.h"
28#include "yat/utility/utility.h"
29
30#include <cassert>
31#include <cstddef>
32#include <cstring>
33#include <fstream>
34#include <mutex>
35#include <sstream>
36
37namespace theplu {
38namespace yat {
39namespace random {
40
41  std::atomic<RNG*> RNG::instance_ { nullptr };
42  std::mutex RNG::init_mutex_;
43  thread_local std::unique_ptr<gsl_rng, void(*)(gsl_rng*)> RNG::rng_
44  { nullptr, gsl_rng_free};
45
46  RNG::RNG(void)
47  {
48      // support rng/seed changes through environment vars
49    if (!gsl_rng_env_setup())
50      throw utility::GSL_error("RNG::RNG unknown generator");
51    seed_ = gsl_rng_default_seed;
52    // let's allocate already here, just to behave as yat 0.8
53    rng_alloc();
54  }
55
56
57  RNG::~RNG(void)
58  {
59  }
60
61
62  RNG* RNG::instance(void)
63  {
64    if (instance_==NULL) {
65      std::unique_lock<std::mutex> lock(init_mutex_);
66      if (instance_==NULL)
67        instance_ = new RNG;
68    }
69    return instance_;
70  }
71
72
73  unsigned long RNG::max(void) const
74  {
75    return gsl_rng_max(rng());
76  }
77
78
79  unsigned long RNG::min(void) const
80  {
81    return gsl_rng_min(rng());
82  }
83
84
85  std::string RNG::name(void) const
86  {
87    return gsl_rng_name(rng());
88  }
89
90
91  const gsl_rng* RNG::rng(void) const
92  {
93    if (rng_.get()==NULL)
94      rng_alloc();
95    return rng_.get();
96  }
97
98
99  void RNG::rng_alloc(void) const
100  {
101    assert(rng_.get()==NULL);
102    gsl_rng* rng = gsl_rng_alloc(gsl_rng_default);
103    if (!rng)
104      throw utility::GSL_error("RNG failed to allocate memory");
105    std::unique_lock<std::mutex> lock(mutex_);
106    gsl_rng_set(rng, seed_);
107    // bump seed to avoid subsequent gsl_rng to be identical
108    ++seed_;
109    // rng_ owns rng and takes care of deallocation
110    rng_.reset(rng);
111  } // lock is released here
112
113
114  void RNG::seed(unsigned long s) const
115  {
116    std::unique_lock<std::mutex> lock(mutex_);
117    gsl_rng_set(rng(),s);
118    seed_ = s+1;
119  } // lock is released here
120
121
122  unsigned long RNG::seed_from_devurandom(void)
123  {
124    std::ifstream is("/dev/urandom", std::ios::binary);
125    unsigned long s=0;
126    utility::binary_read(is, s);
127    is.close();
128    seed(s);
129    return s;
130  }
131
132
133  int RNG::set_state(const RNG_state& state)
134  {
135    if (rng_.get()==NULL)
136      rng_alloc();
137    if (gsl_rng_memcpy(rng_.get(), state.rng()))
138      throw utility::GSL_error("yat::random::RNG::set_state failed");
139    return 0;
140  }
141
142  // --------------------- RNG_state ----------------------------------
143
144  RNG_state::RNG_state(const RNG* rng)
145  {
146    clone(*rng->rng());
147  }
148
149
150  RNG_state::RNG_state(const RNG_state& state)
151  {
152    clone(*state.rng());
153  }
154
155
156  RNG_state::~RNG_state(void)
157  {
158    gsl_rng_free(rng_);
159    rng_=NULL;
160  }
161
162  const gsl_rng* RNG_state::rng(void) const
163  {
164    return rng_;
165  }
166
167
168  void RNG_state::clone(const gsl_rng& rng)
169  {
170    assert(rng_!=&rng);
171    if (!(rng_ = gsl_rng_clone(&rng)))
172      throw utility::GSL_error("RNG_state::clone failed to allocate memory");
173  }
174
175  RNG_state& RNG_state::operator=(const RNG_state& rhs)
176  {
177    if (this != &rhs) {
178      gsl_rng_free(rng_);
179      clone(*rhs.rng());
180    }
181    return *this;
182  }
183
184  // --------------------- Discrete distribtuions ---------------------
185
186  Discrete::Discrete(void)
187    : rng_(RNG::instance())
188  {
189  }
190
191
192  Discrete::~Discrete(void)
193  {
194  }
195
196
197  void Discrete::seed(unsigned long s) const
198  {
199    rng_->seed(s);
200  }
201
202
203  unsigned long Discrete::seed_from_devurandom(void)
204  {
205    return rng_->seed_from_devurandom();
206  }
207
208
209  Binomial::Binomial(double p, unsigned int n)
210    : Discrete(), p_(p), n_(n)
211  {
212  }
213
214
215  unsigned long Binomial::operator()(void) const
216  {
217    return gsl_ran_binomial(rng_->rng(), p_, n_);
218  }
219
220
221  DiscreteGeneral::DiscreteGeneral(void)
222    : gen_(nullptr)
223  {}
224
225
226  DiscreteGeneral::DiscreteGeneral(const statistics::Histogram& hist)
227    : gen_(NULL)
228  {
229    p_.reserve(hist.nof_bins());
230    for (size_t i=0; i<hist.nof_bins(); i++)
231      p_.push_back(hist[i]);
232    preproc();
233  }
234
235
236  DiscreteGeneral::DiscreteGeneral(const DiscreteGeneral& other)
237    : Discrete(other), gen_(NULL), p_(other.p_)
238  {
239    preproc();
240  }
241
242
243  DiscreteGeneral::~DiscreteGeneral(void)
244  {
245    free();
246  }
247
248
249  void DiscreteGeneral::free(void)
250  {
251    if (gen_)
252      gsl_ran_discrete_free( gen_ );
253    gen_ = NULL;
254  }
255
256
257  void DiscreteGeneral::preproc(void)
258  {
259    assert(!gen_);
260    assert(p_.size());
261    gen_ = gsl_ran_discrete_preproc( p_.size(), &p_.front() );
262    if (!gen_)
263      throw utility::GSL_error("DiscreteGeneral failed to setup generator.");
264  }
265
266
267  DiscreteGeneral& DiscreteGeneral::operator=(const DiscreteGeneral& rhs)
268  {
269    free();
270    p_ = rhs.p_;
271    preproc();
272    return *this;
273  }
274
275
276  unsigned long DiscreteGeneral::operator()(void) const
277  {
278    return gsl_ran_discrete(rng_->rng(), gen_);
279  }
280
281
282  DiscreteUniform::DiscreteUniform(unsigned long n)
283    : range_(n)
284  {
285    if (range_>rng_->max()) {
286      std::ostringstream ss;
287      ss << "DiscreteUniform::DiscreteUniform: ";
288      ss << n << " is too large for RNG " << rng_->name();
289      ss << "; maximal argument is " << rng_->max();
290      throw utility::GSL_error(ss.str());
291    }
292  }
293
294
295  unsigned long DiscreteUniform::operator()(void) const
296  {
297    return (range_ ?
298            gsl_rng_uniform_int(rng_->rng(),range_) : gsl_rng_get(rng_->rng()));
299  }
300
301
302  unsigned long DiscreteUniform::operator()(unsigned long n) const
303  {
304    // making sure that n is not larger than the range of the
305    // underlying RNG
306    if (n>rng_->max()) {
307      std::ostringstream ss;
308      ss << "DiscreteUniform::operator(unsigned long): ";
309      ss << n << " is too large for RNG " << rng_->name();
310      ss << "; maximal argument is " << rng_->max();
311      throw utility::GSL_error(ss.str());
312    }
313    return gsl_rng_uniform_int(rng_->rng(),n);
314  }
315
316
317  unsigned long int DiscreteUniform::min(void) const
318  {
319    return range_ ? 0 : rng_->min();
320  }
321
322
323  unsigned long int DiscreteUniform::max(void) const
324  {
325    return range_ ? (range_-1) : rng_->max();
326  }
327
328
329  Geometric::Geometric(double p)
330    : p_(p)
331  {}
332
333
334  unsigned long int Geometric::operator()(void) const
335  {
336    return gsl_ran_geometric (rng_->rng(), p_); 
337  }
338
339
340  unsigned long int Geometric::operator()(double p) const
341  {
342    return gsl_ran_geometric (rng_->rng(), p);
343  }
344
345
346  HyperGeometric::HyperGeometric(void)
347  {}
348
349
350  HyperGeometric::HyperGeometric(unsigned int n1, unsigned int n2,
351                                 unsigned int t)
352    : n1_(n1), n2_(n2), t_(t)
353  {}
354
355
356  unsigned long int HyperGeometric::operator()(void) const
357  {
358    return (*this)(n1_, n2_, t_);
359  }
360
361
362  unsigned long int HyperGeometric::operator()(unsigned int n1,
363                                               unsigned int n2,
364                                               unsigned int t) const
365  {
366    return gsl_ran_hypergeometric(rng_->rng(), n1, n2, t);
367  }
368
369
370  NegativeHyperGeometric::NegativeHyperGeometric(void)
371  {}
372
373
374  NegativeHyperGeometric::NegativeHyperGeometric(unsigned int n1,
375                                                 unsigned int n2, unsigned int t)
376    : n1_(n1), n2_(n2), t_(t)
377  {}
378
379
380  unsigned long int NegativeHyperGeometric::operator()(void) const
381  {
382    return (*this)(n1_, n2_, t_);
383  }
384
385
386  unsigned long int NegativeHyperGeometric::operator()(unsigned int n1,
387                                                       unsigned int n2,
388                                                       unsigned int t) const
389  {
390    assert(t <= n2);
391
392    // NHG can be described as an array with n1 true and n2 false, and
393    // NHG(n1, n2, t) is the number of true left of the t:th false. By
394    // symmetry number of true right of the t:th false is NHG(n1, n2,
395    // n2-t+1) since t:th false counting from left is the (n2-t+1):th
396    // false counting from right.
397
398    // When t is larger than midpoint (2*t > n2+1) we use this
399    // symmetry to speed things up.
400    if (t > (n2+1)/2)
401      return n1 - (*this)(n1, n2, n2-t+1);
402
403    ContinuousUniform uniform;
404    unsigned long int k = 0;
405    while (t) {
406      assert(n1 + n2);
407      double x = uniform();
408      if (x * (n1+n2) < n1) {
409        --n1;
410        ++k;
411        if (!n1)
412          return k;
413      }
414      else {
415        --t;
416        --n2;
417      }
418    }
419    return k;
420  }
421
422
423  Poisson::Poisson(const double m)
424    : m_(m)
425  {
426  }
427
428  unsigned long Poisson::operator()(void) const
429  {
430    return gsl_ran_poisson(rng_->rng(), m_);
431  }
432
433
434  unsigned long Poisson::operator()(const double m) const
435  {
436    return gsl_ran_poisson(rng_->rng(), m);
437  }
438
439  // --------------------- Continuous distribtuions ---------------------
440
441  Continuous::Continuous(void)
442    : rng_(RNG::instance())
443  {
444  }
445
446
447  Continuous::~Continuous(void)
448  {
449  }
450
451
452  void Continuous::seed(unsigned long s) const
453  {
454    rng_->seed(s);
455  }
456
457
458  unsigned long Continuous::seed_from_devurandom(void)
459  {
460    return rng_->seed_from_devurandom();
461  }
462
463
464  ContinuousGeneral::ContinuousGeneral(const statistics::Histogram& hist)
465    : discrete_(DiscreteGeneral(hist)), hist_(hist)
466  {
467  }
468
469
470  double ContinuousGeneral::operator()(void) const
471  {
472    return hist_.observation_value(discrete_())+(u_()-0.5)*hist_.spacing();
473  }
474
475  double ContinuousUniform::operator()(void) const
476  {
477    return gsl_rng_uniform(rng_->rng());
478  }
479
480
481  Exponential::Exponential(const double m)
482    : m_(m)
483  {
484  }
485
486
487  double Exponential::operator()(void) const
488  {
489    return gsl_ran_exponential(rng_->rng(), m_);
490  }
491
492
493  double Exponential::operator()(const double m) const
494  {
495    return gsl_ran_exponential(rng_->rng(), m);
496  }
497
498
499  Gaussian::Gaussian(const double s, const double m)
500    : m_(m), s_(s)
501  {
502  }
503
504
505  double Gaussian::operator()(void) const
506  {
507    return gsl_ran_gaussian(rng_->rng(), s_)+m_;
508  }
509
510
511  double Gaussian::operator()(const double s) const
512  {
513    return gsl_ran_gaussian(rng_->rng(), s);
514  }
515
516
517  double Gaussian::operator()(const double s, const double m) const
518  {
519    return gsl_ran_gaussian(rng_->rng(), s)+m;
520  }
521
522}}} // of namespace random, yat, and theplu
Note: See TracBrowser for help on using the repository browser.