]> git.proxmox.com Git - rustc.git/blame - vendor/rand/src/distributions/utils.rs
New upstream version 1.51.0+dfsg1
[rustc.git] / vendor / rand / src / distributions / utils.rs
CommitLineData
0731742a
XL
1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Math helper functions
10
dfeec247
XL
11#[cfg(feature = "std")] use crate::distributions::ziggurat_tables;
12#[cfg(feature = "std")] use crate::Rng;
13#[cfg(feature = "simd_support")] use packed_simd::*;
0731742a
XL
14
15
16pub trait WideningMultiply<RHS = Self> {
17 type Output;
18
19 fn wmul(self, x: RHS) -> Self::Output;
20}
21
22macro_rules! wmul_impl {
23 ($ty:ty, $wide:ty, $shift:expr) => {
24 impl WideningMultiply for $ty {
25 type Output = ($ty, $ty);
26
27 #[inline(always)]
28 fn wmul(self, x: $ty) -> Self::Output {
29 let tmp = (self as $wide) * (x as $wide);
30 ((tmp >> $shift) as $ty, tmp as $ty)
31 }
32 }
33 };
34
35 // simd bulk implementation
36 ($(($ty:ident, $wide:ident),)+, $shift:expr) => {
37 $(
38 impl WideningMultiply for $ty {
39 type Output = ($ty, $ty);
40
41 #[inline(always)]
42 fn wmul(self, x: $ty) -> Self::Output {
43 // For supported vectors, this should compile to a couple
44 // supported multiply & swizzle instructions (no actual
45 // casting).
46 // TODO: optimize
47 let y: $wide = self.cast();
48 let x: $wide = x.cast();
49 let tmp = y * x;
50 let hi: $ty = (tmp >> $shift).cast();
51 let lo: $ty = tmp.cast();
52 (hi, lo)
53 }
54 }
55 )+
56 };
57}
58wmul_impl! { u8, u16, 8 }
59wmul_impl! { u16, u32, 16 }
60wmul_impl! { u32, u64, 32 }
416331ca 61#[cfg(not(target_os = "emscripten"))]
0731742a
XL
62wmul_impl! { u64, u128, 64 }
63
64// This code is a translation of the __mulddi3 function in LLVM's
65// compiler-rt. It is an optimised variant of the common method
66// `(a + b) * (c + d) = ac + ad + bc + bd`.
67//
68// For some reason LLVM can optimise the C version very well, but
69// keeps shuffling registers in this Rust translation.
70macro_rules! wmul_impl_large {
71 ($ty:ty, $half:expr) => {
72 impl WideningMultiply for $ty {
73 type Output = ($ty, $ty);
74
75 #[inline(always)]
76 fn wmul(self, b: $ty) -> Self::Output {
77 const LOWER_MASK: $ty = !0 >> $half;
78 let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
79 let mut t = low >> $half;
80 low &= LOWER_MASK;
81 t += (self >> $half).wrapping_mul(b & LOWER_MASK);
82 low += (t & LOWER_MASK) << $half;
83 let mut high = t >> $half;
84 t = low >> $half;
85 low &= LOWER_MASK;
86 t += (b >> $half).wrapping_mul(self & LOWER_MASK);
87 low += (t & LOWER_MASK) << $half;
88 high += t >> $half;
89 high += (self >> $half).wrapping_mul(b >> $half);
90
91 (high, low)
92 }
93 }
94 };
95
96 // simd bulk implementation
97 (($($ty:ty,)+) $scalar:ty, $half:expr) => {
98 $(
99 impl WideningMultiply for $ty {
100 type Output = ($ty, $ty);
101
102 #[inline(always)]
103 fn wmul(self, b: $ty) -> Self::Output {
104 // needs wrapping multiplication
105 const LOWER_MASK: $scalar = !0 >> $half;
106 let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
107 let mut t = low >> $half;
108 low &= LOWER_MASK;
109 t += (self >> $half) * (b & LOWER_MASK);
110 low += (t & LOWER_MASK) << $half;
111 let mut high = t >> $half;
112 t = low >> $half;
113 low &= LOWER_MASK;
114 t += (b >> $half) * (self & LOWER_MASK);
115 low += (t & LOWER_MASK) << $half;
116 high += t >> $half;
117 high += (self >> $half) * (b >> $half);
118
119 (high, low)
120 }
121 }
122 )+
123 };
124}
416331ca 125#[cfg(target_os = "emscripten")]
0731742a 126wmul_impl_large! { u64, 32 }
416331ca 127#[cfg(not(target_os = "emscripten"))]
0731742a
XL
128wmul_impl_large! { u128, 64 }
129
130macro_rules! wmul_impl_usize {
131 ($ty:ty) => {
132 impl WideningMultiply for usize {
133 type Output = (usize, usize);
134
135 #[inline(always)]
136 fn wmul(self, x: usize) -> Self::Output {
137 let (high, low) = (self as $ty).wmul(x as $ty);
138 (high as usize, low as usize)
139 }
140 }
dfeec247 141 };
0731742a
XL
142}
143#[cfg(target_pointer_width = "32")]
144wmul_impl_usize! { u32 }
145#[cfg(target_pointer_width = "64")]
146wmul_impl_usize! { u64 }
147
148#[cfg(all(feature = "simd_support", feature = "nightly"))]
149mod simd_wmul {
0731742a 150 use super::*;
dfeec247
XL
151 #[cfg(target_arch = "x86")] use core::arch::x86::*;
152 #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*;
0731742a
XL
153
154 wmul_impl! {
155 (u8x2, u16x2),
156 (u8x4, u16x4),
157 (u8x8, u16x8),
158 (u8x16, u16x16),
159 (u8x32, u16x32),,
160 8
161 }
162
163 wmul_impl! { (u16x2, u32x2),, 16 }
164 #[cfg(not(target_feature = "sse2"))]
165 wmul_impl! { (u16x4, u32x4),, 16 }
166 #[cfg(not(target_feature = "sse4.2"))]
167 wmul_impl! { (u16x8, u32x8),, 16 }
168 #[cfg(not(target_feature = "avx2"))]
169 wmul_impl! { (u16x16, u32x16),, 16 }
170
171 // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
172 // means `wmul` can be implemented with only two instructions.
173 #[allow(unused_macros)]
174 macro_rules! wmul_impl_16 {
175 ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => {
176 impl WideningMultiply for $ty {
177 type Output = ($ty, $ty);
178
179 #[inline(always)]
180 fn wmul(self, x: $ty) -> Self::Output {
181 let b = $intrinsic::from_bits(x);
182 let a = $intrinsic::from_bits(self);
183 let hi = $ty::from_bits(unsafe { $mulhi(a, b) });
184 let lo = $ty::from_bits(unsafe { $mullo(a, b) });
185 (hi, lo)
186 }
187 }
188 };
189 }
190
191 #[cfg(target_feature = "sse2")]
192 wmul_impl_16! { u16x4, __m64, _mm_mulhi_pu16, _mm_mullo_pi16 }
193 #[cfg(target_feature = "sse4.2")]
194 wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 }
195 #[cfg(target_feature = "avx2")]
196 wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
197 // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>`
198 // cannot use the same implementation.
199
200 wmul_impl! {
201 (u32x2, u64x2),
202 (u32x4, u64x4),
203 (u32x8, u64x8),,
204 32
205 }
206
207 // TODO: optimize, this seems to seriously slow things down
208 wmul_impl_large! { (u8x64,) u8, 4 }
209 wmul_impl_large! { (u16x32,) u16, 8 }
210 wmul_impl_large! { (u32x16,) u32, 16 }
211 wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
212}
213#[cfg(all(feature = "simd_support", feature = "nightly"))]
214pub use self::simd_wmul::*;
215
216
217/// Helper trait when dealing with scalar and SIMD floating point types.
218pub(crate) trait FloatSIMDUtils {
219 // `PartialOrd` for vectors compares lexicographically. We want to compare all
220 // the individual SIMD lanes instead, and get the combined result over all
221 // lanes. This is possible using something like `a.lt(b).all()`, but we
222 // implement it as a trait so we can write the same code for `f32` and `f64`.
223 // Only the comparison functions we need are implemented.
224 fn all_lt(self, other: Self) -> bool;
225 fn all_le(self, other: Self) -> bool;
226 fn all_finite(self) -> bool;
227
228 type Mask;
229 fn finite_mask(self) -> Self::Mask;
230 fn gt_mask(self, other: Self) -> Self::Mask;
231 fn ge_mask(self, other: Self) -> Self::Mask;
232
233 // Decrease all lanes where the mask is `true` to the next lower value
234 // representable by the floating-point type. At least one of the lanes
235 // must be set.
236 fn decrease_masked(self, mask: Self::Mask) -> Self;
237
238 // Convert from int value. Conversion is done while retaining the numerical
239 // value, not by retaining the binary representation.
240 type UInt;
241 fn cast_from_int(i: Self::UInt) -> Self;
242}
243
244/// Implement functions available in std builds but missing from core primitives
245#[cfg(not(std))]
dfeec247 246pub(crate) trait Float: Sized {
0731742a
XL
247 fn is_nan(self) -> bool;
248 fn is_infinite(self) -> bool;
249 fn is_finite(self) -> bool;
0731742a
XL
250}
251
252/// Implement functions on f32/f64 to give them APIs similar to SIMD types
dfeec247 253pub(crate) trait FloatAsSIMD: Sized {
0731742a 254 #[inline(always)]
dfeec247
XL
255 fn lanes() -> usize {
256 1
257 }
0731742a 258 #[inline(always)]
dfeec247
XL
259 fn splat(scalar: Self) -> Self {
260 scalar
261 }
0731742a 262 #[inline(always)]
dfeec247
XL
263 fn extract(self, index: usize) -> Self {
264 debug_assert_eq!(index, 0);
265 self
266 }
0731742a 267 #[inline(always)]
dfeec247
XL
268 fn replace(self, index: usize, new_value: Self) -> Self {
269 debug_assert_eq!(index, 0);
270 new_value
271 }
0731742a
XL
272}
273
dfeec247 274pub(crate) trait BoolAsSIMD: Sized {
0731742a
XL
275 fn any(self) -> bool;
276 fn all(self) -> bool;
277 fn none(self) -> bool;
278}
279
280impl BoolAsSIMD for bool {
281 #[inline(always)]
dfeec247
XL
282 fn any(self) -> bool {
283 self
284 }
285
0731742a 286 #[inline(always)]
dfeec247
XL
287 fn all(self) -> bool {
288 self
289 }
290
0731742a 291 #[inline(always)]
dfeec247
XL
292 fn none(self) -> bool {
293 !self
294 }
0731742a
XL
295}
296
297macro_rules! scalar_float_impl {
298 ($ty:ident, $uty:ident) => {
299 #[cfg(not(std))]
300 impl Float for $ty {
0731742a
XL
301 #[inline]
302 fn is_nan(self) -> bool {
303 self != self
304 }
305
306 #[inline]
307 fn is_infinite(self) -> bool {
308 self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY
309 }
310
311 #[inline]
312 fn is_finite(self) -> bool {
313 !(self.is_nan() || self.is_infinite())
314 }
0731742a
XL
315 }
316
317 impl FloatSIMDUtils for $ty {
318 type Mask = bool;
dfeec247
XL
319 type UInt = $uty;
320
0731742a 321 #[inline(always)]
dfeec247
XL
322 fn all_lt(self, other: Self) -> bool {
323 self < other
324 }
325
0731742a 326 #[inline(always)]
dfeec247
XL
327 fn all_le(self, other: Self) -> bool {
328 self <= other
329 }
330
0731742a 331 #[inline(always)]
dfeec247
XL
332 fn all_finite(self) -> bool {
333 self.is_finite()
334 }
335
0731742a 336 #[inline(always)]
dfeec247
XL
337 fn finite_mask(self) -> Self::Mask {
338 self.is_finite()
339 }
340
0731742a 341 #[inline(always)]
dfeec247
XL
342 fn gt_mask(self, other: Self) -> Self::Mask {
343 self > other
344 }
345
0731742a 346 #[inline(always)]
dfeec247
XL
347 fn ge_mask(self, other: Self) -> Self::Mask {
348 self >= other
349 }
350
0731742a
XL
351 #[inline(always)]
352 fn decrease_masked(self, mask: Self::Mask) -> Self {
353 debug_assert!(mask, "At least one lane must be set");
354 <$ty>::from_bits(self.to_bits() - 1)
355 }
dfeec247
XL
356
357 #[inline]
358 fn cast_from_int(i: Self::UInt) -> Self {
359 i as $ty
360 }
0731742a
XL
361 }
362
363 impl FloatAsSIMD for $ty {}
dfeec247 364 };
0731742a
XL
365}
366
367scalar_float_impl!(f32, u32);
368scalar_float_impl!(f64, u64);
369
370
dfeec247 371#[cfg(feature = "simd_support")]
0731742a
XL
372macro_rules! simd_impl {
373 ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => {
374 impl FloatSIMDUtils for $ty {
375 type Mask = $mty;
dfeec247
XL
376 type UInt = $uty;
377
0731742a 378 #[inline(always)]
dfeec247
XL
379 fn all_lt(self, other: Self) -> bool {
380 self.lt(other).all()
381 }
382
0731742a 383 #[inline(always)]
dfeec247
XL
384 fn all_le(self, other: Self) -> bool {
385 self.le(other).all()
386 }
387
0731742a 388 #[inline(always)]
dfeec247
XL
389 fn all_finite(self) -> bool {
390 self.finite_mask().all()
391 }
392
0731742a
XL
393 #[inline(always)]
394 fn finite_mask(self) -> Self::Mask {
395 // This can possibly be done faster by checking bit patterns
396 let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY);
397 let pos_inf = $ty::splat(::core::$f_scalar::INFINITY);
398 self.gt(neg_inf) & self.lt(pos_inf)
399 }
dfeec247 400
0731742a 401 #[inline(always)]
dfeec247
XL
402 fn gt_mask(self, other: Self) -> Self::Mask {
403 self.gt(other)
404 }
405
0731742a 406 #[inline(always)]
dfeec247
XL
407 fn ge_mask(self, other: Self) -> Self::Mask {
408 self.ge(other)
409 }
410
0731742a
XL
411 #[inline(always)]
412 fn decrease_masked(self, mask: Self::Mask) -> Self {
413 // Casting a mask into ints will produce all bits set for
414 // true, and 0 for false. Adding that to the binary
415 // representation of a float means subtracting one from
416 // the binary representation, resulting in the next lower
417 // value representable by $ty. This works even when the
418 // current value is infinity.
419 debug_assert!(mask.any(), "At least one lane must be set");
420 <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask))
421 }
dfeec247 422
416331ca 423 #[inline]
dfeec247
XL
424 fn cast_from_int(i: Self::UInt) -> Self {
425 i.cast()
426 }
0731742a 427 }
dfeec247 428 };
0731742a
XL
429}
430
431#[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 }
432#[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 }
433#[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 }
434#[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 }
435#[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 }
436#[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 }
437#[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 }
438
439/// Calculates ln(gamma(x)) (natural logarithm of the gamma
440/// function) using the Lanczos approximation.
441///
442/// The approximation expresses the gamma function as:
443/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)`
444/// `g` is an arbitrary constant; we use the approximation with `g=5`.
445///
446/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides:
447/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)`
448///
449/// `Ag(z)` is an infinite series with coefficients that can be calculated
450/// ahead of time - we use just the first 6 terms, which is good enough
451/// for most purposes.
dfeec247 452#[cfg(feature = "std")]
0731742a
XL
453pub fn log_gamma(x: f64) -> f64 {
454 // precalculated 6 coefficients for the first 6 terms of the series
455 let coefficients: [f64; 6] = [
456 76.18009172947146,
457 -86.50532032941677,
458 24.01409824083091,
459 -1.231739572450155,
460 0.1208650973866179e-2,
461 -0.5395239384953e-5,
462 ];
463
464 // (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
465 let tmp = x + 5.5;
466 let log = (x + 0.5) * tmp.ln() - tmp;
467
468 // the first few terms of the series for Ag(x)
469 let mut a = 1.000000000190015;
470 let mut denom = x;
471 for coeff in &coefficients {
472 denom += 1.0;
473 a += coeff / denom;
474 }
475
476 // get everything together
477 // a is Ag(x)
478 // 2.5066... is sqrt(2pi)
479 log + (2.5066282746310005 * a / x).ln()
480}
481
482/// Sample a random number using the Ziggurat method (specifically the
483/// ZIGNOR variant from Doornik 2005). Most of the arguments are
484/// directly from the paper:
485///
486/// * `rng`: source of randomness
487/// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
488/// * `X`: the $x_i$ abscissae.
489/// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
490/// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
491/// * `pdf`: the probability density function
492/// * `zero_case`: manual sampling from the tail when we chose the
493/// bottom box (i.e. i == 0)
494
495// the perf improvement (25-50%) is definitely worth the extra code
496// size from force-inlining.
dfeec247 497#[cfg(feature = "std")]
0731742a
XL
498#[inline(always)]
499pub fn ziggurat<R: Rng + ?Sized, P, Z>(
dfeec247
XL
500 rng: &mut R,
501 symmetric: bool,
502 x_tab: ziggurat_tables::ZigTable,
503 f_tab: ziggurat_tables::ZigTable,
504 mut pdf: P,
505 mut zero_case: Z
506) -> f64
507where
508 P: FnMut(f64) -> f64,
509 Z: FnMut(&mut R, f64) -> f64,
510{
416331ca 511 use crate::distributions::float::IntoFloat;
0731742a
XL
512 loop {
513 // As an optimisation we re-implement the conversion to a f64.
514 // From the remaining 12 most significant bits we use 8 to construct `i`.
515 // This saves us generating a whole extra random number, while the added
516 // precision of using 64 bits for f64 does not buy us much.
517 let bits = rng.next_u64();
518 let i = bits as usize & 0xff;
519
520 let u = if symmetric {
521 // Convert to a value in the range [2,4) and substract to get [-1,1)
522 // We can't convert to an open range directly, that would require
523 // substracting `3.0 - EPSILON`, which is not representable.
524 // It is possible with an extra step, but an open range does not
525 // seem neccesary for the ziggurat algorithm anyway.
526 (bits >> 12).into_float_with_exponent(1) - 3.0
527 } else {
528 // Convert to a value in the range [1,2) and substract to get (0,1)
dfeec247 529 (bits >> 12).into_float_with_exponent(0) - (1.0 - ::core::f64::EPSILON / 2.0)
0731742a
XL
530 };
531 let x = u * x_tab[i];
532
dfeec247 533 let test_x = if symmetric { x.abs() } else { x };
0731742a
XL
534
535 // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
536 if test_x < x_tab[i + 1] {
537 return x;
538 }
539 if i == 0 {
540 return zero_case(rng, u);
541 }
542 // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
543 if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
544 return x;
545 }
546 }
547}