]> git.proxmox.com Git - rustc.git/blame - src/librand/distributions/mod.rs
New upstream version 1.22.1+dfsg1
[rustc.git] / src / librand / distributions / mod.rs
CommitLineData
1a4d82fc
JJ
1// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11//! Sampling from random distributions.
12//!
13//! This is a generalization of `Rand` to allow parameters to control the
14//! exact properties of the generated values, e.g. the mean and standard
15//! deviation of a normal distribution. The `Sample` trait is the most
16//! general, and allows for generating values that change some state
17//! internally. The `IndependentSample` trait is for generating values
18//! that do not need to record state.
19
32a655c1
SL
20use core::fmt;
21
a7813a04 22#[cfg(not(test))] // only necessary for no_std
9346a6ac 23use core::num::Float;
a7813a04 24
85aaf69f 25use core::marker::PhantomData;
1a4d82fc 26
3157f602 27use {Rand, Rng};
1a4d82fc
JJ
28
29pub use self::range::Range;
3157f602
XL
30pub use self::gamma::{ChiSquared, FisherF, Gamma, StudentT};
31pub use self::normal::{LogNormal, Normal};
1a4d82fc
JJ
32pub use self::exponential::Exp;
33
34pub mod range;
35pub mod gamma;
36pub mod normal;
37pub mod exponential;
38
39/// Types that can be used to create a random instance of `Support`.
40pub trait Sample<Support> {
41 /// Generate a random value of `Support`, using `rng` as the
42 /// source of randomness.
43 fn sample<R: Rng>(&mut self, rng: &mut R) -> Support;
44}
45
46/// `Sample`s that do not require keeping track of state.
47///
48/// Since no state is recorded, each sample is (statistically)
49/// independent of all others, assuming the `Rng` used has this
50/// property.
51// FIXME maybe having this separate is overkill (the only reason is to
52// take &self rather than &mut self)? or maybe this should be the
53// trait called `Sample` and the other should be `DependentSample`.
54pub trait IndependentSample<Support>: Sample<Support> {
55 /// Generate a random value.
7cac9316 56 fn ind_sample<R: Rng>(&self, _: &mut R) -> Support;
1a4d82fc
JJ
57}
58
59/// A wrapper for generating types that implement `Rand` via the
60/// `Sample` & `IndependentSample` traits.
b039eaaf
SL
61pub struct RandSample<Sup> {
62 _marker: PhantomData<Sup>,
63}
85aaf69f
SL
64
65impl<Sup> RandSample<Sup> {
66 pub fn new() -> RandSample<Sup> {
67 RandSample { _marker: PhantomData }
68 }
69}
1a4d82fc
JJ
70
71impl<Sup: Rand> Sample<Sup> for RandSample<Sup> {
b039eaaf
SL
72 fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup {
73 self.ind_sample(rng)
74 }
1a4d82fc
JJ
75}
76
77impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
78 fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup {
79 rng.gen()
80 }
81}
82
32a655c1
SL
83impl<Sup> fmt::Debug for RandSample<Sup> {
84 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85 f.pad("RandSample { .. }")
86 }
87}
88
1a4d82fc
JJ
89/// A value with a particular weight for use with `WeightedChoice`.
90pub struct Weighted<T> {
91 /// The numerical weight of this item
c34b1796 92 pub weight: usize,
1a4d82fc
JJ
93 /// The actual item which is being weighted
94 pub item: T,
95}
96
32a655c1
SL
97impl<T: fmt::Debug> fmt::Debug for Weighted<T> {
98 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99 f.debug_struct("Weighted")
100 .field("weight", &self.weight)
101 .field("item", &self.item)
102 .finish()
103 }
104}
105
1a4d82fc
JJ
106/// A distribution that selects from a finite collection of weighted items.
107///
108/// Each item has an associated weight that influences how likely it
109/// is to be chosen: higher weight is more likely.
110///
111/// The `Clone` restriction is a limitation of the `Sample` and
112/// `IndependentSample` traits. Note that `&T` is (cheaply) `Clone` for
c34b1796 113/// all `T`, as is `usize`, so one can store references or indices into
1a4d82fc 114/// another vector.
b039eaaf 115pub struct WeightedChoice<'a, T: 'a> {
1a4d82fc 116 items: &'a mut [Weighted<T>],
b039eaaf 117 weight_range: Range<usize>,
1a4d82fc
JJ
118}
119
120impl<'a, T: Clone> WeightedChoice<'a, T> {
121 /// Create a new `WeightedChoice`.
122 ///
123 /// Panics if:
124 /// - `v` is empty
125 /// - the total weight is 0
c34b1796 126 /// - the total weight is larger than a `usize` can contain.
1a4d82fc
JJ
127 pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
128 // strictly speaking, this is subsumed by the total weight == 0 case
b039eaaf
SL
129 assert!(!items.is_empty(),
130 "WeightedChoice::new called with no items");
1a4d82fc 131
c34b1796 132 let mut running_total = 0_usize;
1a4d82fc
JJ
133
134 // we convert the list from individual weights to cumulative
135 // weights so we can binary search. This *could* drop elements
136 // with weight == 0 as an optimisation.
85aaf69f 137 for item in &mut *items {
1a4d82fc
JJ
138 running_total = match running_total.checked_add(item.weight) {
139 Some(n) => n,
92a42be0
SL
140 None => {
141 panic!("WeightedChoice::new called with a total weight larger than a usize \
142 can contain")
143 }
1a4d82fc
JJ
144 };
145
146 item.weight = running_total;
147 }
b039eaaf
SL
148 assert!(running_total != 0,
149 "WeightedChoice::new called with a total weight of 0");
1a4d82fc
JJ
150
151 WeightedChoice {
3b2f2976 152 items,
1a4d82fc
JJ
153 // we're likely to be generating numbers in this range
154 // relatively often, so might as well cache it
b039eaaf 155 weight_range: Range::new(0, running_total),
1a4d82fc
JJ
156 }
157 }
158}
159
160impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> {
b039eaaf
SL
161 fn sample<R: Rng>(&mut self, rng: &mut R) -> T {
162 self.ind_sample(rng)
163 }
1a4d82fc
JJ
164}
165
166impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
167 fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
168 // we want to find the first element that has cumulative
169 // weight > sample_weight, which we do by binary since the
170 // cumulative weights of self.items are sorted.
171
172 // choose a weight in [0, total_weight)
173 let sample_weight = self.weight_range.ind_sample(rng);
174
175 // short circuit when it's the first item
176 if sample_weight < self.items[0].weight {
177 return self.items[0].item.clone();
178 }
179
180 let mut idx = 0;
181 let mut modifier = self.items.len();
182
183 // now we know that every possibility has an element to the
184 // left, so we can just search for the last element that has
185 // cumulative weight <= sample_weight, then the next one will
186 // be "it". (Note that this greatest element will never be the
187 // last element of the vector, since sample_weight is chosen
188 // in [0, total_weight) and the cumulative weight of the last
189 // one is exactly the total weight.)
190 while modifier > 1 {
191 let i = idx + modifier / 2;
192 if self.items[i].weight <= sample_weight {
193 // we're small, so look to the right, but allow this
194 // exact element still.
195 idx = i;
196 // we need the `/ 2` to round up otherwise we'll drop
197 // the trailing elements when `modifier` is odd.
198 modifier += 1;
199 } else {
200 // otherwise we're too big, so go left. (i.e. do
201 // nothing)
202 }
203 modifier /= 2;
204 }
205 return self.items[idx + 1].item.clone();
206 }
207}
208
32a655c1
SL
209impl<'a, T: fmt::Debug> fmt::Debug for WeightedChoice<'a, T> {
210 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
211 f.debug_struct("WeightedChoice")
212 .field("items", &self.items)
213 .field("weight_range", &self.weight_range)
214 .finish()
215 }
216}
217
1a4d82fc
JJ
218mod ziggurat_tables;
219
220/// Sample a random number using the Ziggurat method (specifically the
221/// ZIGNOR variant from Doornik 2005). Most of the arguments are
222/// directly from the paper:
223///
224/// * `rng`: source of randomness
225/// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
226/// * `X`: the $x_i$ abscissae.
227/// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
228/// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
229/// * `pdf`: the probability density function
230/// * `zero_case`: manual sampling from the tail when we chose the
231/// bottom box (i.e. i == 0)
1a4d82fc
JJ
232// the perf improvement (25-50%) is definitely worth the extra code
233// size from force-inlining.
234#[inline(always)]
b039eaaf
SL
235fn ziggurat<R: Rng, P, Z>(rng: &mut R,
236 symmetric: bool,
237 x_tab: ziggurat_tables::ZigTable,
238 f_tab: ziggurat_tables::ZigTable,
239 mut pdf: P,
240 mut zero_case: Z)
241 -> f64
242 where P: FnMut(f64) -> f64,
243 Z: FnMut(&mut R, f64) -> f64
244{
c34b1796 245 const SCALE: f64 = (1u64 << 53) as f64;
1a4d82fc
JJ
246 loop {
247 // reimplement the f64 generation as an optimisation suggested
248 // by the Doornik paper: we have a lot of precision-space
249 // (i.e. there are 11 bits of the 64 of a u64 to use after
250 // creating a f64), so we might as well reuse some to save
251 // generating a whole extra random number. (Seems to be 15%
252 // faster.)
253 //
254 // This unfortunately misses out on the benefits of direct
255 // floating point generation if an RNG like dSMFT is
256 // used. (That is, such RNGs create floats directly, highly
257 // efficiently and overload next_f32/f64, so by not calling it
258 // this may be slower than it would be otherwise.)
259 // FIXME: investigate/optimise for the above.
260 let bits: u64 = rng.gen();
c34b1796 261 let i = (bits & 0xff) as usize;
1a4d82fc
JJ
262 let f = (bits >> 11) as f64 / SCALE;
263
264 // u is either U(-1, 1) or U(0, 1) depending on if this is a
265 // symmetric distribution or not.
c30ab7b3 266 let u = if symmetric { 2.0 * f - 1.0 } else { f };
1a4d82fc
JJ
267 let x = u * x_tab[i];
268
c30ab7b3 269 let test_x = if symmetric { x.abs() } else { x };
1a4d82fc
JJ
270
271 // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
272 if test_x < x_tab[i + 1] {
273 return x;
274 }
275 if i == 0 {
276 return zero_case(rng, u);
277 }
278 // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
c34b1796 279 if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
1a4d82fc
JJ
280 return x;
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
3157f602
XL
287 use {Rand, Rng};
288 use super::{IndependentSample, RandSample, Sample, Weighted, WeightedChoice};
1a4d82fc 289
85aaf69f 290 #[derive(PartialEq, Debug)]
c34b1796 291 struct ConstRand(usize);
1a4d82fc
JJ
292 impl Rand for ConstRand {
293 fn rand<R: Rng>(_: &mut R) -> ConstRand {
294 ConstRand(0)
295 }
296 }
297
298 // 0, 1, 2, 3, ...
b039eaaf
SL
299 struct CountingRng {
300 i: u32,
301 }
1a4d82fc
JJ
302 impl Rng for CountingRng {
303 fn next_u32(&mut self) -> u32 {
304 self.i += 1;
305 self.i - 1
306 }
307 fn next_u64(&mut self) -> u64 {
308 self.next_u32() as u64
309 }
310 }
311
312 #[test]
313 fn test_rand_sample() {
85aaf69f 314 let mut rand_sample = RandSample::<ConstRand>::new();
1a4d82fc
JJ
315
316 assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0));
317 assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0));
318 }
319 #[test]
b039eaaf 320 #[rustfmt_skip]
1a4d82fc
JJ
321 fn test_weighted_choice() {
322 // this makes assumptions about the internal implementation of
323 // WeightedChoice, specifically: it doesn't reorder the items,
324 // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
325 // 1, internally; modulo a modulo operation).
326
327 macro_rules! t {
328 ($items:expr, $expected:expr) => {{
329 let mut items = $items;
85aaf69f 330 let wc = WeightedChoice::new(&mut items);
1a4d82fc
JJ
331 let expected = $expected;
332
333 let mut rng = CountingRng { i: 0 };
334
85aaf69f 335 for &val in &expected {
1a4d82fc
JJ
336 assert_eq!(wc.ind_sample(&mut rng), val)
337 }
338 }}
339 }
340
c30ab7b3 341 t!(vec![Weighted { weight: 1, item: 10 }],
b039eaaf 342 [10]);
1a4d82fc
JJ
343
344 // skip some
c30ab7b3 345 t!(vec![Weighted { weight: 0, item: 20 },
b039eaaf
SL
346 Weighted { weight: 2, item: 21 },
347 Weighted { weight: 0, item: 22 },
c30ab7b3 348 Weighted { weight: 1, item: 23 }],
b039eaaf 349 [21, 21, 23]);
1a4d82fc
JJ
350
351 // different weights
c30ab7b3
SL
352 t!(vec![Weighted { weight: 4, item: 30 },
353 Weighted { weight: 3, item: 31 }],
b039eaaf 354 [30, 30, 30, 30, 31, 31, 31]);
1a4d82fc
JJ
355
356 // check that we're binary searching
357 // correctly with some vectors of odd
358 // length.
c30ab7b3 359 t!(vec![Weighted { weight: 1, item: 40 },
b039eaaf
SL
360 Weighted { weight: 1, item: 41 },
361 Weighted { weight: 1, item: 42 },
362 Weighted { weight: 1, item: 43 },
c30ab7b3 363 Weighted { weight: 1, item: 44 }],
1a4d82fc 364 [40, 41, 42, 43, 44]);
c30ab7b3 365 t!(vec![Weighted { weight: 1, item: 50 },
b039eaaf
SL
366 Weighted { weight: 1, item: 51 },
367 Weighted { weight: 1, item: 52 },
368 Weighted { weight: 1, item: 53 },
369 Weighted { weight: 1, item: 54 },
370 Weighted { weight: 1, item: 55 },
c30ab7b3 371 Weighted { weight: 1, item: 56 }],
1a4d82fc
JJ
372 [50, 51, 52, 53, 54, 55, 56]);
373 }
374
b039eaaf
SL
375 #[test]
376 #[should_panic]
1a4d82fc 377 fn test_weighted_choice_no_items() {
c34b1796 378 WeightedChoice::<isize>::new(&mut []);
1a4d82fc 379 }
b039eaaf
SL
380 #[test]
381 #[should_panic]
382 #[rustfmt_skip]
1a4d82fc 383 fn test_weighted_choice_zero_weight() {
b039eaaf
SL
384 WeightedChoice::new(&mut [Weighted { weight: 0, item: 0 },
385 Weighted { weight: 0, item: 1 }]);
1a4d82fc 386 }
b039eaaf
SL
387 #[test]
388 #[should_panic]
389 #[rustfmt_skip]
1a4d82fc 390 fn test_weighted_choice_weight_overflows() {
c34b1796 391 let x = (!0) as usize / 2; // x + x + 2 is the overflow
85aaf69f
SL
392 WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
393 Weighted { weight: 1, item: 1 },
394 Weighted { weight: x, item: 2 },
395 Weighted { weight: 1, item: 3 }]);
1a4d82fc
JJ
396 }
397}