]> git.proxmox.com Git - cargo.git/blob - vendor/rand/src/distributions/weighted/mod.rs
New upstream version 0.37.0
[cargo.git] / vendor / rand / src / distributions / weighted / mod.rs
1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 //! Weighted index sampling
10 //!
11 //! This module provides two implementations for sampling indices:
12 //!
13 //! * [`WeightedIndex`] allows `O(log N)` sampling
14 //! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with
15 //! much greater set-up cost
16 //!
17 //! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html
18
19 pub mod alias_method;
20
21 use crate::Rng;
22 use crate::distributions::Distribution;
23 use crate::distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
24 use core::cmp::PartialOrd;
25 use core::fmt;
26
27 // Note that this whole module is only imported if feature="alloc" is enabled.
28 #[cfg(not(feature="std"))] use crate::alloc::vec::Vec;
29
30 /// A distribution using weighted sampling to pick a discretely selected
31 /// item.
32 ///
33 /// Sampling a `WeightedIndex` distribution returns the index of a randomly
34 /// selected element from the iterator used when the `WeightedIndex` was
35 /// created. The chance of a given element being picked is proportional to the
36 /// value of the element. The weights can use any type `X` for which an
37 /// implementation of [`Uniform<X>`] exists.
38 ///
39 /// # Performance
40 ///
41 /// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
42 /// size is the sum of the size of those objects, possibly plus some alignment.
43 ///
44 /// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
45 /// weights of type `X`, where `N` is the number of weights. However, since
46 /// `Vec` doesn't guarantee a particular growth strategy, additional memory
47 /// might be allocated but not used. Since the `WeightedIndex` object also
48 /// contains, this might cause additional allocations, though for primitive
49 /// types, ['Uniform<X>`] doesn't allocate any memory.
50 ///
51 /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
52 /// `N` is the number of weights.
53 ///
54 /// Sampling from `WeightedIndex` will result in a single call to
55 /// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
56 /// will request a single value from the underlying [`RngCore`], though the
57 /// exact number depends on the implementaiton of `Uniform<X>::sample`.
58 ///
59 /// # Example
60 ///
61 /// ```
62 /// use rand::prelude::*;
63 /// use rand::distributions::WeightedIndex;
64 ///
65 /// let choices = ['a', 'b', 'c'];
66 /// let weights = [2, 1, 1];
67 /// let dist = WeightedIndex::new(&weights).unwrap();
68 /// let mut rng = thread_rng();
69 /// for _ in 0..100 {
70 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
71 /// println!("{}", choices[dist.sample(&mut rng)]);
72 /// }
73 ///
74 /// let items = [('a', 0), ('b', 3), ('c', 7)];
75 /// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
76 /// for _ in 0..100 {
77 /// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
78 /// println!("{}", items[dist2.sample(&mut rng)].0);
79 /// }
80 /// ```
81 ///
82 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
83 /// [`RngCore`]: crate::RngCore
84 #[derive(Debug, Clone)]
85 pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
86 cumulative_weights: Vec<X>,
87 weight_distribution: X::Sampler,
88 }
89
90 impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
91 /// Creates a new a `WeightedIndex` [`Distribution`] using the values
92 /// in `weights`. The weights can use any type `X` for which an
93 /// implementation of [`Uniform<X>`] exists.
94 ///
95 /// Returns an error if the iterator is empty, if any weight is `< 0`, or
96 /// if its total value is 0.
97 ///
98 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
99 pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
100 where I: IntoIterator,
101 I::Item: SampleBorrow<X>,
102 X: for<'a> ::core::ops::AddAssign<&'a X> +
103 Clone +
104 Default {
105 let mut iter = weights.into_iter();
106 let mut total_weight: X = iter.next()
107 .ok_or(WeightedError::NoItem)?
108 .borrow()
109 .clone();
110
111 let zero = <X as Default>::default();
112 if total_weight < zero {
113 return Err(WeightedError::InvalidWeight);
114 }
115
116 let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
117 for w in iter {
118 if *w.borrow() < zero {
119 return Err(WeightedError::InvalidWeight);
120 }
121 weights.push(total_weight.clone());
122 total_weight += w.borrow();
123 }
124
125 if total_weight == zero {
126 return Err(WeightedError::AllWeightsZero);
127 }
128 let distr = X::Sampler::new(zero, total_weight);
129
130 Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
131 }
132 }
133
134 impl<X> Distribution<usize> for WeightedIndex<X> where
135 X: SampleUniform + PartialOrd {
136 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
137 use ::core::cmp::Ordering;
138 let chosen_weight = self.weight_distribution.sample(rng);
139 // Find the first item which has a weight *higher* than the chosen weight.
140 self.cumulative_weights.binary_search_by(
141 |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
142 }
143 }
144
145 #[cfg(test)]
146 mod test {
147 use super::*;
148
149 #[test]
150 #[cfg(not(miri))] // Miri is too slow
151 fn test_weightedindex() {
152 let mut r = crate::test::rng(700);
153 const N_REPS: u32 = 5000;
154 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
155 let total_weight = weights.iter().sum::<u32>() as f32;
156
157 let verify = |result: [i32; 14]| {
158 for (i, count) in result.iter().enumerate() {
159 let exp = (weights[i] * N_REPS) as f32 / total_weight;
160 let mut err = (*count as f32 - exp).abs();
161 if err != 0.0 {
162 err /= exp;
163 }
164 assert!(err <= 0.25);
165 }
166 };
167
168 // WeightedIndex from vec
169 let mut chosen = [0i32; 14];
170 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
171 for _ in 0..N_REPS {
172 chosen[distr.sample(&mut r)] += 1;
173 }
174 verify(chosen);
175
176 // WeightedIndex from slice
177 chosen = [0i32; 14];
178 let distr = WeightedIndex::new(&weights[..]).unwrap();
179 for _ in 0..N_REPS {
180 chosen[distr.sample(&mut r)] += 1;
181 }
182 verify(chosen);
183
184 // WeightedIndex from iterator
185 chosen = [0i32; 14];
186 let distr = WeightedIndex::new(weights.iter()).unwrap();
187 for _ in 0..N_REPS {
188 chosen[distr.sample(&mut r)] += 1;
189 }
190 verify(chosen);
191
192 for _ in 0..5 {
193 assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
194 assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
195 assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
196 }
197
198 assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem);
199 assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero);
200 assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::InvalidWeight);
201 assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight);
202 assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight);
203 }
204 }
205
206 /// Error type returned from `WeightedIndex::new`.
207 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
208 pub enum WeightedError {
209 /// The provided weight collection contains no items.
210 NoItem,
211
212 /// A weight is either less than zero, greater than the supported maximum or
213 /// otherwise invalid.
214 InvalidWeight,
215
216 /// All items in the provided weight collection are zero.
217 AllWeightsZero,
218
219 /// Too many weights are provided (length greater than `u32::MAX`)
220 TooMany,
221 }
222
223 impl WeightedError {
224 fn msg(&self) -> &str {
225 match *self {
226 WeightedError::NoItem => "No weights provided.",
227 WeightedError::InvalidWeight => "A weight is invalid.",
228 WeightedError::AllWeightsZero => "All weights are zero.",
229 WeightedError::TooMany => "Too many weights (hit u32::MAX)",
230 }
231 }
232 }
233
234 #[cfg(feature="std")]
235 impl ::std::error::Error for WeightedError {
236 fn description(&self) -> &str {
237 self.msg()
238 }
239 fn cause(&self) -> Option<&dyn (::std::error::Error)> {
240 None
241 }
242 }
243
244 impl fmt::Display for WeightedError {
245 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
246 write!(f, "{}", self.msg())
247 }
248 }