]>
Commit | Line | Data |
---|---|---|
0731742a XL |
1 | // Copyright 2018 Developers of the Rand project. |
2 | // Copyright 2016-2017 The Rust Project Developers. | |
b7449926 XL |
3 | // |
4 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | |
5 | // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | |
6 | // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | |
7 | // option. This file may not be copied, modified, or distributed | |
8 | // except according to those terms. | |
9 | ||
10 | //! The binomial distribution. | |
416331ca | 11 | #![allow(deprecated)] |
dfeec247 | 12 | #![allow(clippy::all)] |
b7449926 | 13 | |
416331ca | 14 | use crate::distributions::{Distribution, Uniform}; |
dfeec247 | 15 | use crate::Rng; |
b7449926 XL |
16 | |
17 | /// The binomial distribution `Binomial(n, p)`. | |
18 | /// | |
19 | /// This distribution has density function: | |
20 | /// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. | |
dfeec247 | 21 | #[deprecated(since = "0.7.0", note = "moved to rand_distr crate")] |
b7449926 XL |
22 | #[derive(Clone, Copy, Debug)] |
23 | pub struct Binomial { | |
24 | /// Number of trials. | |
25 | n: u64, | |
26 | /// Probability of success. | |
27 | p: f64, | |
28 | } | |
29 | ||
30 | impl Binomial { | |
31 | /// Construct a new `Binomial` with the given shape parameters `n` (number | |
32 | /// of trials) and `p` (probability of success). | |
33 | /// | |
34 | /// Panics if `p < 0` or `p > 1`. | |
35 | pub fn new(n: u64, p: f64) -> Binomial { | |
36 | assert!(p >= 0.0, "Binomial::new called with p < 0"); | |
37 | assert!(p <= 1.0, "Binomial::new called with p > 1"); | |
38 | Binomial { n, p } | |
39 | } | |
40 | } | |
41 | ||
416331ca XL |
42 | /// Convert a `f64` to an `i64`, panicing on overflow. |
43 | // In the future (Rust 1.34), this might be replaced with `TryFrom`. | |
44 | fn f64_to_i64(x: f64) -> i64 { | |
45 | assert!(x < (::std::i64::MAX as f64)); | |
46 | x as i64 | |
47 | } | |
48 | ||
b7449926 XL |
49 | impl Distribution<u64> for Binomial { |
50 | fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { | |
51 | // Handle these values directly. | |
52 | if self.p == 0.0 { | |
53 | return 0; | |
54 | } else if self.p == 1.0 { | |
55 | return self.n; | |
56 | } | |
416331ca XL |
57 | |
58 | // The binomial distribution is symmetrical with respect to p -> 1-p, | |
59 | // k -> n-k switch p so that it is less than 0.5 - this allows for lower | |
60 | // expected values we will just invert the result at the end | |
dfeec247 | 61 | let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p }; |
b7449926 | 62 | |
416331ca XL |
63 | let result; |
64 | let q = 1. - p; | |
65 | ||
66 | // For small n * min(p, 1 - p), the BINV algorithm based on the inverse | |
67 | // transformation of the binomial distribution is efficient. Otherwise, | |
68 | // the BTPE algorithm is used. | |
69 | // | |
70 | // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial | |
71 | // random variate generation. Commun. ACM 31, 2 (February 1988), | |
72 | // 216-222. http://dx.doi.org/10.1145/42372.42381 | |
73 | ||
74 | // Threshold for prefering the BINV algorithm. The paper suggests 10, | |
75 | // Ranlib uses 30, and GSL uses 14. | |
76 | const BINV_THRESHOLD: f64 = 10.; | |
77 | ||
dfeec247 | 78 | if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (::std::i32::MAX as u64) { |
416331ca XL |
79 | // Use the BINV algorithm. |
80 | let s = p / q; | |
81 | let a = ((self.n + 1) as f64) * s; | |
82 | let mut r = q.powi(self.n as i32); | |
83 | let mut u: f64 = rng.gen(); | |
84 | let mut x = 0; | |
85 | while u > r as f64 { | |
86 | u -= r; | |
87 | x += 1; | |
88 | r *= a / (x as f64) - s; | |
89 | } | |
90 | result = x; | |
91 | } else { | |
92 | // Use the BTPE algorithm. | |
93 | ||
94 | // Threshold for using the squeeze algorithm. This can be freely | |
95 | // chosen based on performance. Ranlib and GSL use 20. | |
96 | const SQUEEZE_THRESHOLD: i64 = 20; | |
97 | ||
98 | // Step 0: Calculate constants as functions of `n` and `p`. | |
99 | let n = self.n as f64; | |
100 | let np = n * p; | |
101 | let npq = np * q; | |
102 | let f_m = np + p; | |
103 | let m = f64_to_i64(f_m); | |
104 | // radius of triangle region, since height=1 also area of region | |
105 | let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; | |
106 | // tip of triangle | |
107 | let x_m = (m as f64) + 0.5; | |
108 | // left edge of triangle | |
109 | let x_l = x_m - p1; | |
110 | // right edge of triangle | |
111 | let x_r = x_m + p1; | |
112 | let c = 0.134 + 20.5 / (15.3 + (m as f64)); | |
113 | // p1 + area of parallelogram region | |
114 | let p2 = p1 * (1. + 2. * c); | |
115 | ||
116 | fn lambda(a: f64) -> f64 { | |
117 | a * (1. + 0.5 * a) | |
118 | } | |
119 | ||
120 | let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p)); | |
121 | let lambda_r = lambda((x_r - f_m) / (x_r * q)); | |
122 | // p1 + area of left tail | |
123 | let p3 = p2 + c / lambda_l; | |
124 | // p1 + area of right tail | |
125 | let p4 = p3 + c / lambda_r; | |
126 | ||
127 | // return value | |
128 | let mut y: i64; | |
129 | ||
130 | let gen_u = Uniform::new(0., p4); | |
131 | let gen_v = Uniform::new(0., 1.); | |
132 | ||
b7449926 | 133 | loop { |
416331ca XL |
134 | // Step 1: Generate `u` for selecting the region. If region 1 is |
135 | // selected, generate a triangularly distributed variate. | |
136 | let u = gen_u.sample(rng); | |
137 | let mut v = gen_v.sample(rng); | |
138 | if !(u > p1) { | |
139 | y = f64_to_i64(x_m - p1 * v + u); | |
b7449926 XL |
140 | break; |
141 | } | |
b7449926 | 142 | |
416331ca XL |
143 | if !(u > p2) { |
144 | // Step 2: Region 2, parallelograms. Check if region 2 is | |
145 | // used. If so, generate `y`. | |
146 | let x = x_l + (u - p1) / c; | |
147 | v = v * c + 1.0 - (x - x_m).abs() / p1; | |
148 | if v > 1. { | |
149 | continue; | |
150 | } else { | |
151 | y = f64_to_i64(x); | |
152 | } | |
153 | } else if !(u > p3) { | |
154 | // Step 3: Region 3, left exponential tail. | |
155 | y = f64_to_i64(x_l + v.ln() / lambda_l); | |
156 | if y < 0 { | |
157 | continue; | |
158 | } else { | |
159 | v *= (u - p2) * lambda_l; | |
160 | } | |
161 | } else { | |
162 | // Step 4: Region 4, right exponential tail. | |
163 | y = f64_to_i64(x_r - v.ln() / lambda_r); | |
164 | if y > 0 && (y as u64) > self.n { | |
165 | continue; | |
166 | } else { | |
167 | v *= (u - p3) * lambda_r; | |
168 | } | |
169 | } | |
170 | ||
171 | // Step 5: Acceptance/rejection comparison. | |
172 | ||
173 | // Step 5.0: Test for appropriate method of evaluating f(y). | |
174 | let k = (y - m).abs(); | |
175 | if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { | |
176 | // Step 5.1: Evaluate f(y) via the recursive relationship. Start the | |
177 | // search from the mode. | |
178 | let s = p / q; | |
179 | let a = s * (n + 1.); | |
180 | let mut f = 1.0; | |
181 | if m < y { | |
182 | let mut i = m; | |
183 | loop { | |
184 | i += 1; | |
185 | f *= a / (i as f64) - s; | |
186 | if i == y { | |
187 | break; | |
188 | } | |
189 | } | |
190 | } else if m > y { | |
191 | let mut i = y; | |
192 | loop { | |
193 | i += 1; | |
194 | f /= a / (i as f64) - s; | |
195 | if i == m { | |
196 | break; | |
197 | } | |
198 | } | |
199 | } | |
200 | if v > f { | |
201 | continue; | |
202 | } else { | |
203 | break; | |
204 | } | |
205 | } | |
b7449926 | 206 | |
416331ca XL |
207 | // Step 5.2: Squeezing. Check the value of ln(v) againts upper and |
208 | // lower bound of ln(f(y)). | |
209 | let k = k as f64; | |
dfeec247 XL |
210 | let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); |
211 | let t = -0.5 * k * k / npq; | |
416331ca XL |
212 | let alpha = v.ln(); |
213 | if alpha < t - rho { | |
214 | break; | |
215 | } | |
216 | if alpha > t + rho { | |
217 | continue; | |
218 | } | |
219 | ||
220 | // Step 5.3: Final acceptance/rejection test. | |
221 | let x1 = (y + 1) as f64; | |
222 | let f1 = (m + 1) as f64; | |
223 | let z = (f64_to_i64(n) + 1 - m) as f64; | |
224 | let w = (f64_to_i64(n) - y + 1) as f64; | |
225 | ||
226 | fn stirling(a: f64) -> f64 { | |
227 | let a2 = a * a; | |
228 | (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. | |
229 | } | |
230 | ||
dfeec247 XL |
231 | if alpha |
232 | > x_m * (f1 / x1).ln() | |
233 | + (n - (m as f64) + 0.5) * (z / w).ln() | |
234 | + ((y - m) as f64) * (w * p / (x1 * q)).ln() | |
235 | // We use the signs from the GSL implementation, which are | |
236 | // different than the ones in the reference. According to | |
237 | // the GSL authors, the new signs were verified to be | |
238 | // correct by one of the original designers of the | |
239 | // algorithm. | |
240 | + stirling(f1) | |
241 | + stirling(z) | |
242 | - stirling(x1) | |
243 | - stirling(w) | |
416331ca XL |
244 | { |
245 | continue; | |
246 | } | |
b7449926 | 247 | |
b7449926 XL |
248 | break; |
249 | } | |
416331ca XL |
250 | assert!(y >= 0); |
251 | result = y as u64; | |
b7449926 XL |
252 | } |
253 | ||
416331ca | 254 | // Invert the result for p < 0.5. |
b7449926 | 255 | if p != self.p { |
416331ca | 256 | self.n - result |
b7449926 | 257 | } else { |
416331ca | 258 | result |
b7449926 XL |
259 | } |
260 | } | |
261 | } | |
262 | ||
263 | #[cfg(test)] | |
264 | mod test { | |
b7449926 | 265 | use super::Binomial; |
dfeec247 XL |
266 | use crate::distributions::Distribution; |
267 | use crate::Rng; | |
b7449926 XL |
268 | |
269 | fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) { | |
270 | let binomial = Binomial::new(n, p); | |
271 | ||
272 | let expected_mean = n as f64 * p; | |
273 | let expected_variance = n as f64 * p * (1.0 - p); | |
274 | ||
275 | let mut results = [0.0; 1000]; | |
dfeec247 XL |
276 | for i in results.iter_mut() { |
277 | *i = binomial.sample(rng) as f64; | |
278 | } | |
b7449926 XL |
279 | |
280 | let mean = results.iter().sum::<f64>() / results.len() as f64; | |
dfeec247 XL |
281 | assert!( |
282 | (mean as f64 - expected_mean).abs() < expected_mean / 50.0, | |
283 | "mean: {}, expected_mean: {}", | |
284 | mean, | |
285 | expected_mean | |
286 | ); | |
b7449926 XL |
287 | |
288 | let variance = | |
dfeec247 XL |
289 | results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64; |
290 | assert!( | |
291 | (variance - expected_variance).abs() < expected_variance / 10.0, | |
292 | "variance: {}, expected_variance: {}", | |
293 | variance, | |
294 | expected_variance | |
295 | ); | |
b7449926 XL |
296 | } |
297 | ||
298 | #[test] | |
dfeec247 | 299 | #[cfg_attr(miri, ignore)] // Miri is too slow |
b7449926 | 300 | fn test_binomial() { |
416331ca | 301 | let mut rng = crate::test::rng(351); |
b7449926 XL |
302 | test_binomial_mean_and_variance(150, 0.1, &mut rng); |
303 | test_binomial_mean_and_variance(70, 0.6, &mut rng); | |
304 | test_binomial_mean_and_variance(40, 0.5, &mut rng); | |
305 | test_binomial_mean_and_variance(20, 0.7, &mut rng); | |
306 | test_binomial_mean_and_variance(20, 0.5, &mut rng); | |
307 | } | |
308 | ||
309 | #[test] | |
310 | fn test_binomial_end_points() { | |
416331ca | 311 | let mut rng = crate::test::rng(352); |
b7449926 XL |
312 | assert_eq!(rng.sample(Binomial::new(20, 0.0)), 0); |
313 | assert_eq!(rng.sample(Binomial::new(20, 1.0)), 20); | |
314 | } | |
315 | ||
316 | #[test] | |
317 | #[should_panic] | |
318 | fn test_binomial_invalid_lambda_neg() { | |
319 | Binomial::new(20, -10.0); | |
320 | } | |
321 | } |