]>
Commit | Line | Data |
---|---|---|
1a4d82fc JJ |
1 | // Copyright 2013 The Rust Project Developers. See the COPYRIGHT |
2 | // file at the top-level directory of this distribution and at | |
3 | // http://rust-lang.org/COPYRIGHT. | |
4 | // | |
5 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | |
6 | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | |
7 | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | |
8 | // option. This file may not be copied, modified, or distributed | |
9 | // except according to those terms. | |
10 | ||
11 | //! Sampling from random distributions. | |
12 | //! | |
13 | //! This is a generalization of `Rand` to allow parameters to control the | |
14 | //! exact properties of the generated values, e.g. the mean and standard | |
15 | //! deviation of a normal distribution. The `Sample` trait is the most | |
16 | //! general, and allows for generating values that change some state | |
17 | //! internally. The `IndependentSample` trait is for generating values | |
18 | //! that do not need to record state. | |
19 | ||
32a655c1 SL |
20 | use core::fmt; |
21 | ||
a7813a04 | 22 | #[cfg(not(test))] // only necessary for no_std |
9346a6ac | 23 | use core::num::Float; |
a7813a04 | 24 | |
85aaf69f | 25 | use core::marker::PhantomData; |
1a4d82fc | 26 | |
3157f602 | 27 | use {Rand, Rng}; |
1a4d82fc JJ |
28 | |
29 | pub use self::range::Range; | |
3157f602 XL |
30 | pub use self::gamma::{ChiSquared, FisherF, Gamma, StudentT}; |
31 | pub use self::normal::{LogNormal, Normal}; | |
1a4d82fc JJ |
32 | pub use self::exponential::Exp; |
33 | ||
34 | pub mod range; | |
35 | pub mod gamma; | |
36 | pub mod normal; | |
37 | pub mod exponential; | |
38 | ||
39 | /// Types that can be used to create a random instance of `Support`. | |
40 | pub trait Sample<Support> { | |
41 | /// Generate a random value of `Support`, using `rng` as the | |
42 | /// source of randomness. | |
43 | fn sample<R: Rng>(&mut self, rng: &mut R) -> Support; | |
44 | } | |
45 | ||
46 | /// `Sample`s that do not require keeping track of state. | |
47 | /// | |
48 | /// Since no state is recorded, each sample is (statistically) | |
49 | /// independent of all others, assuming the `Rng` used has this | |
50 | /// property. | |
51 | // FIXME maybe having this separate is overkill (the only reason is to | |
52 | // take &self rather than &mut self)? or maybe this should be the | |
53 | // trait called `Sample` and the other should be `DependentSample`. | |
54 | pub trait IndependentSample<Support>: Sample<Support> { | |
55 | /// Generate a random value. | |
7cac9316 | 56 | fn ind_sample<R: Rng>(&self, _: &mut R) -> Support; |
1a4d82fc JJ |
57 | } |
58 | ||
59 | /// A wrapper for generating types that implement `Rand` via the | |
60 | /// `Sample` & `IndependentSample` traits. | |
b039eaaf SL |
61 | pub struct RandSample<Sup> { |
62 | _marker: PhantomData<Sup>, | |
63 | } | |
85aaf69f SL |
64 | |
65 | impl<Sup> RandSample<Sup> { | |
66 | pub fn new() -> RandSample<Sup> { | |
67 | RandSample { _marker: PhantomData } | |
68 | } | |
69 | } | |
1a4d82fc JJ |
70 | |
71 | impl<Sup: Rand> Sample<Sup> for RandSample<Sup> { | |
b039eaaf SL |
72 | fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup { |
73 | self.ind_sample(rng) | |
74 | } | |
1a4d82fc JJ |
75 | } |
76 | ||
77 | impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> { | |
78 | fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup { | |
79 | rng.gen() | |
80 | } | |
81 | } | |
82 | ||
32a655c1 SL |
83 | impl<Sup> fmt::Debug for RandSample<Sup> { |
84 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
85 | f.pad("RandSample { .. }") | |
86 | } | |
87 | } | |
88 | ||
1a4d82fc JJ |
89 | /// A value with a particular weight for use with `WeightedChoice`. |
90 | pub struct Weighted<T> { | |
91 | /// The numerical weight of this item | |
c34b1796 | 92 | pub weight: usize, |
1a4d82fc JJ |
93 | /// The actual item which is being weighted |
94 | pub item: T, | |
95 | } | |
96 | ||
32a655c1 SL |
97 | impl<T: fmt::Debug> fmt::Debug for Weighted<T> { |
98 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
99 | f.debug_struct("Weighted") | |
100 | .field("weight", &self.weight) | |
101 | .field("item", &self.item) | |
102 | .finish() | |
103 | } | |
104 | } | |
105 | ||
1a4d82fc JJ |
106 | /// A distribution that selects from a finite collection of weighted items. |
107 | /// | |
108 | /// Each item has an associated weight that influences how likely it | |
109 | /// is to be chosen: higher weight is more likely. | |
110 | /// | |
111 | /// The `Clone` restriction is a limitation of the `Sample` and | |
112 | /// `IndependentSample` traits. Note that `&T` is (cheaply) `Clone` for | |
c34b1796 | 113 | /// all `T`, as is `usize`, so one can store references or indices into |
1a4d82fc | 114 | /// another vector. |
b039eaaf | 115 | pub struct WeightedChoice<'a, T: 'a> { |
1a4d82fc | 116 | items: &'a mut [Weighted<T>], |
b039eaaf | 117 | weight_range: Range<usize>, |
1a4d82fc JJ |
118 | } |
119 | ||
120 | impl<'a, T: Clone> WeightedChoice<'a, T> { | |
121 | /// Create a new `WeightedChoice`. | |
122 | /// | |
123 | /// Panics if: | |
124 | /// - `v` is empty | |
125 | /// - the total weight is 0 | |
c34b1796 | 126 | /// - the total weight is larger than a `usize` can contain. |
1a4d82fc JJ |
127 | pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> { |
128 | // strictly speaking, this is subsumed by the total weight == 0 case | |
b039eaaf SL |
129 | assert!(!items.is_empty(), |
130 | "WeightedChoice::new called with no items"); | |
1a4d82fc | 131 | |
c34b1796 | 132 | let mut running_total = 0_usize; |
1a4d82fc JJ |
133 | |
134 | // we convert the list from individual weights to cumulative | |
135 | // weights so we can binary search. This *could* drop elements | |
136 | // with weight == 0 as an optimisation. | |
85aaf69f | 137 | for item in &mut *items { |
1a4d82fc JJ |
138 | running_total = match running_total.checked_add(item.weight) { |
139 | Some(n) => n, | |
92a42be0 SL |
140 | None => { |
141 | panic!("WeightedChoice::new called with a total weight larger than a usize \ | |
142 | can contain") | |
143 | } | |
1a4d82fc JJ |
144 | }; |
145 | ||
146 | item.weight = running_total; | |
147 | } | |
b039eaaf SL |
148 | assert!(running_total != 0, |
149 | "WeightedChoice::new called with a total weight of 0"); | |
1a4d82fc JJ |
150 | |
151 | WeightedChoice { | |
3b2f2976 | 152 | items, |
1a4d82fc JJ |
153 | // we're likely to be generating numbers in this range |
154 | // relatively often, so might as well cache it | |
b039eaaf | 155 | weight_range: Range::new(0, running_total), |
1a4d82fc JJ |
156 | } |
157 | } | |
158 | } | |
159 | ||
160 | impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> { | |
b039eaaf SL |
161 | fn sample<R: Rng>(&mut self, rng: &mut R) -> T { |
162 | self.ind_sample(rng) | |
163 | } | |
1a4d82fc JJ |
164 | } |
165 | ||
166 | impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> { | |
167 | fn ind_sample<R: Rng>(&self, rng: &mut R) -> T { | |
168 | // we want to find the first element that has cumulative | |
169 | // weight > sample_weight, which we do by binary since the | |
170 | // cumulative weights of self.items are sorted. | |
171 | ||
172 | // choose a weight in [0, total_weight) | |
173 | let sample_weight = self.weight_range.ind_sample(rng); | |
174 | ||
175 | // short circuit when it's the first item | |
176 | if sample_weight < self.items[0].weight { | |
177 | return self.items[0].item.clone(); | |
178 | } | |
179 | ||
180 | let mut idx = 0; | |
181 | let mut modifier = self.items.len(); | |
182 | ||
183 | // now we know that every possibility has an element to the | |
184 | // left, so we can just search for the last element that has | |
185 | // cumulative weight <= sample_weight, then the next one will | |
186 | // be "it". (Note that this greatest element will never be the | |
187 | // last element of the vector, since sample_weight is chosen | |
188 | // in [0, total_weight) and the cumulative weight of the last | |
189 | // one is exactly the total weight.) | |
190 | while modifier > 1 { | |
191 | let i = idx + modifier / 2; | |
192 | if self.items[i].weight <= sample_weight { | |
193 | // we're small, so look to the right, but allow this | |
194 | // exact element still. | |
195 | idx = i; | |
196 | // we need the `/ 2` to round up otherwise we'll drop | |
197 | // the trailing elements when `modifier` is odd. | |
198 | modifier += 1; | |
199 | } else { | |
200 | // otherwise we're too big, so go left. (i.e. do | |
201 | // nothing) | |
202 | } | |
203 | modifier /= 2; | |
204 | } | |
205 | return self.items[idx + 1].item.clone(); | |
206 | } | |
207 | } | |
208 | ||
32a655c1 SL |
209 | impl<'a, T: fmt::Debug> fmt::Debug for WeightedChoice<'a, T> { |
210 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
211 | f.debug_struct("WeightedChoice") | |
212 | .field("items", &self.items) | |
213 | .field("weight_range", &self.weight_range) | |
214 | .finish() | |
215 | } | |
216 | } | |
217 | ||
1a4d82fc JJ |
218 | mod ziggurat_tables; |
219 | ||
220 | /// Sample a random number using the Ziggurat method (specifically the | |
221 | /// ZIGNOR variant from Doornik 2005). Most of the arguments are | |
222 | /// directly from the paper: | |
223 | /// | |
224 | /// * `rng`: source of randomness | |
225 | /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0. | |
226 | /// * `X`: the $x_i$ abscissae. | |
227 | /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$) | |
228 | /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$ | |
229 | /// * `pdf`: the probability density function | |
230 | /// * `zero_case`: manual sampling from the tail when we chose the | |
231 | /// bottom box (i.e. i == 0) | |
1a4d82fc JJ |
232 | // the perf improvement (25-50%) is definitely worth the extra code |
233 | // size from force-inlining. | |
234 | #[inline(always)] | |
b039eaaf SL |
235 | fn ziggurat<R: Rng, P, Z>(rng: &mut R, |
236 | symmetric: bool, | |
237 | x_tab: ziggurat_tables::ZigTable, | |
238 | f_tab: ziggurat_tables::ZigTable, | |
239 | mut pdf: P, | |
240 | mut zero_case: Z) | |
241 | -> f64 | |
242 | where P: FnMut(f64) -> f64, | |
243 | Z: FnMut(&mut R, f64) -> f64 | |
244 | { | |
c34b1796 | 245 | const SCALE: f64 = (1u64 << 53) as f64; |
1a4d82fc JJ |
246 | loop { |
247 | // reimplement the f64 generation as an optimisation suggested | |
248 | // by the Doornik paper: we have a lot of precision-space | |
249 | // (i.e. there are 11 bits of the 64 of a u64 to use after | |
250 | // creating a f64), so we might as well reuse some to save | |
251 | // generating a whole extra random number. (Seems to be 15% | |
252 | // faster.) | |
253 | // | |
254 | // This unfortunately misses out on the benefits of direct | |
255 | // floating point generation if an RNG like dSMFT is | |
256 | // used. (That is, such RNGs create floats directly, highly | |
257 | // efficiently and overload next_f32/f64, so by not calling it | |
258 | // this may be slower than it would be otherwise.) | |
259 | // FIXME: investigate/optimise for the above. | |
260 | let bits: u64 = rng.gen(); | |
c34b1796 | 261 | let i = (bits & 0xff) as usize; |
1a4d82fc JJ |
262 | let f = (bits >> 11) as f64 / SCALE; |
263 | ||
264 | // u is either U(-1, 1) or U(0, 1) depending on if this is a | |
265 | // symmetric distribution or not. | |
c30ab7b3 | 266 | let u = if symmetric { 2.0 * f - 1.0 } else { f }; |
1a4d82fc JJ |
267 | let x = u * x_tab[i]; |
268 | ||
c30ab7b3 | 269 | let test_x = if symmetric { x.abs() } else { x }; |
1a4d82fc JJ |
270 | |
271 | // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i]) | |
272 | if test_x < x_tab[i + 1] { | |
273 | return x; | |
274 | } | |
275 | if i == 0 { | |
276 | return zero_case(rng, u); | |
277 | } | |
278 | // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 | |
c34b1796 | 279 | if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) { |
1a4d82fc JJ |
280 | return x; |
281 | } | |
282 | } | |
283 | } | |
284 | ||
285 | #[cfg(test)] | |
286 | mod tests { | |
3157f602 XL |
287 | use {Rand, Rng}; |
288 | use super::{IndependentSample, RandSample, Sample, Weighted, WeightedChoice}; | |
1a4d82fc | 289 | |
85aaf69f | 290 | #[derive(PartialEq, Debug)] |
c34b1796 | 291 | struct ConstRand(usize); |
1a4d82fc JJ |
292 | impl Rand for ConstRand { |
293 | fn rand<R: Rng>(_: &mut R) -> ConstRand { | |
294 | ConstRand(0) | |
295 | } | |
296 | } | |
297 | ||
298 | // 0, 1, 2, 3, ... | |
b039eaaf SL |
299 | struct CountingRng { |
300 | i: u32, | |
301 | } | |
1a4d82fc JJ |
302 | impl Rng for CountingRng { |
303 | fn next_u32(&mut self) -> u32 { | |
304 | self.i += 1; | |
305 | self.i - 1 | |
306 | } | |
307 | fn next_u64(&mut self) -> u64 { | |
308 | self.next_u32() as u64 | |
309 | } | |
310 | } | |
311 | ||
312 | #[test] | |
313 | fn test_rand_sample() { | |
85aaf69f | 314 | let mut rand_sample = RandSample::<ConstRand>::new(); |
1a4d82fc JJ |
315 | |
316 | assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0)); | |
317 | assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0)); | |
318 | } | |
319 | #[test] | |
b039eaaf | 320 | #[rustfmt_skip] |
1a4d82fc JJ |
321 | fn test_weighted_choice() { |
322 | // this makes assumptions about the internal implementation of | |
323 | // WeightedChoice, specifically: it doesn't reorder the items, | |
324 | // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to | |
325 | // 1, internally; modulo a modulo operation). | |
326 | ||
327 | macro_rules! t { | |
328 | ($items:expr, $expected:expr) => {{ | |
329 | let mut items = $items; | |
85aaf69f | 330 | let wc = WeightedChoice::new(&mut items); |
1a4d82fc JJ |
331 | let expected = $expected; |
332 | ||
333 | let mut rng = CountingRng { i: 0 }; | |
334 | ||
85aaf69f | 335 | for &val in &expected { |
1a4d82fc JJ |
336 | assert_eq!(wc.ind_sample(&mut rng), val) |
337 | } | |
338 | }} | |
339 | } | |
340 | ||
c30ab7b3 | 341 | t!(vec![Weighted { weight: 1, item: 10 }], |
b039eaaf | 342 | [10]); |
1a4d82fc JJ |
343 | |
344 | // skip some | |
c30ab7b3 | 345 | t!(vec![Weighted { weight: 0, item: 20 }, |
b039eaaf SL |
346 | Weighted { weight: 2, item: 21 }, |
347 | Weighted { weight: 0, item: 22 }, | |
c30ab7b3 | 348 | Weighted { weight: 1, item: 23 }], |
b039eaaf | 349 | [21, 21, 23]); |
1a4d82fc JJ |
350 | |
351 | // different weights | |
c30ab7b3 SL |
352 | t!(vec![Weighted { weight: 4, item: 30 }, |
353 | Weighted { weight: 3, item: 31 }], | |
b039eaaf | 354 | [30, 30, 30, 30, 31, 31, 31]); |
1a4d82fc JJ |
355 | |
356 | // check that we're binary searching | |
357 | // correctly with some vectors of odd | |
358 | // length. | |
c30ab7b3 | 359 | t!(vec![Weighted { weight: 1, item: 40 }, |
b039eaaf SL |
360 | Weighted { weight: 1, item: 41 }, |
361 | Weighted { weight: 1, item: 42 }, | |
362 | Weighted { weight: 1, item: 43 }, | |
c30ab7b3 | 363 | Weighted { weight: 1, item: 44 }], |
1a4d82fc | 364 | [40, 41, 42, 43, 44]); |
c30ab7b3 | 365 | t!(vec![Weighted { weight: 1, item: 50 }, |
b039eaaf SL |
366 | Weighted { weight: 1, item: 51 }, |
367 | Weighted { weight: 1, item: 52 }, | |
368 | Weighted { weight: 1, item: 53 }, | |
369 | Weighted { weight: 1, item: 54 }, | |
370 | Weighted { weight: 1, item: 55 }, | |
c30ab7b3 | 371 | Weighted { weight: 1, item: 56 }], |
1a4d82fc JJ |
372 | [50, 51, 52, 53, 54, 55, 56]); |
373 | } | |
374 | ||
b039eaaf SL |
375 | #[test] |
376 | #[should_panic] | |
1a4d82fc | 377 | fn test_weighted_choice_no_items() { |
c34b1796 | 378 | WeightedChoice::<isize>::new(&mut []); |
1a4d82fc | 379 | } |
b039eaaf SL |
380 | #[test] |
381 | #[should_panic] | |
382 | #[rustfmt_skip] | |
1a4d82fc | 383 | fn test_weighted_choice_zero_weight() { |
b039eaaf SL |
384 | WeightedChoice::new(&mut [Weighted { weight: 0, item: 0 }, |
385 | Weighted { weight: 0, item: 1 }]); | |
1a4d82fc | 386 | } |
b039eaaf SL |
387 | #[test] |
388 | #[should_panic] | |
389 | #[rustfmt_skip] | |
1a4d82fc | 390 | fn test_weighted_choice_weight_overflows() { |
c34b1796 | 391 | let x = (!0) as usize / 2; // x + x + 2 is the overflow |
85aaf69f SL |
392 | WeightedChoice::new(&mut [Weighted { weight: x, item: 0 }, |
393 | Weighted { weight: 1, item: 1 }, | |
394 | Weighted { weight: x, item: 2 }, | |
395 | Weighted { weight: 1, item: 3 }]); | |
1a4d82fc JJ |
396 | } |
397 | } |