]> git.proxmox.com Git - rustc.git/blame - vendor/rand/src/distributions/bernoulli.rs
New upstream version 1.51.0+dfsg1
[rustc.git] / vendor / rand / src / distributions / bernoulli.rs
CommitLineData
0731742a 1// Copyright 2018 Developers of the Rand project.
b7449926
XL
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
0731742a 8
b7449926
XL
9//! The Bernoulli distribution.
10
416331ca 11use crate::distributions::Distribution;
dfeec247
XL
12use crate::Rng;
13use core::{fmt, u64};
b7449926
XL
14
15/// The Bernoulli distribution.
16///
17/// This is a special case of the Binomial distribution where `n = 1`.
18///
19/// # Example
20///
21/// ```rust
22/// use rand::distributions::{Bernoulli, Distribution};
23///
416331ca 24/// let d = Bernoulli::new(0.3).unwrap();
b7449926
XL
25/// let v = d.sample(&mut rand::thread_rng());
26/// println!("{} is from a Bernoulli distribution", v);
27/// ```
28///
29/// # Precision
30///
31/// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`),
32/// so only probabilities that are multiples of 2<sup>-64</sup> can be
33/// represented.
34#[derive(Clone, Copy, Debug)]
35pub struct Bernoulli {
36 /// Probability of success, relative to the maximal integer.
37 p_int: u64,
38}
39
0731742a
XL
40// To sample from the Bernoulli distribution we use a method that compares a
41// random `u64` value `v < (p * 2^64)`.
42//
43// If `p == 1.0`, the integer `v` to compare against can not represented as a
44// `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64).
45// Note that value of `p < 1.0` can never result in `u64::MAX`, because an
46// `f64` only has 53 bits of precision, and the next largest value of `p` will
47// result in `2^64 - 2048`.
48//
49// Also there is a 100% theoretical concern: if someone consistenly wants to
50// generate `true` using the Bernoulli distribution (i.e. by using a probability
51// of `1.0`), just using `u64::MAX` is not enough. On average it would return
52// false once every 2^64 iterations. Some people apparently care about this
53// case.
54//
55// That is why we special-case `u64::MAX` to always return `true`, without using
56// the RNG, and pay the performance price for all uses that *are* reasonable.
57// Luckily, if `new()` and `sample` are close, the compiler can optimize out the
58// extra check.
dfeec247 59const ALWAYS_TRUE: u64 = u64::MAX;
0731742a
XL
60
61// This is just `2.0.powi(64)`, but written this way because it is not available
62// in `no_std` mode.
63const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
64
416331ca
XL
65/// Error type returned from `Bernoulli::new`.
66#[derive(Clone, Copy, Debug, PartialEq, Eq)]
67pub enum BernoulliError {
68 /// `p < 0` or `p > 1`.
69 InvalidProbability,
70}
71
dfeec247
XL
72impl fmt::Display for BernoulliError {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 f.write_str(match self {
75 BernoulliError::InvalidProbability => "p is outside [0, 1] in Bernoulli distribution",
76 })
77 }
78}
79
80#[cfg(feature = "std")]
81impl ::std::error::Error for BernoulliError {}
82
b7449926
XL
83impl Bernoulli {
84 /// Construct a new `Bernoulli` with the given probability of success `p`.
85 ///
b7449926
XL
86 /// # Precision
87 ///
88 /// For `p = 1.0`, the resulting distribution will always generate true.
89 /// For `p = 0.0`, the resulting distribution will always generate false.
90 ///
91 /// This method is accurate for any input `p` in the range `[0, 1]` which is
92 /// a multiple of 2<sup>-64</sup>. (Note that not all multiples of
93 /// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
94 #[inline]
416331ca 95 pub fn new(p: f64) -> Result<Bernoulli, BernoulliError> {
dfeec247
XL
96 if !(p >= 0.0 && p < 1.0) {
97 if p == 1.0 {
98 return Ok(Bernoulli { p_int: ALWAYS_TRUE });
99 }
416331ca 100 return Err(BernoulliError::InvalidProbability);
0731742a 101 }
dfeec247
XL
102 Ok(Bernoulli {
103 p_int: (p * SCALE) as u64,
104 })
0731742a
XL
105 }
106
107 /// Construct a new `Bernoulli` with the probability of success of
108 /// `numerator`-in-`denominator`. I.e. `new_ratio(2, 3)` will return
109 /// a `Bernoulli` with a 2-in-3 chance, or about 67%, of returning `true`.
110 ///
0731742a 111 /// return `true`. If `numerator == 0` it will always return `false`.
dfeec247
XL
112 /// For `numerator > denominator` and `denominator == 0`, this returns an
113 /// error. Otherwise, for `numerator == denominator`, samples are always
114 /// true; for `numerator == 0` samples are always false.
0731742a 115 #[inline]
416331ca 116 pub fn from_ratio(numerator: u32, denominator: u32) -> Result<Bernoulli, BernoulliError> {
dfeec247 117 if numerator > denominator || denominator == 0 {
416331ca
XL
118 return Err(BernoulliError::InvalidProbability);
119 }
0731742a 120 if numerator == denominator {
dfeec247 121 return Ok(Bernoulli { p_int: ALWAYS_TRUE });
0731742a 122 }
dfeec247 123 let p_int = ((f64::from(numerator) / f64::from(denominator)) * SCALE) as u64;
416331ca 124 Ok(Bernoulli { p_int })
b7449926
XL
125 }
126}
127
128impl Distribution<bool> for Bernoulli {
129 #[inline]
130 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
131 // Make sure to always return true for p = 1.0.
dfeec247
XL
132 if self.p_int == ALWAYS_TRUE {
133 return true;
134 }
0731742a
XL
135 let v: u64 = rng.gen();
136 v < self.p_int
b7449926
XL
137 }
138}
139
140#[cfg(test)]
141mod test {
b7449926 142 use super::Bernoulli;
dfeec247
XL
143 use crate::distributions::Distribution;
144 use crate::Rng;
b7449926
XL
145
146 #[test]
147 fn test_trivial() {
416331ca
XL
148 let mut r = crate::test::rng(1);
149 let always_false = Bernoulli::new(0.0).unwrap();
150 let always_true = Bernoulli::new(1.0).unwrap();
b7449926
XL
151 for _ in 0..5 {
152 assert_eq!(r.sample::<bool, _>(&always_false), false);
153 assert_eq!(r.sample::<bool, _>(&always_true), true);
154 assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false);
155 assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true);
156 }
157 }
158
159 #[test]
dfeec247 160 #[cfg_attr(miri, ignore)] // Miri is too slow
b7449926
XL
161 fn test_average() {
162 const P: f64 = 0.3;
0731742a
XL
163 const NUM: u32 = 3;
164 const DENOM: u32 = 10;
416331ca
XL
165 let d1 = Bernoulli::new(P).unwrap();
166 let d2 = Bernoulli::from_ratio(NUM, DENOM).unwrap();
0731742a 167 const N: u32 = 100_000;
b7449926 168
0731742a
XL
169 let mut sum1: u32 = 0;
170 let mut sum2: u32 = 0;
416331ca 171 let mut rng = crate::test::rng(2);
b7449926 172 for _ in 0..N {
0731742a
XL
173 if d1.sample(&mut rng) {
174 sum1 += 1;
175 }
176 if d2.sample(&mut rng) {
177 sum2 += 1;
b7449926
XL
178 }
179 }
0731742a
XL
180 let avg1 = (sum1 as f64) / (N as f64);
181 assert!((avg1 - P).abs() < 5e-3);
b7449926 182
0731742a 183 let avg2 = (sum2 as f64) / (N as f64);
dfeec247
XL
184 assert!((avg2 - (NUM as f64) / (DENOM as f64)).abs() < 5e-3);
185 }
186
187 #[test]
188 fn value_stability() {
189 let mut rng = crate::test::rng(3);
190 let distr = Bernoulli::new(0.4532).unwrap();
191 let mut buf = [false; 10];
192 for x in &mut buf {
193 *x = rng.sample(&distr);
194 }
195 assert_eq!(buf, [
196 true, false, false, true, false, false, true, true, true, true
197 ]);
b7449926
XL
198 }
199}