1 // Copyright 2018 Developers of the Rand project.
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
9 //! Weighted index sampling
11 //! This module provides two implementations for sampling indices:
13 //! * [`WeightedIndex`] allows `O(log N)` sampling
14 //! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with
15 //! much greater set-up cost
17 //! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html
21 use crate::distributions
::uniform
::{SampleBorrow, SampleUniform, UniformSampler}
;
22 use crate::distributions
::Distribution
;
24 use core
::cmp
::PartialOrd
;
27 // Note that this whole module is only imported if feature="alloc" is enabled.
28 #[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;
30 /// A distribution using weighted sampling to pick a discretely selected
33 /// Sampling a `WeightedIndex` distribution returns the index of a randomly
34 /// selected element from the iterator used when the `WeightedIndex` was
35 /// created. The chance of a given element being picked is proportional to the
36 /// value of the element. The weights can use any type `X` for which an
37 /// implementation of [`Uniform<X>`] exists.
41 /// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
42 /// size is the sum of the size of those objects, possibly plus some alignment.
44 /// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
45 /// weights of type `X`, where `N` is the number of weights. However, since
46 /// `Vec` doesn't guarantee a particular growth strategy, additional memory
47 /// might be allocated but not used. Since the `WeightedIndex` object also
48 /// contains, this might cause additional allocations, though for primitive
49 /// types, ['Uniform<X>`] doesn't allocate any memory.
51 /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
52 /// `N` is the number of weights.
54 /// Sampling from `WeightedIndex` will result in a single call to
55 /// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
56 /// will request a single value from the underlying [`RngCore`], though the
57 /// exact number depends on the implementaiton of `Uniform<X>::sample`.
62 /// use rand::prelude::*;
63 /// use rand::distributions::WeightedIndex;
65 /// let choices = ['a', 'b', 'c'];
66 /// let weights = [2, 1, 1];
67 /// let dist = WeightedIndex::new(&weights).unwrap();
68 /// let mut rng = thread_rng();
70 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
71 /// println!("{}", choices[dist.sample(&mut rng)]);
74 /// let items = [('a', 0), ('b', 3), ('c', 7)];
75 /// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
77 /// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
78 /// println!("{}", items[dist2.sample(&mut rng)].0);
82 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
83 /// [`RngCore`]: crate::RngCore
84 #[derive(Debug, Clone)]
85 pub struct WeightedIndex
<X
: SampleUniform
+ PartialOrd
> {
86 cumulative_weights
: Vec
<X
>,
88 weight_distribution
: X
::Sampler
,
91 impl<X
: SampleUniform
+ PartialOrd
> WeightedIndex
<X
> {
92 /// Creates a new a `WeightedIndex` [`Distribution`] using the values
93 /// in `weights`. The weights can use any type `X` for which an
94 /// implementation of [`Uniform<X>`] exists.
96 /// Returns an error if the iterator is empty, if any weight is `< 0`, or
97 /// if its total value is 0.
99 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
100 pub fn new
<I
>(weights
: I
) -> Result
<WeightedIndex
<X
>, WeightedError
>
103 I
::Item
: SampleBorrow
<X
>,
104 X
: for<'a
> ::core
::ops
::AddAssign
<&'a X
> + Clone
+ Default
,
106 let mut iter
= weights
.into_iter();
107 let mut total_weight
: X
= iter
.next().ok_or(WeightedError
::NoItem
)?
.borrow().clone();
109 let zero
= <X
as Default
>::default();
110 if total_weight
< zero
{
111 return Err(WeightedError
::InvalidWeight
);
114 let mut weights
= Vec
::<X
>::with_capacity(iter
.size_hint().0);
116 if *w
.borrow() < zero
{
117 return Err(WeightedError
::InvalidWeight
);
119 weights
.push(total_weight
.clone());
120 total_weight
+= w
.borrow();
123 if total_weight
== zero
{
124 return Err(WeightedError
::AllWeightsZero
);
126 let distr
= X
::Sampler
::new(zero
, total_weight
.clone());
129 cumulative_weights
: weights
,
131 weight_distribution
: distr
,
135 /// Update a subset of weights, without changing the number of weights.
137 /// `new_weights` must be sorted by the index.
139 /// Using this method instead of `new` might be more efficient if only a small number of
140 /// weights is modified. No allocations are performed, unless the weight type `X` uses
141 /// allocation internally.
143 /// In case of error, `self` is not modified.
144 pub fn update_weights(&mut self, new_weights
: &[(usize, &X
)]) -> Result
<(), WeightedError
>
145 where X
: for<'a
> ::core
::ops
::AddAssign
<&'a X
>
146 + for<'a
> ::core
::ops
::SubAssign
<&'a X
>
149 if new_weights
.is_empty() {
153 let zero
= <X
as Default
>::default();
155 let mut total_weight
= self.total_weight
.clone();
157 // Check for errors first, so we don't modify `self` in case something
159 let mut prev_i
= None
;
160 for &(i
, w
) in new_weights
{
161 if let Some(old_i
) = prev_i
{
163 return Err(WeightedError
::InvalidWeight
);
167 return Err(WeightedError
::InvalidWeight
);
169 if i
>= self.cumulative_weights
.len() + 1 {
170 return Err(WeightedError
::TooMany
);
173 let mut old_w
= if i
< self.cumulative_weights
.len() {
174 self.cumulative_weights
[i
].clone()
176 self.total_weight
.clone()
179 old_w
-= &self.cumulative_weights
[i
- 1];
182 total_weight
-= &old_w
;
186 if total_weight
== zero
{
187 return Err(WeightedError
::AllWeightsZero
);
190 // Update the weights. Because we checked all the preconditions in the
191 // previous loop, this should never panic.
192 let mut iter
= new_weights
.iter();
194 let mut prev_weight
= zero
.clone();
195 let mut next_new_weight
= iter
.next();
196 let &(first_new_index
, _
) = next_new_weight
.unwrap();
197 let mut cumulative_weight
= if first_new_index
> 0 {
198 self.cumulative_weights
[first_new_index
- 1].clone()
202 for i
in first_new_index
..self.cumulative_weights
.len() {
203 match next_new_weight
{
204 Some(&(j
, w
)) if i
== j
=> {
205 cumulative_weight
+= w
;
206 next_new_weight
= iter
.next();
209 let mut tmp
= self.cumulative_weights
[i
].clone();
210 tmp
-= &prev_weight
; // We know this is positive.
211 cumulative_weight
+= &tmp
;
214 prev_weight
= cumulative_weight
.clone();
215 core
::mem
::swap(&mut prev_weight
, &mut self.cumulative_weights
[i
]);
218 self.total_weight
= total_weight
;
219 self.weight_distribution
= X
::Sampler
::new(zero
, self.total_weight
.clone());
225 impl<X
> Distribution
<usize> for WeightedIndex
<X
>
226 where X
: SampleUniform
+ PartialOrd
228 fn sample
<R
: Rng
+ ?Sized
>(&self, rng
: &mut R
) -> usize {
229 use ::core
::cmp
::Ordering
;
230 let chosen_weight
= self.weight_distribution
.sample(rng
);
231 // Find the first item which has a weight *higher* than the chosen weight.
232 self.cumulative_weights
233 .binary_search_by(|w
| {
234 if *w
<= chosen_weight
{
249 #[cfg_attr(miri, ignore)] // Miri is too slow
250 fn test_weightedindex() {
251 let mut r
= crate::test
::rng(700);
252 const N_REPS
: u32 = 5000;
253 let weights
= [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
254 let total_weight
= weights
.iter().sum
::<u32>() as f32;
256 let verify
= |result
: [i32; 14]| {
257 for (i
, count
) in result
.iter().enumerate() {
258 let exp
= (weights
[i
] * N_REPS
) as f32 / total_weight
;
259 let mut err
= (*count
as f32 - exp
).abs();
263 assert
!(err
<= 0.25);
267 // WeightedIndex from vec
268 let mut chosen
= [0i32; 14];
269 let distr
= WeightedIndex
::new(weights
.to_vec()).unwrap();
271 chosen
[distr
.sample(&mut r
)] += 1;
275 // WeightedIndex from slice
277 let distr
= WeightedIndex
::new(&weights
[..]).unwrap();
279 chosen
[distr
.sample(&mut r
)] += 1;
283 // WeightedIndex from iterator
285 let distr
= WeightedIndex
::new(weights
.iter()).unwrap();
287 chosen
[distr
.sample(&mut r
)] += 1;
292 assert_eq
!(WeightedIndex
::new(&[0, 1]).unwrap().sample(&mut r
), 1);
293 assert_eq
!(WeightedIndex
::new(&[1, 0]).unwrap().sample(&mut r
), 0);
295 WeightedIndex
::new(&[0, 0, 0, 0, 10, 0])
303 WeightedIndex
::new(&[10][0..0]).unwrap_err(),
304 WeightedError
::NoItem
307 WeightedIndex
::new(&[0]).unwrap_err(),
308 WeightedError
::AllWeightsZero
311 WeightedIndex
::new(&[10, 20, -1, 30]).unwrap_err(),
312 WeightedError
::InvalidWeight
315 WeightedIndex
::new(&[-10, 20, 1, 30]).unwrap_err(),
316 WeightedError
::InvalidWeight
319 WeightedIndex
::new(&[-10]).unwrap_err(),
320 WeightedError
::InvalidWeight
325 fn test_update_weights() {
328 &[10u32, 2, 3, 4][..],
329 &[(1, &100), (2, &4)][..], // positive change
330 &[10, 100, 4, 4][..],
333 &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
334 &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
335 &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
339 for (weights
, update
, expected_weights
) in data
.iter() {
340 let total_weight
= weights
.iter().sum
::<u32>();
341 let mut distr
= WeightedIndex
::new(weights
.to_vec()).unwrap();
342 assert_eq
!(distr
.total_weight
, total_weight
);
344 distr
.update_weights(update
).unwrap();
345 let expected_total_weight
= expected_weights
.iter().sum
::<u32>();
346 let expected_distr
= WeightedIndex
::new(expected_weights
.to_vec()).unwrap();
347 assert_eq
!(distr
.total_weight
, expected_total_weight
);
348 assert_eq
!(distr
.total_weight
, expected_distr
.total_weight
);
349 assert_eq
!(distr
.cumulative_weights
, expected_distr
.cumulative_weights
);
354 fn value_stability() {
355 fn test_samples
<X
: SampleUniform
+ PartialOrd
, I
>(
356 weights
: I
, buf
: &mut [usize], expected
: &[usize],
359 I
::Item
: SampleBorrow
<X
>,
360 X
: for<'a
> ::core
::ops
::AddAssign
<&'a X
> + Clone
+ Default
,
362 assert_eq
!(buf
.len(), expected
.len());
363 let distr
= WeightedIndex
::new(weights
).unwrap();
364 let mut rng
= crate::test
::rng(701);
365 for r
in buf
.iter_mut() {
366 *r
= rng
.sample(&distr
);
368 assert_eq
!(buf
, expected
);
371 let mut buf
= [0; 10];
372 test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf
, &[
373 0, 6, 2, 6, 3, 4, 7, 8, 2, 5,
375 test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf
, &[
376 0, 0, 0, 1, 0, 0, 2, 3, 0, 0,
378 test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf
, &[
379 2, 2, 1, 3, 2, 1, 3, 3, 2, 1,
384 /// Error type returned from `WeightedIndex::new`.
385 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
386 pub enum WeightedError
{
387 /// The provided weight collection contains no items.
390 /// A weight is either less than zero, greater than the supported maximum or
391 /// otherwise invalid.
394 /// All items in the provided weight collection are zero.
397 /// Too many weights are provided (length greater than `u32::MAX`)
401 #[cfg(feature = "std")]
402 impl ::std
::error
::Error
for WeightedError {}
404 impl fmt
::Display
for WeightedError
{
405 fn fmt(&self, f
: &mut fmt
::Formatter
) -> fmt
::Result
{
407 WeightedError
::NoItem
=> write
!(f
, "No weights provided."),
408 WeightedError
::InvalidWeight
=> write
!(f
, "A weight is invalid."),
409 WeightedError
::AllWeightsZero
=> write
!(f
, "All weights are zero."),
410 WeightedError
::TooMany
=> write
!(f
, "Too many weights (hit u32::MAX)"),