1 // Copyright 2018 Developers of the Rand project.
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
9 //! Math helper functions
11 #[cfg(feature = "std")] use crate::distributions::ziggurat_tables;
12 #[cfg(feature = "std")] use crate::Rng;
13 #[cfg(feature = "simd_support")] use packed_simd::*;
16 pub trait WideningMultiply
<RHS
= Self> {
19 fn wmul(self, x
: RHS
) -> Self::Output
;
22 macro_rules
! wmul_impl
{
23 ($ty
:ty
, $wide
:ty
, $shift
:expr
) => {
24 impl WideningMultiply
for $ty
{
25 type Output
= ($ty
, $ty
);
28 fn wmul(self, x
: $ty
) -> Self::Output
{
29 let tmp
= (self as $wide
) * (x
as $wide
);
30 ((tmp
>> $shift
) as $ty
, tmp
as $ty
)
35 // simd bulk implementation
36 ($
(($ty
:ident
, $wide
:ident
),)+, $shift
:expr
) => {
38 impl WideningMultiply
for $ty
{
39 type Output
= ($ty
, $ty
);
42 fn wmul(self, x
: $ty
) -> Self::Output
{
43 // For supported vectors, this should compile to a couple
44 // supported multiply & swizzle instructions (no actual
47 let y
: $wide
= self.cast();
48 let x
: $wide
= x
.cast();
50 let hi
: $ty
= (tmp
>> $shift
).cast();
51 let lo
: $ty
= tmp
.cast();
58 wmul_impl
! { u8, u16, 8 }
59 wmul_impl
! { u16, u32, 16 }
60 wmul_impl
! { u32, u64, 32 }
61 #[cfg(not(target_os = "emscripten"))]
62 wmul_impl
! { u64, u128, 64 }
64 // This code is a translation of the __mulddi3 function in LLVM's
65 // compiler-rt. It is an optimised variant of the common method
66 // `(a + b) * (c + d) = ac + ad + bc + bd`.
68 // For some reason LLVM can optimise the C version very well, but
69 // keeps shuffling registers in this Rust translation.
70 macro_rules
! wmul_impl_large
{
71 ($ty
:ty
, $half
:expr
) => {
72 impl WideningMultiply
for $ty
{
73 type Output
= ($ty
, $ty
);
76 fn wmul(self, b
: $ty
) -> Self::Output
{
77 const LOWER_MASK
: $ty
= !0 >> $half
;
78 let mut low
= (self & LOWER_MASK
).wrapping_mul(b
& LOWER_MASK
);
79 let mut t
= low
>> $half
;
81 t
+= (self >> $half
).wrapping_mul(b
& LOWER_MASK
);
82 low
+= (t
& LOWER_MASK
) << $half
;
83 let mut high
= t
>> $half
;
86 t
+= (b
>> $half
).wrapping_mul(self & LOWER_MASK
);
87 low
+= (t
& LOWER_MASK
) << $half
;
89 high
+= (self >> $half
).wrapping_mul(b
>> $half
);
96 // simd bulk implementation
97 (($
($ty
:ty
,)+) $scalar
:ty
, $half
:expr
) => {
99 impl WideningMultiply
for $ty
{
100 type Output
= ($ty
, $ty
);
103 fn wmul(self, b
: $ty
) -> Self::Output
{
104 // needs wrapping multiplication
105 const LOWER_MASK
: $scalar
= !0 >> $half
;
106 let mut low
= (self & LOWER_MASK
) * (b
& LOWER_MASK
);
107 let mut t
= low
>> $half
;
109 t
+= (self >> $half
) * (b
& LOWER_MASK
);
110 low
+= (t
& LOWER_MASK
) << $half
;
111 let mut high
= t
>> $half
;
114 t
+= (b
>> $half
) * (self & LOWER_MASK
);
115 low
+= (t
& LOWER_MASK
) << $half
;
117 high
+= (self >> $half
) * (b
>> $half
);
125 #[cfg(target_os = "emscripten")]
126 wmul_impl_large
! { u64, 32 }
127 #[cfg(not(target_os = "emscripten"))]
128 wmul_impl_large
! { u128, 64 }
130 macro_rules
! wmul_impl_usize
{
132 impl WideningMultiply
for usize {
133 type Output
= (usize, usize);
136 fn wmul(self, x
: usize) -> Self::Output
{
137 let (high
, low
) = (self as $ty
).wmul(x
as $ty
);
138 (high
as usize, low
as usize)
143 #[cfg(target_pointer_width = "32")]
144 wmul_impl_usize
! { u32 }
145 #[cfg(target_pointer_width = "64")]
146 wmul_impl_usize
! { u64 }
148 #[cfg(all(feature = "simd_support", feature = "nightly"))]
151 #[cfg(target_arch = "x86")] use core::arch::x86::*;
152 #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*;
163 wmul_impl
! { (u16x2, u32x2),, 16 }
164 #[cfg(not(target_feature = "sse2"))]
165 wmul_impl
! { (u16x4, u32x4),, 16 }
166 #[cfg(not(target_feature = "sse4.2"))]
167 wmul_impl
! { (u16x8, u32x8),, 16 }
168 #[cfg(not(target_feature = "avx2"))]
169 wmul_impl
! { (u16x16, u32x16),, 16 }
171 // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
172 // means `wmul` can be implemented with only two instructions.
173 #[allow(unused_macros)]
174 macro_rules
! wmul_impl_16
{
175 ($ty
:ident
, $intrinsic
:ident
, $mulhi
:ident
, $mullo
:ident
) => {
176 impl WideningMultiply
for $ty
{
177 type Output
= ($ty
, $ty
);
180 fn wmul(self, x
: $ty
) -> Self::Output
{
181 let b
= $intrinsic
::from_bits(x
);
182 let a
= $intrinsic
::from_bits(self);
183 let hi
= $ty
::from_bits(unsafe { $mulhi(a, b) }
);
184 let lo
= $ty
::from_bits(unsafe { $mullo(a, b) }
);
191 #[cfg(target_feature = "sse2")]
192 wmul_impl_16
! { u16x4, __m64, _mm_mulhi_pu16, _mm_mullo_pi16 }
193 #[cfg(target_feature = "sse4.2")]
194 wmul_impl_16
! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 }
195 #[cfg(target_feature = "avx2")]
196 wmul_impl_16
! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
197 // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>`
198 // cannot use the same implementation.
207 // TODO: optimize, this seems to seriously slow things down
208 wmul_impl_large
! { (u8x64,) u8, 4 }
209 wmul_impl_large
! { (u16x32,) u16, 8 }
210 wmul_impl_large
! { (u32x16,) u32, 16 }
211 wmul_impl_large
! { (u64x2, u64x4, u64x8,) u64, 32 }
213 #[cfg(all(feature = "simd_support", feature = "nightly"))]
214 pub use self::simd_wmul
::*;
217 /// Helper trait when dealing with scalar and SIMD floating point types.
218 pub(crate) trait FloatSIMDUtils
{
219 // `PartialOrd` for vectors compares lexicographically. We want to compare all
220 // the individual SIMD lanes instead, and get the combined result over all
221 // lanes. This is possible using something like `a.lt(b).all()`, but we
222 // implement it as a trait so we can write the same code for `f32` and `f64`.
223 // Only the comparison functions we need are implemented.
224 fn all_lt(self, other
: Self) -> bool
;
225 fn all_le(self, other
: Self) -> bool
;
226 fn all_finite(self) -> bool
;
229 fn finite_mask(self) -> Self::Mask
;
230 fn gt_mask(self, other
: Self) -> Self::Mask
;
231 fn ge_mask(self, other
: Self) -> Self::Mask
;
233 // Decrease all lanes where the mask is `true` to the next lower value
234 // representable by the floating-point type. At least one of the lanes
236 fn decrease_masked(self, mask
: Self::Mask
) -> Self;
238 // Convert from int value. Conversion is done while retaining the numerical
239 // value, not by retaining the binary representation.
241 fn cast_from_int(i
: Self::UInt
) -> Self;
244 /// Implement functions available in std builds but missing from core primitives
246 pub(crate) trait Float
: Sized
{
247 fn is_nan(self) -> bool
;
248 fn is_infinite(self) -> bool
;
249 fn is_finite(self) -> bool
;
252 /// Implement functions on f32/f64 to give them APIs similar to SIMD types
253 pub(crate) trait FloatAsSIMD
: Sized
{
255 fn lanes() -> usize {
259 fn splat(scalar
: Self) -> Self {
263 fn extract(self, index
: usize) -> Self {
264 debug_assert_eq
!(index
, 0);
268 fn replace(self, index
: usize, new_value
: Self) -> Self {
269 debug_assert_eq
!(index
, 0);
274 pub(crate) trait BoolAsSIMD
: Sized
{
275 fn any(self) -> bool
;
276 fn all(self) -> bool
;
277 fn none(self) -> bool
;
280 impl BoolAsSIMD
for bool
{
282 fn any(self) -> bool
{
287 fn all(self) -> bool
{
292 fn none(self) -> bool
{
297 macro_rules
! scalar_float_impl
{
298 ($ty
:ident
, $uty
:ident
) => {
302 fn is_nan(self) -> bool
{
307 fn is_infinite(self) -> bool
{
308 self == ::core
::$ty
::INFINITY
|| self == ::core
::$ty
::NEG_INFINITY
312 fn is_finite(self) -> bool
{
313 !(self.is_nan() || self.is_infinite())
317 impl FloatSIMDUtils
for $ty
{
322 fn all_lt(self, other
: Self) -> bool
{
327 fn all_le(self, other
: Self) -> bool
{
332 fn all_finite(self) -> bool
{
337 fn finite_mask(self) -> Self::Mask
{
342 fn gt_mask(self, other
: Self) -> Self::Mask
{
347 fn ge_mask(self, other
: Self) -> Self::Mask
{
352 fn decrease_masked(self, mask
: Self::Mask
) -> Self {
353 debug_assert
!(mask
, "At least one lane must be set");
354 <$ty
>::from_bits(self.to_bits() - 1)
358 fn cast_from_int(i
: Self::UInt
) -> Self {
363 impl FloatAsSIMD
for $ty {}
367 scalar_float_impl
!(f32, u32);
368 scalar_float_impl
!(f64, u64);
371 #[cfg(feature = "simd_support")]
372 macro_rules
! simd_impl
{
373 ($ty
:ident
, $f_scalar
:ident
, $mty
:ident
, $uty
:ident
) => {
374 impl FloatSIMDUtils
for $ty
{
379 fn all_lt(self, other
: Self) -> bool
{
384 fn all_le(self, other
: Self) -> bool
{
389 fn all_finite(self) -> bool
{
390 self.finite_mask().all()
394 fn finite_mask(self) -> Self::Mask
{
395 // This can possibly be done faster by checking bit patterns
396 let neg_inf
= $ty
::splat(::core
::$f_scalar
::NEG_INFINITY
);
397 let pos_inf
= $ty
::splat(::core
::$f_scalar
::INFINITY
);
398 self.gt(neg_inf
) & self.lt(pos_inf
)
402 fn gt_mask(self, other
: Self) -> Self::Mask
{
407 fn ge_mask(self, other
: Self) -> Self::Mask
{
412 fn decrease_masked(self, mask
: Self::Mask
) -> Self {
413 // Casting a mask into ints will produce all bits set for
414 // true, and 0 for false. Adding that to the binary
415 // representation of a float means subtracting one from
416 // the binary representation, resulting in the next lower
417 // value representable by $ty. This works even when the
418 // current value is infinity.
419 debug_assert
!(mask
.any(), "At least one lane must be set");
420 <$ty
>::from_bits(<$uty
>::from_bits(self) + <$uty
>::from_bits(mask
))
424 fn cast_from_int(i
: Self::UInt
) -> Self {
431 #[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 }
432 #[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 }
433 #[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 }
434 #[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 }
435 #[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 }
436 #[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 }
437 #[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 }
439 /// Calculates ln(gamma(x)) (natural logarithm of the gamma
440 /// function) using the Lanczos approximation.
442 /// The approximation expresses the gamma function as:
443 /// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)`
444 /// `g` is an arbitrary constant; we use the approximation with `g=5`.
446 /// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides:
447 /// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)`
449 /// `Ag(z)` is an infinite series with coefficients that can be calculated
450 /// ahead of time - we use just the first 6 terms, which is good enough
451 /// for most purposes.
452 #[cfg(feature = "std")]
453 pub fn log_gamma(x
: f64) -> f64 {
454 // precalculated 6 coefficients for the first 6 terms of the series
455 let coefficients
: [f64; 6] = [
460 0.1208650973866179e-2,
464 // (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
466 let log
= (x
+ 0.5) * tmp
.ln() - tmp
;
468 // the first few terms of the series for Ag(x)
469 let mut a
= 1.000000000190015;
471 for coeff
in &coefficients
{
476 // get everything together
478 // 2.5066... is sqrt(2pi)
479 log
+ (2.5066282746310005 * a
/ x
).ln()
482 /// Sample a random number using the Ziggurat method (specifically the
483 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
484 /// directly from the paper:
486 /// * `rng`: source of randomness
487 /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
488 /// * `X`: the $x_i$ abscissae.
489 /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
490 /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
491 /// * `pdf`: the probability density function
492 /// * `zero_case`: manual sampling from the tail when we chose the
493 /// bottom box (i.e. i == 0)
495 // the perf improvement (25-50%) is definitely worth the extra code
496 // size from force-inlining.
497 #[cfg(feature = "std")]
499 pub fn ziggurat
<R
: Rng
+ ?Sized
, P
, Z
>(
502 x_tab
: ziggurat_tables
::ZigTable
,
503 f_tab
: ziggurat_tables
::ZigTable
,
508 P
: FnMut(f64) -> f64,
509 Z
: FnMut(&mut R
, f64) -> f64,
511 use crate::distributions
::float
::IntoFloat
;
513 // As an optimisation we re-implement the conversion to a f64.
514 // From the remaining 12 most significant bits we use 8 to construct `i`.
515 // This saves us generating a whole extra random number, while the added
516 // precision of using 64 bits for f64 does not buy us much.
517 let bits
= rng
.next_u64();
518 let i
= bits
as usize & 0xff;
520 let u
= if symmetric
{
521 // Convert to a value in the range [2,4) and substract to get [-1,1)
522 // We can't convert to an open range directly, that would require
523 // substracting `3.0 - EPSILON`, which is not representable.
524 // It is possible with an extra step, but an open range does not
525 // seem neccesary for the ziggurat algorithm anyway.
526 (bits
>> 12).into_float_with_exponent(1) - 3.0
528 // Convert to a value in the range [1,2) and substract to get (0,1)
529 (bits
>> 12).into_float_with_exponent(0) - (1.0 - ::core
::f64::EPSILON
/ 2.0)
531 let x
= u
* x_tab
[i
];
533 let test_x
= if symmetric { x.abs() }
else { x }
;
535 // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
536 if test_x
< x_tab
[i
+ 1] {
540 return zero_case(rng
, u
);
542 // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
543 if f_tab
[i
+ 1] + (f_tab
[i
] - f_tab
[i
+ 1]) * rng
.gen
::<f64>() < pdf(x
) {