]>
Commit | Line | Data |
---|---|---|
0731742a XL |
1 | // Copyright 2018 Developers of the Rand project. |
2 | // | |
3 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | |
4 | // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | |
5 | // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | |
6 | // option. This file may not be copied, modified, or distributed | |
7 | // except according to those terms. | |
8 | ||
416331ca | 9 | //! Low-level API for sampling indices |
0731742a | 10 | |
dfeec247 | 11 | #[cfg(feature = "alloc")] use core::slice; |
0731742a | 12 | |
dfeec247 XL |
13 | #[cfg(all(feature = "alloc", not(feature = "std")))] |
14 | use crate::alloc::vec::{self, Vec}; | |
15 | #[cfg(feature = "std")] use std::vec; | |
0731742a | 16 | // BTreeMap is not as fast in tests, but better than nothing. |
dfeec247 XL |
17 | #[cfg(all(feature = "alloc", not(feature = "std")))] |
18 | use crate::alloc::collections::BTreeSet; | |
19 | #[cfg(feature = "std")] use std::collections::HashSet; | |
0731742a | 20 | |
dfeec247 XL |
21 | #[cfg(feature = "alloc")] |
22 | use crate::distributions::{uniform::SampleUniform, Distribution, Uniform}; | |
416331ca | 23 | use crate::Rng; |
0731742a XL |
24 | |
25 | /// A vector of indices. | |
26 | /// | |
27 | /// Multiple internal representations are possible. | |
28 | #[derive(Clone, Debug)] | |
29 | pub enum IndexVec { | |
dfeec247 XL |
30 | #[doc(hidden)] |
31 | U32(Vec<u32>), | |
32 | #[doc(hidden)] | |
33 | USize(Vec<usize>), | |
0731742a XL |
34 | } |
35 | ||
36 | impl IndexVec { | |
37 | /// Returns the number of indices | |
dfeec247 | 38 | #[inline] |
0731742a | 39 | pub fn len(&self) -> usize { |
dfeec247 XL |
40 | match *self { |
41 | IndexVec::U32(ref v) => v.len(), | |
42 | IndexVec::USize(ref v) => v.len(), | |
43 | } | |
44 | } | |
45 | ||
46 | /// Returns `true` if the length is 0. | |
47 | #[inline] | |
48 | pub fn is_empty(&self) -> bool { | |
49 | match *self { | |
50 | IndexVec::U32(ref v) => v.is_empty(), | |
51 | IndexVec::USize(ref v) => v.is_empty(), | |
0731742a XL |
52 | } |
53 | } | |
54 | ||
55 | /// Return the value at the given `index`. | |
56 | /// | |
416331ca | 57 | /// (Note: we cannot implement [`std::ops::Index`] because of lifetime |
0731742a | 58 | /// restrictions.) |
dfeec247 | 59 | #[inline] |
0731742a | 60 | pub fn index(&self, index: usize) -> usize { |
dfeec247 XL |
61 | match *self { |
62 | IndexVec::U32(ref v) => v[index] as usize, | |
63 | IndexVec::USize(ref v) => v[index], | |
0731742a XL |
64 | } |
65 | } | |
66 | ||
67 | /// Return result as a `Vec<usize>`. Conversion may or may not be trivial. | |
dfeec247 | 68 | #[inline] |
0731742a XL |
69 | pub fn into_vec(self) -> Vec<usize> { |
70 | match self { | |
71 | IndexVec::U32(v) => v.into_iter().map(|i| i as usize).collect(), | |
72 | IndexVec::USize(v) => v, | |
73 | } | |
74 | } | |
75 | ||
76 | /// Iterate over the indices as a sequence of `usize` values | |
dfeec247 XL |
77 | #[inline] |
78 | pub fn iter(&self) -> IndexVecIter<'_> { | |
79 | match *self { | |
80 | IndexVec::U32(ref v) => IndexVecIter::U32(v.iter()), | |
81 | IndexVec::USize(ref v) => IndexVecIter::USize(v.iter()), | |
0731742a XL |
82 | } |
83 | } | |
84 | ||
85 | /// Convert into an iterator over the indices as a sequence of `usize` values | |
dfeec247 | 86 | #[inline] |
0731742a XL |
87 | pub fn into_iter(self) -> IndexVecIntoIter { |
88 | match self { | |
89 | IndexVec::U32(v) => IndexVecIntoIter::U32(v.into_iter()), | |
90 | IndexVec::USize(v) => IndexVecIntoIter::USize(v.into_iter()), | |
91 | } | |
92 | } | |
93 | } | |
94 | ||
95 | impl PartialEq for IndexVec { | |
96 | fn eq(&self, other: &IndexVec) -> bool { | |
97 | use self::IndexVec::*; | |
98 | match (self, other) { | |
99 | (&U32(ref v1), &U32(ref v2)) => v1 == v2, | |
100 | (&USize(ref v1), &USize(ref v2)) => v1 == v2, | |
dfeec247 XL |
101 | (&U32(ref v1), &USize(ref v2)) => { |
102 | (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as usize == *y)) | |
103 | } | |
104 | (&USize(ref v1), &U32(ref v2)) => { | |
105 | (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as usize)) | |
106 | } | |
0731742a XL |
107 | } |
108 | } | |
109 | } | |
110 | ||
111 | impl From<Vec<u32>> for IndexVec { | |
dfeec247 | 112 | #[inline] |
0731742a XL |
113 | fn from(v: Vec<u32>) -> Self { |
114 | IndexVec::U32(v) | |
115 | } | |
116 | } | |
117 | ||
118 | impl From<Vec<usize>> for IndexVec { | |
dfeec247 | 119 | #[inline] |
0731742a XL |
120 | fn from(v: Vec<usize>) -> Self { |
121 | IndexVec::USize(v) | |
122 | } | |
123 | } | |
124 | ||
125 | /// Return type of `IndexVec::iter`. | |
126 | #[derive(Debug)] | |
127 | pub enum IndexVecIter<'a> { | |
dfeec247 XL |
128 | #[doc(hidden)] |
129 | U32(slice::Iter<'a, u32>), | |
130 | #[doc(hidden)] | |
131 | USize(slice::Iter<'a, usize>), | |
0731742a XL |
132 | } |
133 | ||
134 | impl<'a> Iterator for IndexVecIter<'a> { | |
135 | type Item = usize; | |
dfeec247 XL |
136 | |
137 | #[inline] | |
0731742a XL |
138 | fn next(&mut self) -> Option<usize> { |
139 | use self::IndexVecIter::*; | |
dfeec247 XL |
140 | match *self { |
141 | U32(ref mut iter) => iter.next().map(|i| *i as usize), | |
142 | USize(ref mut iter) => iter.next().cloned(), | |
0731742a XL |
143 | } |
144 | } | |
145 | ||
dfeec247 | 146 | #[inline] |
0731742a | 147 | fn size_hint(&self) -> (usize, Option<usize>) { |
dfeec247 XL |
148 | match *self { |
149 | IndexVecIter::U32(ref v) => v.size_hint(), | |
150 | IndexVecIter::USize(ref v) => v.size_hint(), | |
0731742a XL |
151 | } |
152 | } | |
153 | } | |
154 | ||
155 | impl<'a> ExactSizeIterator for IndexVecIter<'a> {} | |
156 | ||
157 | /// Return type of `IndexVec::into_iter`. | |
158 | #[derive(Clone, Debug)] | |
159 | pub enum IndexVecIntoIter { | |
dfeec247 XL |
160 | #[doc(hidden)] |
161 | U32(vec::IntoIter<u32>), | |
162 | #[doc(hidden)] | |
163 | USize(vec::IntoIter<usize>), | |
0731742a XL |
164 | } |
165 | ||
166 | impl Iterator for IndexVecIntoIter { | |
167 | type Item = usize; | |
168 | ||
dfeec247 | 169 | #[inline] |
0731742a XL |
170 | fn next(&mut self) -> Option<Self::Item> { |
171 | use self::IndexVecIntoIter::*; | |
dfeec247 XL |
172 | match *self { |
173 | U32(ref mut v) => v.next().map(|i| i as usize), | |
174 | USize(ref mut v) => v.next(), | |
0731742a XL |
175 | } |
176 | } | |
177 | ||
dfeec247 | 178 | #[inline] |
0731742a XL |
179 | fn size_hint(&self) -> (usize, Option<usize>) { |
180 | use self::IndexVecIntoIter::*; | |
dfeec247 XL |
181 | match *self { |
182 | U32(ref v) => v.size_hint(), | |
183 | USize(ref v) => v.size_hint(), | |
0731742a XL |
184 | } |
185 | } | |
186 | } | |
187 | ||
188 | impl ExactSizeIterator for IndexVecIntoIter {} | |
189 | ||
190 | ||
191 | /// Randomly sample exactly `amount` distinct indices from `0..length`, and | |
192 | /// return them in random order (fully shuffled). | |
193 | /// | |
194 | /// This method is used internally by the slice sampling methods, but it can | |
195 | /// sometimes be useful to have the indices themselves so this is provided as | |
196 | /// an alternative. | |
197 | /// | |
198 | /// The implementation used is not specified; we automatically select the | |
199 | /// fastest available algorithm for the `length` and `amount` parameters | |
200 | /// (based on detailed profiling on an Intel Haswell CPU). Roughly speaking, | |
201 | /// complexity is `O(amount)`, except that when `amount` is small, performance | |
202 | /// is closer to `O(amount^2)`, and when `length` is close to `amount` then | |
203 | /// `O(length)`. | |
204 | /// | |
205 | /// Note that performance is significantly better over `u32` indices than over | |
206 | /// `u64` indices. Because of this we hide the underlying type behind an | |
207 | /// abstraction, `IndexVec`. | |
416331ca | 208 | /// |
0731742a XL |
209 | /// If an allocation-free `no_std` function is required, it is suggested |
210 | /// to adapt the internal `sample_floyd` implementation. | |
211 | /// | |
212 | /// Panics if `amount > length`. | |
213 | pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec | |
416331ca | 214 | where R: Rng + ?Sized { |
0731742a XL |
215 | if amount > length { |
216 | panic!("`amount` of samples must be less than or equal to `length`"); | |
217 | } | |
218 | if length > (::core::u32::MAX as usize) { | |
219 | // We never want to use inplace here, but could use floyd's alg | |
220 | // Lazy version: always use the cache alg. | |
221 | return sample_rejection(rng, length, amount); | |
222 | } | |
223 | let amount = amount as u32; | |
224 | let length = length as u32; | |
225 | ||
226 | // Choice of algorithm here depends on both length and amount. See: | |
227 | // https://github.com/rust-random/rand/pull/479 | |
228 | // We do some calculations with f32. Accuracy is not very important. | |
229 | ||
230 | if amount < 163 { | |
dfeec247 | 231 | const C: [[f32; 2]; 2] = [[1.6, 8.0 / 45.0], [10.0, 70.0 / 9.0]]; |
0731742a XL |
232 | let j = if length < 500_000 { 0 } else { 1 }; |
233 | let amount_fp = amount as f32; | |
234 | let m4 = C[0][j] * amount_fp; | |
235 | // Short-cut: when amount < 12, floyd's is always faster | |
236 | if amount > 11 && (length as f32) < (C[1][j] + m4) * amount_fp { | |
237 | sample_inplace(rng, length, amount) | |
238 | } else { | |
239 | sample_floyd(rng, length, amount) | |
240 | } | |
241 | } else { | |
dfeec247 | 242 | const C: [f32; 2] = [270.0, 330.0 / 9.0]; |
0731742a XL |
243 | let j = if length < 500_000 { 0 } else { 1 }; |
244 | if (length as f32) < C[j] * (amount as f32) { | |
245 | sample_inplace(rng, length, amount) | |
246 | } else { | |
416331ca | 247 | sample_rejection(rng, length, amount) |
0731742a XL |
248 | } |
249 | } | |
250 | } | |
251 | ||
252 | /// Randomly sample exactly `amount` indices from `0..length`, using Floyd's | |
253 | /// combination algorithm. | |
254 | /// | |
255 | /// The output values are fully shuffled. (Overhead is under 50%.) | |
256 | /// | |
257 | /// This implementation uses `O(amount)` memory and `O(amount^2)` time. | |
258 | fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec | |
416331ca | 259 | where R: Rng + ?Sized { |
0731742a XL |
260 | // For small amount we use Floyd's fully-shuffled variant. For larger |
261 | // amounts this is slow due to Vec::insert performance, so we shuffle | |
262 | // afterwards. Benchmarks show little overhead from extra logic. | |
263 | let floyd_shuffle = amount < 50; | |
264 | ||
265 | debug_assert!(amount <= length); | |
266 | let mut indices = Vec::with_capacity(amount as usize); | |
dfeec247 | 267 | for j in length - amount..length { |
0731742a XL |
268 | let t = rng.gen_range(0, j + 1); |
269 | if floyd_shuffle { | |
270 | if let Some(pos) = indices.iter().position(|&x| x == t) { | |
271 | indices.insert(pos, j); | |
272 | continue; | |
273 | } | |
dfeec247 XL |
274 | } else if indices.contains(&t) { |
275 | indices.push(j); | |
276 | continue; | |
0731742a XL |
277 | } |
278 | indices.push(t); | |
279 | } | |
280 | if !floyd_shuffle { | |
281 | // Reimplement SliceRandom::shuffle with smaller indices | |
282 | for i in (1..amount).rev() { | |
283 | // invariant: elements with index > i have been locked in place. | |
284 | indices.swap(i as usize, rng.gen_range(0, i + 1) as usize); | |
285 | } | |
286 | } | |
287 | IndexVec::from(indices) | |
288 | } | |
289 | ||
290 | /// Randomly sample exactly `amount` indices from `0..length`, using an inplace | |
291 | /// partial Fisher-Yates method. | |
292 | /// Sample an amount of indices using an inplace partial fisher yates method. | |
293 | /// | |
294 | /// This allocates the entire `length` of indices and randomizes only the first `amount`. | |
295 | /// It then truncates to `amount` and returns. | |
296 | /// | |
297 | /// This method is not appropriate for large `length` and potentially uses a lot | |
298 | /// of memory; because of this we only implement for `u32` index (which improves | |
299 | /// performance in all cases). | |
300 | /// | |
301 | /// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time. | |
302 | fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec | |
416331ca | 303 | where R: Rng + ?Sized { |
0731742a XL |
304 | debug_assert!(amount <= length); |
305 | let mut indices: Vec<u32> = Vec::with_capacity(length as usize); | |
306 | indices.extend(0..length); | |
307 | for i in 0..amount { | |
308 | let j: u32 = rng.gen_range(i, length); | |
309 | indices.swap(i as usize, j as usize); | |
310 | } | |
311 | indices.truncate(amount as usize); | |
312 | debug_assert_eq!(indices.len(), amount as usize); | |
313 | IndexVec::from(indices) | |
314 | } | |
315 | ||
416331ca XL |
316 | trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash { |
317 | fn zero() -> Self; | |
318 | fn as_usize(self) -> usize; | |
319 | } | |
320 | impl UInt for u32 { | |
dfeec247 XL |
321 | #[inline] |
322 | fn zero() -> Self { | |
323 | 0 | |
324 | } | |
325 | ||
326 | #[inline] | |
327 | fn as_usize(self) -> usize { | |
328 | self as usize | |
329 | } | |
416331ca XL |
330 | } |
331 | impl UInt for usize { | |
dfeec247 XL |
332 | #[inline] |
333 | fn zero() -> Self { | |
334 | 0 | |
335 | } | |
336 | ||
337 | #[inline] | |
338 | fn as_usize(self) -> usize { | |
339 | self | |
340 | } | |
416331ca XL |
341 | } |
342 | ||
0731742a XL |
343 | /// Randomly sample exactly `amount` indices from `0..length`, using rejection |
344 | /// sampling. | |
416331ca | 345 | /// |
0731742a XL |
346 | /// Since `amount <<< length` there is a low chance of a random sample in |
347 | /// `0..length` being a duplicate. We test for duplicates and resample where | |
348 | /// necessary. The algorithm is `O(amount)` time and memory. | |
dfeec247 | 349 | /// |
416331ca XL |
350 | /// This function is generic over X primarily so that results are value-stable |
351 | /// over 32-bit and 64-bit platforms. | |
352 | fn sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec | |
dfeec247 XL |
353 | where |
354 | R: Rng + ?Sized, | |
355 | IndexVec: From<Vec<X>>, | |
356 | { | |
0731742a | 357 | debug_assert!(amount < length); |
dfeec247 XL |
358 | #[cfg(feature = "std")] |
359 | let mut cache = HashSet::with_capacity(amount.as_usize()); | |
360 | #[cfg(not(feature = "std"))] | |
361 | let mut cache = BTreeSet::new(); | |
416331ca XL |
362 | let distr = Uniform::new(X::zero(), length); |
363 | let mut indices = Vec::with_capacity(amount.as_usize()); | |
364 | for _ in 0..amount.as_usize() { | |
0731742a XL |
365 | let mut pos = distr.sample(rng); |
366 | while !cache.insert(pos) { | |
367 | pos = distr.sample(rng); | |
368 | } | |
369 | indices.push(pos); | |
370 | } | |
371 | ||
416331ca | 372 | debug_assert_eq!(indices.len(), amount.as_usize()); |
0731742a XL |
373 | IndexVec::from(indices) |
374 | } | |
375 | ||
376 | #[cfg(test)] | |
377 | mod test { | |
378 | use super::*; | |
dfeec247 XL |
379 | #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::vec; |
380 | #[cfg(feature = "std")] use std::vec; | |
0731742a XL |
381 | |
382 | #[test] | |
383 | fn test_sample_boundaries() { | |
416331ca | 384 | let mut r = crate::test::rng(404); |
0731742a XL |
385 | |
386 | assert_eq!(sample_inplace(&mut r, 0, 0).len(), 0); | |
387 | assert_eq!(sample_inplace(&mut r, 1, 0).len(), 0); | |
388 | assert_eq!(sample_inplace(&mut r, 1, 1).into_vec(), vec![0]); | |
389 | ||
416331ca | 390 | assert_eq!(sample_rejection(&mut r, 1u32, 0).len(), 0); |
0731742a XL |
391 | |
392 | assert_eq!(sample_floyd(&mut r, 0, 0).len(), 0); | |
393 | assert_eq!(sample_floyd(&mut r, 1, 0).len(), 0); | |
394 | assert_eq!(sample_floyd(&mut r, 1, 1).into_vec(), vec![0]); | |
395 | ||
396 | // These algorithms should be fast with big numbers. Test average. | |
dfeec247 | 397 | let sum: usize = sample_rejection(&mut r, 1 << 25, 10u32).into_iter().sum(); |
0731742a XL |
398 | assert!(1 << 25 < sum && sum < (1 << 25) * 25); |
399 | ||
dfeec247 | 400 | let sum: usize = sample_floyd(&mut r, 1 << 25, 10).into_iter().sum(); |
0731742a XL |
401 | assert!(1 << 25 < sum && sum < (1 << 25) * 25); |
402 | } | |
403 | ||
404 | #[test] | |
dfeec247 | 405 | #[cfg_attr(miri, ignore)] // Miri is too slow |
0731742a | 406 | fn test_sample_alg() { |
416331ca | 407 | let seed_rng = crate::test::rng; |
0731742a XL |
408 | |
409 | // We can't test which algorithm is used directly, but Floyd's alg | |
410 | // should produce different results from the others. (Also, `inplace` | |
411 | // and `cached` currently use different sizes thus produce different results.) | |
412 | ||
413 | // A small length and relatively large amount should use inplace | |
414 | let (length, amount): (usize, usize) = (100, 50); | |
415 | let v1 = sample(&mut seed_rng(420), length, amount); | |
416 | let v2 = sample_inplace(&mut seed_rng(420), length as u32, amount as u32); | |
417 | assert!(v1.iter().all(|e| e < length)); | |
418 | assert_eq!(v1, v2); | |
419 | ||
420 | // Test Floyd's alg does produce different results | |
421 | let v3 = sample_floyd(&mut seed_rng(420), length as u32, amount as u32); | |
422 | assert!(v1 != v3); | |
423 | ||
424 | // A large length and small amount should use Floyd | |
dfeec247 | 425 | let (length, amount): (usize, usize) = (1 << 20, 50); |
0731742a XL |
426 | let v1 = sample(&mut seed_rng(421), length, amount); |
427 | let v2 = sample_floyd(&mut seed_rng(421), length as u32, amount as u32); | |
428 | assert!(v1.iter().all(|e| e < length)); | |
429 | assert_eq!(v1, v2); | |
430 | ||
431 | // A large length and larger amount should use cache | |
dfeec247 | 432 | let (length, amount): (usize, usize) = (1 << 20, 600); |
0731742a | 433 | let v1 = sample(&mut seed_rng(422), length, amount); |
416331ca | 434 | let v2 = sample_rejection(&mut seed_rng(422), length as u32, amount as u32); |
0731742a XL |
435 | assert!(v1.iter().all(|e| e < length)); |
436 | assert_eq!(v1, v2); | |
437 | } | |
438 | } |