]>
Commit | Line | Data |
---|---|---|
0531ce1d XL |
1 | // Copyright 2017 The Rust Project Developers. See the COPYRIGHT |
2 | // file at the top-level directory of this distribution and at | |
3 | // http://rust-lang.org/COPYRIGHT. | |
4 | // | |
5 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | |
6 | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | |
7 | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | |
8 | // option. This file may not be copied, modified, or distributed | |
9 | // except according to those terms. | |
10 | ||
11 | //! Functions for randomly accessing and sampling sequences. | |
12 | ||
13 | use super::Rng; | |
14 | ||
15 | // This crate is only enabled when either std or alloc is available. | |
16 | // BTreeMap is not as fast in tests, but better than nothing. | |
17 | #[cfg(feature="std")] use std::collections::HashMap; | |
18 | #[cfg(not(feature="std"))] use alloc::btree_map::BTreeMap; | |
19 | ||
20 | #[cfg(not(feature="std"))] use alloc::Vec; | |
21 | ||
22 | /// Randomly sample `amount` elements from a finite iterator. | |
23 | /// | |
24 | /// The following can be returned: | |
25 | /// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random. | |
26 | /// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the | |
27 | /// length of `iterable` was less than `amount`. This is considered an error since exactly | |
28 | /// `amount` elements is typically expected. | |
29 | /// | |
30 | /// This implementation uses `O(len(iterable))` time and `O(amount)` memory. | |
31 | /// | |
32 | /// # Example | |
33 | /// | |
34 | /// ```rust | |
35 | /// use rand::{thread_rng, seq}; | |
36 | /// | |
37 | /// let mut rng = thread_rng(); | |
38 | /// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap(); | |
39 | /// println!("{:?}", sample); | |
40 | /// ``` | |
41 | pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>> | |
42 | where I: IntoIterator<Item=T>, | |
43 | R: Rng, | |
44 | { | |
45 | let mut iter = iterable.into_iter(); | |
46 | let mut reservoir = Vec::with_capacity(amount); | |
47 | reservoir.extend(iter.by_ref().take(amount)); | |
48 | ||
49 | // Continue unless the iterator was exhausted | |
50 | // | |
51 | // note: this prevents iterators that "restart" from causing problems. | |
52 | // If the iterator stops once, then so do we. | |
53 | if reservoir.len() == amount { | |
54 | for (i, elem) in iter.enumerate() { | |
55 | let k = rng.gen_range(0, i + 1 + amount); | |
56 | if let Some(spot) = reservoir.get_mut(k) { | |
57 | *spot = elem; | |
58 | } | |
59 | } | |
60 | Ok(reservoir) | |
61 | } else { | |
62 | // Don't hang onto extra memory. There is a corner case where | |
63 | // `amount` was much less than `len(iterable)`. | |
64 | reservoir.shrink_to_fit(); | |
65 | Err(reservoir) | |
66 | } | |
67 | } | |
68 | ||
69 | /// Randomly sample exactly `amount` values from `slice`. | |
70 | /// | |
71 | /// The values are non-repeating and in random order. | |
72 | /// | |
73 | /// This implementation uses `O(amount)` time and memory. | |
74 | /// | |
75 | /// Panics if `amount > slice.len()` | |
76 | /// | |
77 | /// # Example | |
78 | /// | |
79 | /// ```rust | |
80 | /// use rand::{thread_rng, seq}; | |
81 | /// | |
82 | /// let mut rng = thread_rng(); | |
83 | /// let values = vec![5, 6, 1, 3, 4, 6, 7]; | |
84 | /// println!("{:?}", seq::sample_slice(&mut rng, &values, 3)); | |
85 | /// ``` | |
86 | pub fn sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T> | |
87 | where R: Rng, | |
88 | T: Clone | |
89 | { | |
90 | let indices = sample_indices(rng, slice.len(), amount); | |
91 | ||
92 | let mut out = Vec::with_capacity(amount); | |
93 | out.extend(indices.iter().map(|i| slice[*i].clone())); | |
94 | out | |
95 | } | |
96 | ||
97 | /// Randomly sample exactly `amount` references from `slice`. | |
98 | /// | |
99 | /// The references are non-repeating and in random order. | |
100 | /// | |
101 | /// This implementation uses `O(amount)` time and memory. | |
102 | /// | |
103 | /// Panics if `amount > slice.len()` | |
104 | /// | |
105 | /// # Example | |
106 | /// | |
107 | /// ```rust | |
108 | /// use rand::{thread_rng, seq}; | |
109 | /// | |
110 | /// let mut rng = thread_rng(); | |
111 | /// let values = vec![5, 6, 1, 3, 4, 6, 7]; | |
112 | /// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3)); | |
113 | /// ``` | |
114 | pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> | |
115 | where R: Rng | |
116 | { | |
117 | let indices = sample_indices(rng, slice.len(), amount); | |
118 | ||
119 | let mut out = Vec::with_capacity(amount); | |
120 | out.extend(indices.iter().map(|i| &slice[*i])); | |
121 | out | |
122 | } | |
123 | ||
124 | /// Randomly sample exactly `amount` indices from `0..length`. | |
125 | /// | |
126 | /// The values are non-repeating and in random order. | |
127 | /// | |
128 | /// This implementation uses `O(amount)` time and memory. | |
129 | /// | |
130 | /// This method is used internally by the slice sampling methods, but it can sometimes be useful to | |
131 | /// have the indices themselves so this is provided as an alternative. | |
132 | /// | |
133 | /// Panics if `amount > length` | |
134 | pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> | |
135 | where R: Rng, | |
136 | { | |
137 | if amount > length { | |
138 | panic!("`amount` must be less than or equal to `slice.len()`"); | |
139 | } | |
140 | ||
141 | // We are going to have to allocate at least `amount` for the output no matter what. However, | |
142 | // if we use the `cached` version we will have to allocate `amount` as a HashMap as well since | |
143 | // it inserts an element for every loop. | |
144 | // | |
145 | // Therefore, if `amount >= length / 2` then inplace will be both faster and use less memory. | |
146 | // In fact, benchmarks show the inplace version is faster for length up to about 20 times | |
147 | // faster than amount. | |
148 | // | |
149 | // TODO: there is probably even more fine-tuning that can be done here since | |
150 | // `HashMap::with_capacity(amount)` probably allocates more than `amount` in practice, | |
151 | // and a trade off could probably be made between memory/cpu, since hashmap operations | |
152 | // are slower than array index swapping. | |
153 | if amount >= length / 20 { | |
154 | sample_indices_inplace(rng, length, amount) | |
155 | } else { | |
156 | sample_indices_cache(rng, length, amount) | |
157 | } | |
158 | } | |
159 | ||
160 | /// Sample an amount of indices using an inplace partial fisher yates method. | |
161 | /// | |
162 | /// This allocates the entire `length` of indices and randomizes only the first `amount`. | |
163 | /// It then truncates to `amount` and returns. | |
164 | /// | |
165 | /// This is better than using a HashMap "cache" when `amount >= length / 2` since it does not | |
166 | /// require allocating an extra cache and is much faster. | |
167 | fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> | |
168 | where R: Rng, | |
169 | { | |
170 | debug_assert!(amount <= length); | |
171 | let mut indices: Vec<usize> = Vec::with_capacity(length); | |
172 | indices.extend(0..length); | |
173 | for i in 0..amount { | |
174 | let j: usize = rng.gen_range(i, length); | |
175 | let tmp = indices[i]; | |
176 | indices[i] = indices[j]; | |
177 | indices[j] = tmp; | |
178 | } | |
179 | indices.truncate(amount); | |
180 | debug_assert_eq!(indices.len(), amount); | |
181 | indices | |
182 | } | |
183 | ||
184 | ||
185 | /// This method performs a partial fisher-yates on a range of indices using a HashMap | |
186 | /// as a cache to record potential collisions. | |
187 | /// | |
188 | /// The cache avoids allocating the entire `length` of values. This is especially useful when | |
189 | /// `amount <<< length`, i.e. select 3 non-repeating from 1_000_000 | |
190 | fn sample_indices_cache<R>( | |
191 | rng: &mut R, | |
192 | length: usize, | |
193 | amount: usize, | |
194 | ) -> Vec<usize> | |
195 | where R: Rng, | |
196 | { | |
197 | debug_assert!(amount <= length); | |
198 | #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount); | |
199 | #[cfg(not(feature="std"))] let mut cache = BTreeMap::new(); | |
200 | let mut out = Vec::with_capacity(amount); | |
201 | for i in 0..amount { | |
202 | let j: usize = rng.gen_range(i, length); | |
203 | ||
204 | // equiv: let tmp = slice[i]; | |
205 | let tmp = match cache.get(&i) { | |
206 | Some(e) => *e, | |
207 | None => i, | |
208 | }; | |
209 | ||
210 | // equiv: slice[i] = slice[j]; | |
211 | let x = match cache.get(&j) { | |
212 | Some(x) => *x, | |
213 | None => j, | |
214 | }; | |
215 | ||
216 | // equiv: slice[j] = tmp; | |
217 | cache.insert(j, tmp); | |
218 | ||
219 | // note that in the inplace version, slice[i] is automatically "returned" value | |
220 | out.push(x); | |
221 | } | |
222 | debug_assert_eq!(out.len(), amount); | |
223 | out | |
224 | } | |
225 | ||
226 | #[cfg(test)] | |
227 | mod test { | |
228 | use super::*; | |
229 | use {thread_rng, XorShiftRng, SeedableRng}; | |
230 | ||
231 | #[test] | |
232 | fn test_sample_iter() { | |
233 | let min_val = 1; | |
234 | let max_val = 100; | |
235 | ||
236 | let mut r = thread_rng(); | |
237 | let vals = (min_val..max_val).collect::<Vec<i32>>(); | |
238 | let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap(); | |
239 | let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err(); | |
240 | ||
241 | assert_eq!(small_sample.len(), 5); | |
242 | assert_eq!(large_sample.len(), vals.len()); | |
243 | // no randomization happens when amount >= len | |
244 | assert_eq!(large_sample, vals.iter().collect::<Vec<_>>()); | |
245 | ||
246 | assert!(small_sample.iter().all(|e| { | |
247 | **e >= min_val && **e <= max_val | |
248 | })); | |
249 | } | |
250 | #[test] | |
251 | fn test_sample_slice_boundaries() { | |
252 | let empty: &[u8] = &[]; | |
253 | ||
254 | let mut r = thread_rng(); | |
255 | ||
256 | // sample 0 items | |
257 | assert_eq!(sample_slice(&mut r, empty, 0), vec![]); | |
258 | assert_eq!(sample_slice(&mut r, &[42, 2, 42], 0), vec![]); | |
259 | ||
260 | // sample 1 item | |
261 | assert_eq!(sample_slice(&mut r, &[42], 1), vec![42]); | |
262 | let v = sample_slice(&mut r, &[1, 42], 1)[0]; | |
263 | assert!(v == 1 || v == 42); | |
264 | ||
265 | // sample "all" the items | |
266 | let v = sample_slice(&mut r, &[42, 133], 2); | |
267 | assert!(v == vec![42, 133] || v == vec![133, 42]); | |
268 | ||
269 | assert_eq!(sample_indices_inplace(&mut r, 0, 0), vec![]); | |
270 | assert_eq!(sample_indices_inplace(&mut r, 1, 0), vec![]); | |
271 | assert_eq!(sample_indices_inplace(&mut r, 1, 1), vec![0]); | |
272 | ||
273 | assert_eq!(sample_indices_cache(&mut r, 0, 0), vec![]); | |
274 | assert_eq!(sample_indices_cache(&mut r, 1, 0), vec![]); | |
275 | assert_eq!(sample_indices_cache(&mut r, 1, 1), vec![0]); | |
276 | ||
277 | // Make sure lucky 777's aren't lucky | |
278 | let slice = &[42, 777]; | |
279 | let mut num_42 = 0; | |
280 | let total = 1000; | |
281 | for _ in 0..total { | |
282 | let v = sample_slice(&mut r, slice, 1); | |
283 | assert_eq!(v.len(), 1); | |
284 | let v = v[0]; | |
285 | assert!(v == 42 || v == 777); | |
286 | if v == 42 { | |
287 | num_42 += 1; | |
288 | } | |
289 | } | |
290 | let ratio_42 = num_42 as f64 / 1000 as f64; | |
291 | assert!(0.4 <= ratio_42 || ratio_42 <= 0.6, "{}", ratio_42); | |
292 | } | |
293 | ||
294 | #[test] | |
295 | fn test_sample_slice() { | |
296 | let xor_rng = XorShiftRng::from_seed; | |
297 | ||
298 | let max_range = 100; | |
299 | let mut r = thread_rng(); | |
300 | ||
301 | for length in 1usize..max_range { | |
302 | let amount = r.gen_range(0, length); | |
303 | let seed: [u32; 4] = [ | |
304 | r.next_u32(), r.next_u32(), r.next_u32(), r.next_u32() | |
305 | ]; | |
306 | ||
307 | println!("Selecting indices: len={}, amount={}, seed={:?}", length, amount, seed); | |
308 | ||
309 | // assert that the two index methods give exactly the same result | |
310 | let inplace = sample_indices_inplace( | |
311 | &mut xor_rng(seed), length, amount); | |
312 | let cache = sample_indices_cache( | |
313 | &mut xor_rng(seed), length, amount); | |
314 | assert_eq!(inplace, cache); | |
315 | ||
316 | // assert the basics work | |
317 | let regular = sample_indices( | |
318 | &mut xor_rng(seed), length, amount); | |
319 | assert_eq!(regular.len(), amount); | |
320 | assert!(regular.iter().all(|e| *e < length)); | |
321 | assert_eq!(regular, inplace); | |
322 | ||
323 | // also test that sampling the slice works | |
324 | let vec: Vec<usize> = (0..length).collect(); | |
325 | { | |
326 | let result = sample_slice(&mut xor_rng(seed), &vec, amount); | |
327 | assert_eq!(result, regular); | |
328 | } | |
329 | ||
330 | { | |
331 | let result = sample_slice_ref(&mut xor_rng(seed), &vec, amount); | |
332 | let expected = regular.iter().map(|v| v).collect::<Vec<_>>(); | |
333 | assert_eq!(result, expected); | |
334 | } | |
335 | } | |
336 | } | |
337 | } |