]> git.proxmox.com Git - rustc.git/blame - vendor/rand-0.7.3/src/distributions/binomial.rs
Merge tag 'debian/1.52.1+dfsg1-1_exp2' into proxmox/buster
[rustc.git] / vendor / rand-0.7.3 / src / distributions / binomial.rs
CommitLineData
0731742a
XL
1// Copyright 2018 Developers of the Rand project.
2// Copyright 2016-2017 The Rust Project Developers.
b7449926
XL
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The binomial distribution.
416331ca 11#![allow(deprecated)]
dfeec247 12#![allow(clippy::all)]
b7449926 13
416331ca 14use crate::distributions::{Distribution, Uniform};
dfeec247 15use crate::Rng;
b7449926
XL
16
17/// The binomial distribution `Binomial(n, p)`.
18///
19/// This distribution has density function:
20/// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
dfeec247 21#[deprecated(since = "0.7.0", note = "moved to rand_distr crate")]
b7449926
XL
22#[derive(Clone, Copy, Debug)]
23pub struct Binomial {
24 /// Number of trials.
25 n: u64,
26 /// Probability of success.
27 p: f64,
28}
29
30impl Binomial {
31 /// Construct a new `Binomial` with the given shape parameters `n` (number
32 /// of trials) and `p` (probability of success).
33 ///
34 /// Panics if `p < 0` or `p > 1`.
35 pub fn new(n: u64, p: f64) -> Binomial {
36 assert!(p >= 0.0, "Binomial::new called with p < 0");
37 assert!(p <= 1.0, "Binomial::new called with p > 1");
38 Binomial { n, p }
39 }
40}
41
416331ca
XL
42/// Convert a `f64` to an `i64`, panicing on overflow.
43// In the future (Rust 1.34), this might be replaced with `TryFrom`.
44fn f64_to_i64(x: f64) -> i64 {
45 assert!(x < (::std::i64::MAX as f64));
46 x as i64
47}
48
b7449926
XL
49impl Distribution<u64> for Binomial {
50 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
51 // Handle these values directly.
52 if self.p == 0.0 {
53 return 0;
54 } else if self.p == 1.0 {
55 return self.n;
56 }
416331ca
XL
57
58 // The binomial distribution is symmetrical with respect to p -> 1-p,
59 // k -> n-k switch p so that it is less than 0.5 - this allows for lower
60 // expected values we will just invert the result at the end
dfeec247 61 let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p };
b7449926 62
416331ca
XL
63 let result;
64 let q = 1. - p;
65
66 // For small n * min(p, 1 - p), the BINV algorithm based on the inverse
67 // transformation of the binomial distribution is efficient. Otherwise,
68 // the BTPE algorithm is used.
69 //
70 // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
71 // random variate generation. Commun. ACM 31, 2 (February 1988),
72 // 216-222. http://dx.doi.org/10.1145/42372.42381
73
74 // Threshold for prefering the BINV algorithm. The paper suggests 10,
75 // Ranlib uses 30, and GSL uses 14.
76 const BINV_THRESHOLD: f64 = 10.;
77
dfeec247 78 if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (::std::i32::MAX as u64) {
416331ca
XL
79 // Use the BINV algorithm.
80 let s = p / q;
81 let a = ((self.n + 1) as f64) * s;
82 let mut r = q.powi(self.n as i32);
83 let mut u: f64 = rng.gen();
84 let mut x = 0;
85 while u > r as f64 {
86 u -= r;
87 x += 1;
88 r *= a / (x as f64) - s;
89 }
90 result = x;
91 } else {
92 // Use the BTPE algorithm.
93
94 // Threshold for using the squeeze algorithm. This can be freely
95 // chosen based on performance. Ranlib and GSL use 20.
96 const SQUEEZE_THRESHOLD: i64 = 20;
97
98 // Step 0: Calculate constants as functions of `n` and `p`.
99 let n = self.n as f64;
100 let np = n * p;
101 let npq = np * q;
102 let f_m = np + p;
103 let m = f64_to_i64(f_m);
104 // radius of triangle region, since height=1 also area of region
105 let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
106 // tip of triangle
107 let x_m = (m as f64) + 0.5;
108 // left edge of triangle
109 let x_l = x_m - p1;
110 // right edge of triangle
111 let x_r = x_m + p1;
112 let c = 0.134 + 20.5 / (15.3 + (m as f64));
113 // p1 + area of parallelogram region
114 let p2 = p1 * (1. + 2. * c);
115
116 fn lambda(a: f64) -> f64 {
117 a * (1. + 0.5 * a)
118 }
119
120 let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p));
121 let lambda_r = lambda((x_r - f_m) / (x_r * q));
122 // p1 + area of left tail
123 let p3 = p2 + c / lambda_l;
124 // p1 + area of right tail
125 let p4 = p3 + c / lambda_r;
126
127 // return value
128 let mut y: i64;
129
130 let gen_u = Uniform::new(0., p4);
131 let gen_v = Uniform::new(0., 1.);
132
b7449926 133 loop {
416331ca
XL
134 // Step 1: Generate `u` for selecting the region. If region 1 is
135 // selected, generate a triangularly distributed variate.
136 let u = gen_u.sample(rng);
137 let mut v = gen_v.sample(rng);
138 if !(u > p1) {
139 y = f64_to_i64(x_m - p1 * v + u);
b7449926
XL
140 break;
141 }
b7449926 142
416331ca
XL
143 if !(u > p2) {
144 // Step 2: Region 2, parallelograms. Check if region 2 is
145 // used. If so, generate `y`.
146 let x = x_l + (u - p1) / c;
147 v = v * c + 1.0 - (x - x_m).abs() / p1;
148 if v > 1. {
149 continue;
150 } else {
151 y = f64_to_i64(x);
152 }
153 } else if !(u > p3) {
154 // Step 3: Region 3, left exponential tail.
155 y = f64_to_i64(x_l + v.ln() / lambda_l);
156 if y < 0 {
157 continue;
158 } else {
159 v *= (u - p2) * lambda_l;
160 }
161 } else {
162 // Step 4: Region 4, right exponential tail.
163 y = f64_to_i64(x_r - v.ln() / lambda_r);
164 if y > 0 && (y as u64) > self.n {
165 continue;
166 } else {
167 v *= (u - p3) * lambda_r;
168 }
169 }
170
171 // Step 5: Acceptance/rejection comparison.
172
173 // Step 5.0: Test for appropriate method of evaluating f(y).
174 let k = (y - m).abs();
175 if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
176 // Step 5.1: Evaluate f(y) via the recursive relationship. Start the
177 // search from the mode.
178 let s = p / q;
179 let a = s * (n + 1.);
180 let mut f = 1.0;
181 if m < y {
182 let mut i = m;
183 loop {
184 i += 1;
185 f *= a / (i as f64) - s;
186 if i == y {
187 break;
188 }
189 }
190 } else if m > y {
191 let mut i = y;
192 loop {
193 i += 1;
194 f /= a / (i as f64) - s;
195 if i == m {
196 break;
197 }
198 }
199 }
200 if v > f {
201 continue;
202 } else {
203 break;
204 }
205 }
b7449926 206
416331ca
XL
207 // Step 5.2: Squeezing. Check the value of ln(v) againts upper and
208 // lower bound of ln(f(y)).
209 let k = k as f64;
dfeec247
XL
210 let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
211 let t = -0.5 * k * k / npq;
416331ca
XL
212 let alpha = v.ln();
213 if alpha < t - rho {
214 break;
215 }
216 if alpha > t + rho {
217 continue;
218 }
219
220 // Step 5.3: Final acceptance/rejection test.
221 let x1 = (y + 1) as f64;
222 let f1 = (m + 1) as f64;
223 let z = (f64_to_i64(n) + 1 - m) as f64;
224 let w = (f64_to_i64(n) - y + 1) as f64;
225
226 fn stirling(a: f64) -> f64 {
227 let a2 = a * a;
228 (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
229 }
230
dfeec247
XL
231 if alpha
232 > x_m * (f1 / x1).ln()
233 + (n - (m as f64) + 0.5) * (z / w).ln()
234 + ((y - m) as f64) * (w * p / (x1 * q)).ln()
235 // We use the signs from the GSL implementation, which are
236 // different than the ones in the reference. According to
237 // the GSL authors, the new signs were verified to be
238 // correct by one of the original designers of the
239 // algorithm.
240 + stirling(f1)
241 + stirling(z)
242 - stirling(x1)
243 - stirling(w)
416331ca
XL
244 {
245 continue;
246 }
b7449926 247
b7449926
XL
248 break;
249 }
416331ca
XL
250 assert!(y >= 0);
251 result = y as u64;
b7449926
XL
252 }
253
416331ca 254 // Invert the result for p < 0.5.
b7449926 255 if p != self.p {
416331ca 256 self.n - result
b7449926 257 } else {
416331ca 258 result
b7449926
XL
259 }
260 }
261}
262
263#[cfg(test)]
264mod test {
b7449926 265 use super::Binomial;
dfeec247
XL
266 use crate::distributions::Distribution;
267 use crate::Rng;
b7449926
XL
268
269 fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) {
270 let binomial = Binomial::new(n, p);
271
272 let expected_mean = n as f64 * p;
273 let expected_variance = n as f64 * p * (1.0 - p);
274
275 let mut results = [0.0; 1000];
dfeec247
XL
276 for i in results.iter_mut() {
277 *i = binomial.sample(rng) as f64;
278 }
b7449926
XL
279
280 let mean = results.iter().sum::<f64>() / results.len() as f64;
dfeec247
XL
281 assert!(
282 (mean as f64 - expected_mean).abs() < expected_mean / 50.0,
283 "mean: {}, expected_mean: {}",
284 mean,
285 expected_mean
286 );
b7449926
XL
287
288 let variance =
dfeec247
XL
289 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
290 assert!(
291 (variance - expected_variance).abs() < expected_variance / 10.0,
292 "variance: {}, expected_variance: {}",
293 variance,
294 expected_variance
295 );
b7449926
XL
296 }
297
298 #[test]
dfeec247 299 #[cfg_attr(miri, ignore)] // Miri is too slow
b7449926 300 fn test_binomial() {
416331ca 301 let mut rng = crate::test::rng(351);
b7449926
XL
302 test_binomial_mean_and_variance(150, 0.1, &mut rng);
303 test_binomial_mean_and_variance(70, 0.6, &mut rng);
304 test_binomial_mean_and_variance(40, 0.5, &mut rng);
305 test_binomial_mean_and_variance(20, 0.7, &mut rng);
306 test_binomial_mean_and_variance(20, 0.5, &mut rng);
307 }
308
309 #[test]
310 fn test_binomial_end_points() {
416331ca 311 let mut rng = crate::test::rng(352);
b7449926
XL
312 assert_eq!(rng.sample(Binomial::new(20, 0.0)), 0);
313 assert_eq!(rng.sample(Binomial::new(20, 1.0)), 20);
314 }
315
316 #[test]
317 #[should_panic]
318 fn test_binomial_invalid_lambda_neg() {
319 Binomial::new(20, -10.0);
320 }
321}