]> git.proxmox.com Git - cargo.git/blob - vendor/rand/src/distributions/utils.rs
New upstream version 0.33.0
[cargo.git] / vendor / rand / src / distributions / utils.rs
1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 //! Math helper functions
10
11 #[cfg(feature="simd_support")]
12 use packed_simd::*;
13 #[cfg(feature="std")]
14 use distributions::ziggurat_tables;
15 #[cfg(feature="std")]
16 use Rng;
17
18
19 pub trait WideningMultiply<RHS = Self> {
20 type Output;
21
22 fn wmul(self, x: RHS) -> Self::Output;
23 }
24
25 macro_rules! wmul_impl {
26 ($ty:ty, $wide:ty, $shift:expr) => {
27 impl WideningMultiply for $ty {
28 type Output = ($ty, $ty);
29
30 #[inline(always)]
31 fn wmul(self, x: $ty) -> Self::Output {
32 let tmp = (self as $wide) * (x as $wide);
33 ((tmp >> $shift) as $ty, tmp as $ty)
34 }
35 }
36 };
37
38 // simd bulk implementation
39 ($(($ty:ident, $wide:ident),)+, $shift:expr) => {
40 $(
41 impl WideningMultiply for $ty {
42 type Output = ($ty, $ty);
43
44 #[inline(always)]
45 fn wmul(self, x: $ty) -> Self::Output {
46 // For supported vectors, this should compile to a couple
47 // supported multiply & swizzle instructions (no actual
48 // casting).
49 // TODO: optimize
50 let y: $wide = self.cast();
51 let x: $wide = x.cast();
52 let tmp = y * x;
53 let hi: $ty = (tmp >> $shift).cast();
54 let lo: $ty = tmp.cast();
55 (hi, lo)
56 }
57 }
58 )+
59 };
60 }
61 wmul_impl! { u8, u16, 8 }
62 wmul_impl! { u16, u32, 16 }
63 wmul_impl! { u32, u64, 32 }
64 #[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
65 wmul_impl! { u64, u128, 64 }
66
67 // This code is a translation of the __mulddi3 function in LLVM's
68 // compiler-rt. It is an optimised variant of the common method
69 // `(a + b) * (c + d) = ac + ad + bc + bd`.
70 //
71 // For some reason LLVM can optimise the C version very well, but
72 // keeps shuffling registers in this Rust translation.
73 macro_rules! wmul_impl_large {
74 ($ty:ty, $half:expr) => {
75 impl WideningMultiply for $ty {
76 type Output = ($ty, $ty);
77
78 #[inline(always)]
79 fn wmul(self, b: $ty) -> Self::Output {
80 const LOWER_MASK: $ty = !0 >> $half;
81 let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
82 let mut t = low >> $half;
83 low &= LOWER_MASK;
84 t += (self >> $half).wrapping_mul(b & LOWER_MASK);
85 low += (t & LOWER_MASK) << $half;
86 let mut high = t >> $half;
87 t = low >> $half;
88 low &= LOWER_MASK;
89 t += (b >> $half).wrapping_mul(self & LOWER_MASK);
90 low += (t & LOWER_MASK) << $half;
91 high += t >> $half;
92 high += (self >> $half).wrapping_mul(b >> $half);
93
94 (high, low)
95 }
96 }
97 };
98
99 // simd bulk implementation
100 (($($ty:ty,)+) $scalar:ty, $half:expr) => {
101 $(
102 impl WideningMultiply for $ty {
103 type Output = ($ty, $ty);
104
105 #[inline(always)]
106 fn wmul(self, b: $ty) -> Self::Output {
107 // needs wrapping multiplication
108 const LOWER_MASK: $scalar = !0 >> $half;
109 let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
110 let mut t = low >> $half;
111 low &= LOWER_MASK;
112 t += (self >> $half) * (b & LOWER_MASK);
113 low += (t & LOWER_MASK) << $half;
114 let mut high = t >> $half;
115 t = low >> $half;
116 low &= LOWER_MASK;
117 t += (b >> $half) * (self & LOWER_MASK);
118 low += (t & LOWER_MASK) << $half;
119 high += t >> $half;
120 high += (self >> $half) * (b >> $half);
121
122 (high, low)
123 }
124 }
125 )+
126 };
127 }
128 #[cfg(not(all(rustc_1_26, not(target_os = "emscripten"))))]
129 wmul_impl_large! { u64, 32 }
130 #[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
131 wmul_impl_large! { u128, 64 }
132
133 macro_rules! wmul_impl_usize {
134 ($ty:ty) => {
135 impl WideningMultiply for usize {
136 type Output = (usize, usize);
137
138 #[inline(always)]
139 fn wmul(self, x: usize) -> Self::Output {
140 let (high, low) = (self as $ty).wmul(x as $ty);
141 (high as usize, low as usize)
142 }
143 }
144 }
145 }
146 #[cfg(target_pointer_width = "32")]
147 wmul_impl_usize! { u32 }
148 #[cfg(target_pointer_width = "64")]
149 wmul_impl_usize! { u64 }
150
151 #[cfg(all(feature = "simd_support", feature = "nightly"))]
152 mod simd_wmul {
153 #[cfg(target_arch = "x86")]
154 use core::arch::x86::*;
155 #[cfg(target_arch = "x86_64")]
156 use core::arch::x86_64::*;
157 use super::*;
158
159 wmul_impl! {
160 (u8x2, u16x2),
161 (u8x4, u16x4),
162 (u8x8, u16x8),
163 (u8x16, u16x16),
164 (u8x32, u16x32),,
165 8
166 }
167
168 wmul_impl! { (u16x2, u32x2),, 16 }
169 #[cfg(not(target_feature = "sse2"))]
170 wmul_impl! { (u16x4, u32x4),, 16 }
171 #[cfg(not(target_feature = "sse4.2"))]
172 wmul_impl! { (u16x8, u32x8),, 16 }
173 #[cfg(not(target_feature = "avx2"))]
174 wmul_impl! { (u16x16, u32x16),, 16 }
175
176 // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
177 // means `wmul` can be implemented with only two instructions.
178 #[allow(unused_macros)]
179 macro_rules! wmul_impl_16 {
180 ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => {
181 impl WideningMultiply for $ty {
182 type Output = ($ty, $ty);
183
184 #[inline(always)]
185 fn wmul(self, x: $ty) -> Self::Output {
186 let b = $intrinsic::from_bits(x);
187 let a = $intrinsic::from_bits(self);
188 let hi = $ty::from_bits(unsafe { $mulhi(a, b) });
189 let lo = $ty::from_bits(unsafe { $mullo(a, b) });
190 (hi, lo)
191 }
192 }
193 };
194 }
195
196 #[cfg(target_feature = "sse2")]
197 wmul_impl_16! { u16x4, __m64, _mm_mulhi_pu16, _mm_mullo_pi16 }
198 #[cfg(target_feature = "sse4.2")]
199 wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 }
200 #[cfg(target_feature = "avx2")]
201 wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
202 // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>`
203 // cannot use the same implementation.
204
205 wmul_impl! {
206 (u32x2, u64x2),
207 (u32x4, u64x4),
208 (u32x8, u64x8),,
209 32
210 }
211
212 // TODO: optimize, this seems to seriously slow things down
213 wmul_impl_large! { (u8x64,) u8, 4 }
214 wmul_impl_large! { (u16x32,) u16, 8 }
215 wmul_impl_large! { (u32x16,) u32, 16 }
216 wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
217 }
218 #[cfg(all(feature = "simd_support", feature = "nightly"))]
219 pub use self::simd_wmul::*;
220
221
222 /// Helper trait when dealing with scalar and SIMD floating point types.
223 pub(crate) trait FloatSIMDUtils {
224 // `PartialOrd` for vectors compares lexicographically. We want to compare all
225 // the individual SIMD lanes instead, and get the combined result over all
226 // lanes. This is possible using something like `a.lt(b).all()`, but we
227 // implement it as a trait so we can write the same code for `f32` and `f64`.
228 // Only the comparison functions we need are implemented.
229 fn all_lt(self, other: Self) -> bool;
230 fn all_le(self, other: Self) -> bool;
231 fn all_finite(self) -> bool;
232
233 type Mask;
234 fn finite_mask(self) -> Self::Mask;
235 fn gt_mask(self, other: Self) -> Self::Mask;
236 fn ge_mask(self, other: Self) -> Self::Mask;
237
238 // Decrease all lanes where the mask is `true` to the next lower value
239 // representable by the floating-point type. At least one of the lanes
240 // must be set.
241 fn decrease_masked(self, mask: Self::Mask) -> Self;
242
243 // Convert from int value. Conversion is done while retaining the numerical
244 // value, not by retaining the binary representation.
245 type UInt;
246 fn cast_from_int(i: Self::UInt) -> Self;
247 }
248
249 /// Implement functions available in std builds but missing from core primitives
250 #[cfg(not(std))]
251 pub(crate) trait Float : Sized {
252 type Bits;
253
254 fn is_nan(self) -> bool;
255 fn is_infinite(self) -> bool;
256 fn is_finite(self) -> bool;
257 fn to_bits(self) -> Self::Bits;
258 fn from_bits(v: Self::Bits) -> Self;
259 }
260
261 /// Implement functions on f32/f64 to give them APIs similar to SIMD types
262 pub(crate) trait FloatAsSIMD : Sized {
263 #[inline(always)]
264 fn lanes() -> usize { 1 }
265 #[inline(always)]
266 fn splat(scalar: Self) -> Self { scalar }
267 #[inline(always)]
268 fn extract(self, index: usize) -> Self { debug_assert_eq!(index, 0); self }
269 #[inline(always)]
270 fn replace(self, index: usize, new_value: Self) -> Self { debug_assert_eq!(index, 0); new_value }
271 }
272
273 pub(crate) trait BoolAsSIMD : Sized {
274 fn any(self) -> bool;
275 fn all(self) -> bool;
276 fn none(self) -> bool;
277 }
278
279 impl BoolAsSIMD for bool {
280 #[inline(always)]
281 fn any(self) -> bool { self }
282 #[inline(always)]
283 fn all(self) -> bool { self }
284 #[inline(always)]
285 fn none(self) -> bool { !self }
286 }
287
288 macro_rules! scalar_float_impl {
289 ($ty:ident, $uty:ident) => {
290 #[cfg(not(std))]
291 impl Float for $ty {
292 type Bits = $uty;
293
294 #[inline]
295 fn is_nan(self) -> bool {
296 self != self
297 }
298
299 #[inline]
300 fn is_infinite(self) -> bool {
301 self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY
302 }
303
304 #[inline]
305 fn is_finite(self) -> bool {
306 !(self.is_nan() || self.is_infinite())
307 }
308
309 #[inline]
310 fn to_bits(self) -> Self::Bits {
311 unsafe { ::core::mem::transmute(self) }
312 }
313
314 #[inline]
315 fn from_bits(v: Self::Bits) -> Self {
316 // It turns out the safety issues with sNaN were overblown! Hooray!
317 unsafe { ::core::mem::transmute(v) }
318 }
319 }
320
321 impl FloatSIMDUtils for $ty {
322 type Mask = bool;
323 #[inline(always)]
324 fn all_lt(self, other: Self) -> bool { self < other }
325 #[inline(always)]
326 fn all_le(self, other: Self) -> bool { self <= other }
327 #[inline(always)]
328 fn all_finite(self) -> bool { self.is_finite() }
329 #[inline(always)]
330 fn finite_mask(self) -> Self::Mask { self.is_finite() }
331 #[inline(always)]
332 fn gt_mask(self, other: Self) -> Self::Mask { self > other }
333 #[inline(always)]
334 fn ge_mask(self, other: Self) -> Self::Mask { self >= other }
335 #[inline(always)]
336 fn decrease_masked(self, mask: Self::Mask) -> Self {
337 debug_assert!(mask, "At least one lane must be set");
338 <$ty>::from_bits(self.to_bits() - 1)
339 }
340 type UInt = $uty;
341 fn cast_from_int(i: Self::UInt) -> Self { i as $ty }
342 }
343
344 impl FloatAsSIMD for $ty {}
345 }
346 }
347
348 scalar_float_impl!(f32, u32);
349 scalar_float_impl!(f64, u64);
350
351
352 #[cfg(feature="simd_support")]
353 macro_rules! simd_impl {
354 ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => {
355 impl FloatSIMDUtils for $ty {
356 type Mask = $mty;
357 #[inline(always)]
358 fn all_lt(self, other: Self) -> bool { self.lt(other).all() }
359 #[inline(always)]
360 fn all_le(self, other: Self) -> bool { self.le(other).all() }
361 #[inline(always)]
362 fn all_finite(self) -> bool { self.finite_mask().all() }
363 #[inline(always)]
364 fn finite_mask(self) -> Self::Mask {
365 // This can possibly be done faster by checking bit patterns
366 let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY);
367 let pos_inf = $ty::splat(::core::$f_scalar::INFINITY);
368 self.gt(neg_inf) & self.lt(pos_inf)
369 }
370 #[inline(always)]
371 fn gt_mask(self, other: Self) -> Self::Mask { self.gt(other) }
372 #[inline(always)]
373 fn ge_mask(self, other: Self) -> Self::Mask { self.ge(other) }
374 #[inline(always)]
375 fn decrease_masked(self, mask: Self::Mask) -> Self {
376 // Casting a mask into ints will produce all bits set for
377 // true, and 0 for false. Adding that to the binary
378 // representation of a float means subtracting one from
379 // the binary representation, resulting in the next lower
380 // value representable by $ty. This works even when the
381 // current value is infinity.
382 debug_assert!(mask.any(), "At least one lane must be set");
383 <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask))
384 }
385 type UInt = $uty;
386 fn cast_from_int(i: Self::UInt) -> Self { i.cast() }
387 }
388 }
389 }
390
391 #[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 }
392 #[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 }
393 #[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 }
394 #[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 }
395 #[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 }
396 #[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 }
397 #[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 }
398
399 /// Calculates ln(gamma(x)) (natural logarithm of the gamma
400 /// function) using the Lanczos approximation.
401 ///
402 /// The approximation expresses the gamma function as:
403 /// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)`
404 /// `g` is an arbitrary constant; we use the approximation with `g=5`.
405 ///
406 /// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides:
407 /// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)`
408 ///
409 /// `Ag(z)` is an infinite series with coefficients that can be calculated
410 /// ahead of time - we use just the first 6 terms, which is good enough
411 /// for most purposes.
412 #[cfg(feature="std")]
413 pub fn log_gamma(x: f64) -> f64 {
414 // precalculated 6 coefficients for the first 6 terms of the series
415 let coefficients: [f64; 6] = [
416 76.18009172947146,
417 -86.50532032941677,
418 24.01409824083091,
419 -1.231739572450155,
420 0.1208650973866179e-2,
421 -0.5395239384953e-5,
422 ];
423
424 // (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
425 let tmp = x + 5.5;
426 let log = (x + 0.5) * tmp.ln() - tmp;
427
428 // the first few terms of the series for Ag(x)
429 let mut a = 1.000000000190015;
430 let mut denom = x;
431 for coeff in &coefficients {
432 denom += 1.0;
433 a += coeff / denom;
434 }
435
436 // get everything together
437 // a is Ag(x)
438 // 2.5066... is sqrt(2pi)
439 log + (2.5066282746310005 * a / x).ln()
440 }
441
442 /// Sample a random number using the Ziggurat method (specifically the
443 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
444 /// directly from the paper:
445 ///
446 /// * `rng`: source of randomness
447 /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
448 /// * `X`: the $x_i$ abscissae.
449 /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
450 /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
451 /// * `pdf`: the probability density function
452 /// * `zero_case`: manual sampling from the tail when we chose the
453 /// bottom box (i.e. i == 0)
454
455 // the perf improvement (25-50%) is definitely worth the extra code
456 // size from force-inlining.
457 #[cfg(feature="std")]
458 #[inline(always)]
459 pub fn ziggurat<R: Rng + ?Sized, P, Z>(
460 rng: &mut R,
461 symmetric: bool,
462 x_tab: ziggurat_tables::ZigTable,
463 f_tab: ziggurat_tables::ZigTable,
464 mut pdf: P,
465 mut zero_case: Z)
466 -> f64 where P: FnMut(f64) -> f64, Z: FnMut(&mut R, f64) -> f64 {
467 use distributions::float::IntoFloat;
468 loop {
469 // As an optimisation we re-implement the conversion to a f64.
470 // From the remaining 12 most significant bits we use 8 to construct `i`.
471 // This saves us generating a whole extra random number, while the added
472 // precision of using 64 bits for f64 does not buy us much.
473 let bits = rng.next_u64();
474 let i = bits as usize & 0xff;
475
476 let u = if symmetric {
477 // Convert to a value in the range [2,4) and substract to get [-1,1)
478 // We can't convert to an open range directly, that would require
479 // substracting `3.0 - EPSILON`, which is not representable.
480 // It is possible with an extra step, but an open range does not
481 // seem neccesary for the ziggurat algorithm anyway.
482 (bits >> 12).into_float_with_exponent(1) - 3.0
483 } else {
484 // Convert to a value in the range [1,2) and substract to get (0,1)
485 (bits >> 12).into_float_with_exponent(0)
486 - (1.0 - ::core::f64::EPSILON / 2.0)
487 };
488 let x = u * x_tab[i];
489
490 let test_x = if symmetric { x.abs() } else {x};
491
492 // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
493 if test_x < x_tab[i + 1] {
494 return x;
495 }
496 if i == 0 {
497 return zero_case(rng, u);
498 }
499 // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
500 if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
501 return x;
502 }
503 }
504 }