]> git.proxmox.com Git - rustc.git/blob - vendor/rand/src/distributions/weighted/mod.rs
New upstream version 1.51.0+dfsg1
[rustc.git] / vendor / rand / src / distributions / weighted / mod.rs
1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 //! Weighted index sampling
10 //!
11 //! This module provides two implementations for sampling indices:
12 //!
13 //! * [`WeightedIndex`] allows `O(log N)` sampling
14 //! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with
15 //! much greater set-up cost
16 //!
17 //! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html
18
19 pub mod alias_method;
20
21 use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
22 use crate::distributions::Distribution;
23 use crate::Rng;
24 use core::cmp::PartialOrd;
25 use core::fmt;
26
27 // Note that this whole module is only imported if feature="alloc" is enabled.
28 #[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;
29
30 /// A distribution using weighted sampling to pick a discretely selected
31 /// item.
32 ///
33 /// Sampling a `WeightedIndex` distribution returns the index of a randomly
34 /// selected element from the iterator used when the `WeightedIndex` was
35 /// created. The chance of a given element being picked is proportional to the
36 /// value of the element. The weights can use any type `X` for which an
37 /// implementation of [`Uniform<X>`] exists.
38 ///
39 /// # Performance
40 ///
41 /// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
42 /// size is the sum of the size of those objects, possibly plus some alignment.
43 ///
44 /// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
45 /// weights of type `X`, where `N` is the number of weights. However, since
46 /// `Vec` doesn't guarantee a particular growth strategy, additional memory
47 /// might be allocated but not used. Since the `WeightedIndex` object also
48 /// contains, this might cause additional allocations, though for primitive
49 /// types, ['Uniform<X>`] doesn't allocate any memory.
50 ///
51 /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
52 /// `N` is the number of weights.
53 ///
54 /// Sampling from `WeightedIndex` will result in a single call to
55 /// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
56 /// will request a single value from the underlying [`RngCore`], though the
57 /// exact number depends on the implementaiton of `Uniform<X>::sample`.
58 ///
59 /// # Example
60 ///
61 /// ```
62 /// use rand::prelude::*;
63 /// use rand::distributions::WeightedIndex;
64 ///
65 /// let choices = ['a', 'b', 'c'];
66 /// let weights = [2, 1, 1];
67 /// let dist = WeightedIndex::new(&weights).unwrap();
68 /// let mut rng = thread_rng();
69 /// for _ in 0..100 {
70 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
71 /// println!("{}", choices[dist.sample(&mut rng)]);
72 /// }
73 ///
74 /// let items = [('a', 0), ('b', 3), ('c', 7)];
75 /// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
76 /// for _ in 0..100 {
77 /// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
78 /// println!("{}", items[dist2.sample(&mut rng)].0);
79 /// }
80 /// ```
81 ///
82 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
83 /// [`RngCore`]: crate::RngCore
84 #[derive(Debug, Clone)]
85 pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
86 cumulative_weights: Vec<X>,
87 total_weight: X,
88 weight_distribution: X::Sampler,
89 }
90
91 impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
92 /// Creates a new a `WeightedIndex` [`Distribution`] using the values
93 /// in `weights`. The weights can use any type `X` for which an
94 /// implementation of [`Uniform<X>`] exists.
95 ///
96 /// Returns an error if the iterator is empty, if any weight is `< 0`, or
97 /// if its total value is 0.
98 ///
99 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
100 pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
101 where
102 I: IntoIterator,
103 I::Item: SampleBorrow<X>,
104 X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
105 {
106 let mut iter = weights.into_iter();
107 let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
108
109 let zero = <X as Default>::default();
110 if total_weight < zero {
111 return Err(WeightedError::InvalidWeight);
112 }
113
114 let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
115 for w in iter {
116 if *w.borrow() < zero {
117 return Err(WeightedError::InvalidWeight);
118 }
119 weights.push(total_weight.clone());
120 total_weight += w.borrow();
121 }
122
123 if total_weight == zero {
124 return Err(WeightedError::AllWeightsZero);
125 }
126 let distr = X::Sampler::new(zero, total_weight.clone());
127
128 Ok(WeightedIndex {
129 cumulative_weights: weights,
130 total_weight,
131 weight_distribution: distr,
132 })
133 }
134
135 /// Update a subset of weights, without changing the number of weights.
136 ///
137 /// `new_weights` must be sorted by the index.
138 ///
139 /// Using this method instead of `new` might be more efficient if only a small number of
140 /// weights is modified. No allocations are performed, unless the weight type `X` uses
141 /// allocation internally.
142 ///
143 /// In case of error, `self` is not modified.
144 pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
145 where X: for<'a> ::core::ops::AddAssign<&'a X>
146 + for<'a> ::core::ops::SubAssign<&'a X>
147 + Clone
148 + Default {
149 if new_weights.is_empty() {
150 return Ok(());
151 }
152
153 let zero = <X as Default>::default();
154
155 let mut total_weight = self.total_weight.clone();
156
157 // Check for errors first, so we don't modify `self` in case something
158 // goes wrong.
159 let mut prev_i = None;
160 for &(i, w) in new_weights {
161 if let Some(old_i) = prev_i {
162 if old_i >= i {
163 return Err(WeightedError::InvalidWeight);
164 }
165 }
166 if *w < zero {
167 return Err(WeightedError::InvalidWeight);
168 }
169 if i >= self.cumulative_weights.len() + 1 {
170 return Err(WeightedError::TooMany);
171 }
172
173 let mut old_w = if i < self.cumulative_weights.len() {
174 self.cumulative_weights[i].clone()
175 } else {
176 self.total_weight.clone()
177 };
178 if i > 0 {
179 old_w -= &self.cumulative_weights[i - 1];
180 }
181
182 total_weight -= &old_w;
183 total_weight += w;
184 prev_i = Some(i);
185 }
186 if total_weight == zero {
187 return Err(WeightedError::AllWeightsZero);
188 }
189
190 // Update the weights. Because we checked all the preconditions in the
191 // previous loop, this should never panic.
192 let mut iter = new_weights.iter();
193
194 let mut prev_weight = zero.clone();
195 let mut next_new_weight = iter.next();
196 let &(first_new_index, _) = next_new_weight.unwrap();
197 let mut cumulative_weight = if first_new_index > 0 {
198 self.cumulative_weights[first_new_index - 1].clone()
199 } else {
200 zero.clone()
201 };
202 for i in first_new_index..self.cumulative_weights.len() {
203 match next_new_weight {
204 Some(&(j, w)) if i == j => {
205 cumulative_weight += w;
206 next_new_weight = iter.next();
207 }
208 _ => {
209 let mut tmp = self.cumulative_weights[i].clone();
210 tmp -= &prev_weight; // We know this is positive.
211 cumulative_weight += &tmp;
212 }
213 }
214 prev_weight = cumulative_weight.clone();
215 core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
216 }
217
218 self.total_weight = total_weight;
219 self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
220
221 Ok(())
222 }
223 }
224
225 impl<X> Distribution<usize> for WeightedIndex<X>
226 where X: SampleUniform + PartialOrd
227 {
228 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
229 use ::core::cmp::Ordering;
230 let chosen_weight = self.weight_distribution.sample(rng);
231 // Find the first item which has a weight *higher* than the chosen weight.
232 self.cumulative_weights
233 .binary_search_by(|w| {
234 if *w <= chosen_weight {
235 Ordering::Less
236 } else {
237 Ordering::Greater
238 }
239 })
240 .unwrap_err()
241 }
242 }
243
244 #[cfg(test)]
245 mod test {
246 use super::*;
247
248 #[test]
249 #[cfg_attr(miri, ignore)] // Miri is too slow
250 fn test_weightedindex() {
251 let mut r = crate::test::rng(700);
252 const N_REPS: u32 = 5000;
253 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
254 let total_weight = weights.iter().sum::<u32>() as f32;
255
256 let verify = |result: [i32; 14]| {
257 for (i, count) in result.iter().enumerate() {
258 let exp = (weights[i] * N_REPS) as f32 / total_weight;
259 let mut err = (*count as f32 - exp).abs();
260 if err != 0.0 {
261 err /= exp;
262 }
263 assert!(err <= 0.25);
264 }
265 };
266
267 // WeightedIndex from vec
268 let mut chosen = [0i32; 14];
269 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
270 for _ in 0..N_REPS {
271 chosen[distr.sample(&mut r)] += 1;
272 }
273 verify(chosen);
274
275 // WeightedIndex from slice
276 chosen = [0i32; 14];
277 let distr = WeightedIndex::new(&weights[..]).unwrap();
278 for _ in 0..N_REPS {
279 chosen[distr.sample(&mut r)] += 1;
280 }
281 verify(chosen);
282
283 // WeightedIndex from iterator
284 chosen = [0i32; 14];
285 let distr = WeightedIndex::new(weights.iter()).unwrap();
286 for _ in 0..N_REPS {
287 chosen[distr.sample(&mut r)] += 1;
288 }
289 verify(chosen);
290
291 for _ in 0..5 {
292 assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
293 assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
294 assert_eq!(
295 WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
296 .unwrap()
297 .sample(&mut r),
298 4
299 );
300 }
301
302 assert_eq!(
303 WeightedIndex::new(&[10][0..0]).unwrap_err(),
304 WeightedError::NoItem
305 );
306 assert_eq!(
307 WeightedIndex::new(&[0]).unwrap_err(),
308 WeightedError::AllWeightsZero
309 );
310 assert_eq!(
311 WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
312 WeightedError::InvalidWeight
313 );
314 assert_eq!(
315 WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
316 WeightedError::InvalidWeight
317 );
318 assert_eq!(
319 WeightedIndex::new(&[-10]).unwrap_err(),
320 WeightedError::InvalidWeight
321 );
322 }
323
324 #[test]
325 fn test_update_weights() {
326 let data = [
327 (
328 &[10u32, 2, 3, 4][..],
329 &[(1, &100), (2, &4)][..], // positive change
330 &[10, 100, 4, 4][..],
331 ),
332 (
333 &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
334 &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
335 &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
336 ),
337 ];
338
339 for (weights, update, expected_weights) in data.iter() {
340 let total_weight = weights.iter().sum::<u32>();
341 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
342 assert_eq!(distr.total_weight, total_weight);
343
344 distr.update_weights(update).unwrap();
345 let expected_total_weight = expected_weights.iter().sum::<u32>();
346 let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
347 assert_eq!(distr.total_weight, expected_total_weight);
348 assert_eq!(distr.total_weight, expected_distr.total_weight);
349 assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
350 }
351 }
352
353 #[test]
354 fn value_stability() {
355 fn test_samples<X: SampleUniform + PartialOrd, I>(
356 weights: I, buf: &mut [usize], expected: &[usize],
357 ) where
358 I: IntoIterator,
359 I::Item: SampleBorrow<X>,
360 X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
361 {
362 assert_eq!(buf.len(), expected.len());
363 let distr = WeightedIndex::new(weights).unwrap();
364 let mut rng = crate::test::rng(701);
365 for r in buf.iter_mut() {
366 *r = rng.sample(&distr);
367 }
368 assert_eq!(buf, expected);
369 }
370
371 let mut buf = [0; 10];
372 test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
373 0, 6, 2, 6, 3, 4, 7, 8, 2, 5,
374 ]);
375 test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
376 0, 0, 0, 1, 0, 0, 2, 3, 0, 0,
377 ]);
378 test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
379 2, 2, 1, 3, 2, 1, 3, 3, 2, 1,
380 ]);
381 }
382 }
383
384 /// Error type returned from `WeightedIndex::new`.
385 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
386 pub enum WeightedError {
387 /// The provided weight collection contains no items.
388 NoItem,
389
390 /// A weight is either less than zero, greater than the supported maximum or
391 /// otherwise invalid.
392 InvalidWeight,
393
394 /// All items in the provided weight collection are zero.
395 AllWeightsZero,
396
397 /// Too many weights are provided (length greater than `u32::MAX`)
398 TooMany,
399 }
400
401 #[cfg(feature = "std")]
402 impl ::std::error::Error for WeightedError {}
403
404 impl fmt::Display for WeightedError {
405 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
406 match *self {
407 WeightedError::NoItem => write!(f, "No weights provided."),
408 WeightedError::InvalidWeight => write!(f, "A weight is invalid."),
409 WeightedError::AllWeightsZero => write!(f, "All weights are zero."),
410 WeightedError::TooMany => write!(f, "Too many weights (hit u32::MAX)"),
411 }
412 }
413 }