1 // Copyright 2018 Developers of the Rand project.
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
9 //! Weighted index sampling
11 //! This module provides two implementations for sampling indices:
13 //! * [`WeightedIndex`] allows `O(log N)` sampling
14 //! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with
15 //! much greater set-up cost
17 //! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html
22 use crate::distributions
::Distribution
;
23 use crate::distributions
::uniform
::{UniformSampler, SampleUniform, SampleBorrow}
;
24 use core
::cmp
::PartialOrd
;
27 // Note that this whole module is only imported if feature="alloc" is enabled.
28 #[cfg(not(feature="std"))] use crate::alloc::vec::Vec;
30 /// A distribution using weighted sampling to pick a discretely selected
33 /// Sampling a `WeightedIndex` distribution returns the index of a randomly
34 /// selected element from the iterator used when the `WeightedIndex` was
35 /// created. The chance of a given element being picked is proportional to the
36 /// value of the element. The weights can use any type `X` for which an
37 /// implementation of [`Uniform<X>`] exists.
41 /// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
42 /// size is the sum of the size of those objects, possibly plus some alignment.
44 /// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
45 /// weights of type `X`, where `N` is the number of weights. However, since
46 /// `Vec` doesn't guarantee a particular growth strategy, additional memory
47 /// might be allocated but not used. Since the `WeightedIndex` object also
48 /// contains, this might cause additional allocations, though for primitive
49 /// types, ['Uniform<X>`] doesn't allocate any memory.
51 /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
52 /// `N` is the number of weights.
54 /// Sampling from `WeightedIndex` will result in a single call to
55 /// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
56 /// will request a single value from the underlying [`RngCore`], though the
57 /// exact number depends on the implementaiton of `Uniform<X>::sample`.
62 /// use rand::prelude::*;
63 /// use rand::distributions::WeightedIndex;
65 /// let choices = ['a', 'b', 'c'];
66 /// let weights = [2, 1, 1];
67 /// let dist = WeightedIndex::new(&weights).unwrap();
68 /// let mut rng = thread_rng();
70 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
71 /// println!("{}", choices[dist.sample(&mut rng)]);
74 /// let items = [('a', 0), ('b', 3), ('c', 7)];
75 /// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
77 /// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
78 /// println!("{}", items[dist2.sample(&mut rng)].0);
82 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
83 /// [`RngCore`]: crate::RngCore
84 #[derive(Debug, Clone)]
85 pub struct WeightedIndex
<X
: SampleUniform
+ PartialOrd
> {
86 cumulative_weights
: Vec
<X
>,
87 weight_distribution
: X
::Sampler
,
90 impl<X
: SampleUniform
+ PartialOrd
> WeightedIndex
<X
> {
91 /// Creates a new a `WeightedIndex` [`Distribution`] using the values
92 /// in `weights`. The weights can use any type `X` for which an
93 /// implementation of [`Uniform<X>`] exists.
95 /// Returns an error if the iterator is empty, if any weight is `< 0`, or
96 /// if its total value is 0.
98 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
99 pub fn new
<I
>(weights
: I
) -> Result
<WeightedIndex
<X
>, WeightedError
>
100 where I
: IntoIterator
,
101 I
::Item
: SampleBorrow
<X
>,
102 X
: for<'a
> ::core
::ops
::AddAssign
<&'a X
> +
105 let mut iter
= weights
.into_iter();
106 let mut total_weight
: X
= iter
.next()
107 .ok_or(WeightedError
::NoItem
)?
111 let zero
= <X
as Default
>::default();
112 if total_weight
< zero
{
113 return Err(WeightedError
::InvalidWeight
);
116 let mut weights
= Vec
::<X
>::with_capacity(iter
.size_hint().0);
118 if *w
.borrow() < zero
{
119 return Err(WeightedError
::InvalidWeight
);
121 weights
.push(total_weight
.clone());
122 total_weight
+= w
.borrow();
125 if total_weight
== zero
{
126 return Err(WeightedError
::AllWeightsZero
);
128 let distr
= X
::Sampler
::new(zero
, total_weight
);
130 Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }
)
134 impl<X
> Distribution
<usize> for WeightedIndex
<X
> where
135 X
: SampleUniform
+ PartialOrd
{
136 fn sample
<R
: Rng
+ ?Sized
>(&self, rng
: &mut R
) -> usize {
137 use ::core
::cmp
::Ordering
;
138 let chosen_weight
= self.weight_distribution
.sample(rng
);
139 // Find the first item which has a weight *higher* than the chosen weight.
140 self.cumulative_weights
.binary_search_by(
141 |w
| if *w
<= chosen_weight { Ordering::Less }
else { Ordering::Greater }
).unwrap_err()
150 #[cfg(not(miri))] // Miri is too slow
151 fn test_weightedindex() {
152 let mut r
= crate::test
::rng(700);
153 const N_REPS
: u32 = 5000;
154 let weights
= [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
155 let total_weight
= weights
.iter().sum
::<u32>() as f32;
157 let verify
= |result
: [i32; 14]| {
158 for (i
, count
) in result
.iter().enumerate() {
159 let exp
= (weights
[i
] * N_REPS
) as f32 / total_weight
;
160 let mut err
= (*count
as f32 - exp
).abs();
164 assert
!(err
<= 0.25);
168 // WeightedIndex from vec
169 let mut chosen
= [0i32; 14];
170 let distr
= WeightedIndex
::new(weights
.to_vec()).unwrap();
172 chosen
[distr
.sample(&mut r
)] += 1;
176 // WeightedIndex from slice
178 let distr
= WeightedIndex
::new(&weights
[..]).unwrap();
180 chosen
[distr
.sample(&mut r
)] += 1;
184 // WeightedIndex from iterator
186 let distr
= WeightedIndex
::new(weights
.iter()).unwrap();
188 chosen
[distr
.sample(&mut r
)] += 1;
193 assert_eq
!(WeightedIndex
::new(&[0, 1]).unwrap().sample(&mut r
), 1);
194 assert_eq
!(WeightedIndex
::new(&[1, 0]).unwrap().sample(&mut r
), 0);
195 assert_eq
!(WeightedIndex
::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r
), 4);
198 assert_eq
!(WeightedIndex
::new(&[10][0..0]).unwrap_err(), WeightedError
::NoItem
);
199 assert_eq
!(WeightedIndex
::new(&[0]).unwrap_err(), WeightedError
::AllWeightsZero
);
200 assert_eq
!(WeightedIndex
::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError
::InvalidWeight
);
201 assert_eq
!(WeightedIndex
::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError
::InvalidWeight
);
202 assert_eq
!(WeightedIndex
::new(&[-10]).unwrap_err(), WeightedError
::InvalidWeight
);
206 /// Error type returned from `WeightedIndex::new`.
207 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
208 pub enum WeightedError
{
209 /// The provided weight collection contains no items.
212 /// A weight is either less than zero, greater than the supported maximum or
213 /// otherwise invalid.
216 /// All items in the provided weight collection are zero.
219 /// Too many weights are provided (length greater than `u32::MAX`)
224 fn msg(&self) -> &str {
226 WeightedError
::NoItem
=> "No weights provided.",
227 WeightedError
::InvalidWeight
=> "A weight is invalid.",
228 WeightedError
::AllWeightsZero
=> "All weights are zero.",
229 WeightedError
::TooMany
=> "Too many weights (hit u32::MAX)",
234 #[cfg(feature="std")]
235 impl ::std
::error
::Error
for WeightedError
{
236 fn description(&self) -> &str {
239 fn cause(&self) -> Option
<&dyn (::std
::error
::Error
)> {
244 impl fmt
::Display
for WeightedError
{
245 fn fmt(&self, f
: &mut fmt
::Formatter
) -> fmt
::Result
{
246 write
!(f
, "{}", self.msg())