]> git.proxmox.com Git - rustc.git/blame - library/portable-simd/crates/core_simd/src/masks.rs
New upstream version 1.58.1+dfsg1
[rustc.git] / library / portable-simd / crates / core_simd / src / masks.rs
CommitLineData
3c0e092e
XL
1//! Types and traits associated with masking lanes of vectors.
2//! Types representing
3#![allow(non_camel_case_types)]
4
5#[cfg_attr(
6 not(all(target_arch = "x86_64", target_feature = "avx512f")),
7 path = "masks/full_masks.rs"
8)]
9#[cfg_attr(
10 all(target_arch = "x86_64", target_feature = "avx512f"),
11 path = "masks/bitmask.rs"
12)]
13mod mask_impl;
14
15use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
16use core::cmp::Ordering;
17use core::fmt;
18
19mod sealed {
20 use super::*;
21
22 /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
23 /// from bleeding into the parent bounds.
24 ///
25 /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
26 /// prevent us from ever removing that bound, or from implementing `MaskElement` on
27 /// non-`PartialEq` types in the future.
28 pub trait Sealed {
29 fn valid<const LANES: usize>(values: Simd<Self, LANES>) -> bool
30 where
31 LaneCount<LANES>: SupportedLaneCount,
32 Self: SimdElement;
33
34 fn eq(self, other: Self) -> bool;
35
36 const TRUE: Self;
37
38 const FALSE: Self;
39 }
40}
41use sealed::Sealed;
42
43/// Marker trait for types that may be used as SIMD mask elements.
44pub unsafe trait MaskElement: SimdElement + Sealed {}
45
46macro_rules! impl_element {
47 { $ty:ty } => {
48 impl Sealed for $ty {
49 fn valid<const LANES: usize>(value: Simd<Self, LANES>) -> bool
50 where
51 LaneCount<LANES>: SupportedLaneCount,
52 {
53 (value.lanes_eq(Simd::splat(0)) | value.lanes_eq(Simd::splat(-1))).all()
54 }
55
56 fn eq(self, other: Self) -> bool { self == other }
57
58 const TRUE: Self = -1;
59 const FALSE: Self = 0;
60 }
61
62 unsafe impl MaskElement for $ty {}
63 }
64}
65
66impl_element! { i8 }
67impl_element! { i16 }
68impl_element! { i32 }
69impl_element! { i64 }
70impl_element! { isize }
71
72/// A SIMD vector mask for `LANES` elements of width specified by `Element`.
73///
74/// The layout of this type is unspecified.
75#[repr(transparent)]
76pub struct Mask<T, const LANES: usize>(mask_impl::Mask<T, LANES>)
77where
78 T: MaskElement,
79 LaneCount<LANES>: SupportedLaneCount;
80
81impl<T, const LANES: usize> Copy for Mask<T, LANES>
82where
83 T: MaskElement,
84 LaneCount<LANES>: SupportedLaneCount,
85{
86}
87
88impl<T, const LANES: usize> Clone for Mask<T, LANES>
89where
90 T: MaskElement,
91 LaneCount<LANES>: SupportedLaneCount,
92{
93 fn clone(&self) -> Self {
94 *self
95 }
96}
97
98impl<T, const LANES: usize> Mask<T, LANES>
99where
100 T: MaskElement,
101 LaneCount<LANES>: SupportedLaneCount,
102{
103 /// Construct a mask by setting all lanes to the given value.
104 pub fn splat(value: bool) -> Self {
105 Self(mask_impl::Mask::splat(value))
106 }
107
108 /// Converts an array to a SIMD vector.
109 pub fn from_array(array: [bool; LANES]) -> Self {
110 let mut vector = Self::splat(false);
111 for (i, v) in array.iter().enumerate() {
112 vector.set(i, *v);
113 }
114 vector
115 }
116
117 /// Converts a SIMD vector to an array.
118 pub fn to_array(self) -> [bool; LANES] {
119 let mut array = [false; LANES];
120 for (i, v) in array.iter_mut().enumerate() {
121 *v = self.test(i);
122 }
123 array
124 }
125
126 /// Converts a vector of integers to a mask, where 0 represents `false` and -1
127 /// represents `true`.
128 ///
129 /// # Safety
130 /// All lanes must be either 0 or -1.
131 #[inline]
132 pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self {
133 unsafe { Self(mask_impl::Mask::from_int_unchecked(value)) }
134 }
135
136 /// Converts a vector of integers to a mask, where 0 represents `false` and -1
137 /// represents `true`.
138 ///
139 /// # Panics
140 /// Panics if any lane is not 0 or -1.
141 #[inline]
142 pub fn from_int(value: Simd<T, LANES>) -> Self {
143 assert!(T::valid(value), "all values must be either 0 or -1",);
144 unsafe { Self::from_int_unchecked(value) }
145 }
146
147 /// Converts the mask to a vector of integers, where 0 represents `false` and -1
148 /// represents `true`.
149 #[inline]
150 pub fn to_int(self) -> Simd<T, LANES> {
151 self.0.to_int()
152 }
153
154 /// Tests the value of the specified lane.
155 ///
156 /// # Safety
157 /// `lane` must be less than `LANES`.
158 #[inline]
159 pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
160 unsafe { self.0.test_unchecked(lane) }
161 }
162
163 /// Tests the value of the specified lane.
164 ///
165 /// # Panics
166 /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
167 #[inline]
168 pub fn test(&self, lane: usize) -> bool {
169 assert!(lane < LANES, "lane index out of range");
170 unsafe { self.test_unchecked(lane) }
171 }
172
173 /// Sets the value of the specified lane.
174 ///
175 /// # Safety
176 /// `lane` must be less than `LANES`.
177 #[inline]
178 pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
179 unsafe {
180 self.0.set_unchecked(lane, value);
181 }
182 }
183
184 /// Sets the value of the specified lane.
185 ///
186 /// # Panics
187 /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
188 #[inline]
189 pub fn set(&mut self, lane: usize, value: bool) {
190 assert!(lane < LANES, "lane index out of range");
191 unsafe {
192 self.set_unchecked(lane, value);
193 }
194 }
195
196 /// Convert this mask to a bitmask, with one bit set per lane.
197 #[cfg(feature = "generic_const_exprs")]
198 pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
199 self.0.to_bitmask()
200 }
201
202 /// Convert a bitmask to a mask.
203 #[cfg(feature = "generic_const_exprs")]
204 pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
205 Self(mask_impl::Mask::from_bitmask(bitmask))
206 }
207
208 /// Returns true if any lane is set, or false otherwise.
209 #[inline]
210 pub fn any(self) -> bool {
211 self.0.any()
212 }
213
214 /// Returns true if all lanes are set, or false otherwise.
215 #[inline]
216 pub fn all(self) -> bool {
217 self.0.all()
218 }
219}
220
221// vector/array conversion
222impl<T, const LANES: usize> From<[bool; LANES]> for Mask<T, LANES>
223where
224 T: MaskElement,
225 LaneCount<LANES>: SupportedLaneCount,
226{
227 fn from(array: [bool; LANES]) -> Self {
228 Self::from_array(array)
229 }
230}
231
232impl<T, const LANES: usize> From<Mask<T, LANES>> for [bool; LANES]
233where
234 T: MaskElement,
235 LaneCount<LANES>: SupportedLaneCount,
236{
237 fn from(vector: Mask<T, LANES>) -> Self {
238 vector.to_array()
239 }
240}
241
242impl<T, const LANES: usize> Default for Mask<T, LANES>
243where
244 T: MaskElement,
245 LaneCount<LANES>: SupportedLaneCount,
246{
247 #[inline]
248 fn default() -> Self {
249 Self::splat(false)
250 }
251}
252
253impl<T, const LANES: usize> PartialEq for Mask<T, LANES>
254where
255 T: MaskElement + PartialEq,
256 LaneCount<LANES>: SupportedLaneCount,
257{
258 #[inline]
259 fn eq(&self, other: &Self) -> bool {
260 self.0 == other.0
261 }
262}
263
264impl<T, const LANES: usize> PartialOrd for Mask<T, LANES>
265where
266 T: MaskElement + PartialOrd,
267 LaneCount<LANES>: SupportedLaneCount,
268{
269 #[inline]
270 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
271 self.0.partial_cmp(&other.0)
272 }
273}
274
275impl<T, const LANES: usize> fmt::Debug for Mask<T, LANES>
276where
277 T: MaskElement + fmt::Debug,
278 LaneCount<LANES>: SupportedLaneCount,
279{
280 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281 f.debug_list()
282 .entries((0..LANES).map(|lane| self.test(lane)))
283 .finish()
284 }
285}
286
287impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES>
288where
289 T: MaskElement,
290 LaneCount<LANES>: SupportedLaneCount,
291{
292 type Output = Self;
293 #[inline]
294 fn bitand(self, rhs: Self) -> Self {
295 Self(self.0 & rhs.0)
296 }
297}
298
299impl<T, const LANES: usize> core::ops::BitAnd<bool> for Mask<T, LANES>
300where
301 T: MaskElement,
302 LaneCount<LANES>: SupportedLaneCount,
303{
304 type Output = Self;
305 #[inline]
306 fn bitand(self, rhs: bool) -> Self {
307 self & Self::splat(rhs)
308 }
309}
310
311impl<T, const LANES: usize> core::ops::BitAnd<Mask<T, LANES>> for bool
312where
313 T: MaskElement,
314 LaneCount<LANES>: SupportedLaneCount,
315{
316 type Output = Mask<T, LANES>;
317 #[inline]
318 fn bitand(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> {
319 Mask::splat(self) & rhs
320 }
321}
322
323impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES>
324where
325 T: MaskElement,
326 LaneCount<LANES>: SupportedLaneCount,
327{
328 type Output = Self;
329 #[inline]
330 fn bitor(self, rhs: Self) -> Self {
331 Self(self.0 | rhs.0)
332 }
333}
334
335impl<T, const LANES: usize> core::ops::BitOr<bool> for Mask<T, LANES>
336where
337 T: MaskElement,
338 LaneCount<LANES>: SupportedLaneCount,
339{
340 type Output = Self;
341 #[inline]
342 fn bitor(self, rhs: bool) -> Self {
343 self | Self::splat(rhs)
344 }
345}
346
347impl<T, const LANES: usize> core::ops::BitOr<Mask<T, LANES>> for bool
348where
349 T: MaskElement,
350 LaneCount<LANES>: SupportedLaneCount,
351{
352 type Output = Mask<T, LANES>;
353 #[inline]
354 fn bitor(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> {
355 Mask::splat(self) | rhs
356 }
357}
358
359impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES>
360where
361 T: MaskElement,
362 LaneCount<LANES>: SupportedLaneCount,
363{
364 type Output = Self;
365 #[inline]
366 fn bitxor(self, rhs: Self) -> Self::Output {
367 Self(self.0 ^ rhs.0)
368 }
369}
370
371impl<T, const LANES: usize> core::ops::BitXor<bool> for Mask<T, LANES>
372where
373 T: MaskElement,
374 LaneCount<LANES>: SupportedLaneCount,
375{
376 type Output = Self;
377 #[inline]
378 fn bitxor(self, rhs: bool) -> Self::Output {
379 self ^ Self::splat(rhs)
380 }
381}
382
383impl<T, const LANES: usize> core::ops::BitXor<Mask<T, LANES>> for bool
384where
385 T: MaskElement,
386 LaneCount<LANES>: SupportedLaneCount,
387{
388 type Output = Mask<T, LANES>;
389 #[inline]
390 fn bitxor(self, rhs: Mask<T, LANES>) -> Self::Output {
391 Mask::splat(self) ^ rhs
392 }
393}
394
395impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES>
396where
397 T: MaskElement,
398 LaneCount<LANES>: SupportedLaneCount,
399{
400 type Output = Mask<T, LANES>;
401 #[inline]
402 fn not(self) -> Self::Output {
403 Self(!self.0)
404 }
405}
406
407impl<T, const LANES: usize> core::ops::BitAndAssign for Mask<T, LANES>
408where
409 T: MaskElement,
410 LaneCount<LANES>: SupportedLaneCount,
411{
412 #[inline]
413 fn bitand_assign(&mut self, rhs: Self) {
414 self.0 = self.0 & rhs.0;
415 }
416}
417
418impl<T, const LANES: usize> core::ops::BitAndAssign<bool> for Mask<T, LANES>
419where
420 T: MaskElement,
421 LaneCount<LANES>: SupportedLaneCount,
422{
423 #[inline]
424 fn bitand_assign(&mut self, rhs: bool) {
425 *self &= Self::splat(rhs);
426 }
427}
428
429impl<T, const LANES: usize> core::ops::BitOrAssign for Mask<T, LANES>
430where
431 T: MaskElement,
432 LaneCount<LANES>: SupportedLaneCount,
433{
434 #[inline]
435 fn bitor_assign(&mut self, rhs: Self) {
436 self.0 = self.0 | rhs.0;
437 }
438}
439
440impl<T, const LANES: usize> core::ops::BitOrAssign<bool> for Mask<T, LANES>
441where
442 T: MaskElement,
443 LaneCount<LANES>: SupportedLaneCount,
444{
445 #[inline]
446 fn bitor_assign(&mut self, rhs: bool) {
447 *self |= Self::splat(rhs);
448 }
449}
450
451impl<T, const LANES: usize> core::ops::BitXorAssign for Mask<T, LANES>
452where
453 T: MaskElement,
454 LaneCount<LANES>: SupportedLaneCount,
455{
456 #[inline]
457 fn bitxor_assign(&mut self, rhs: Self) {
458 self.0 = self.0 ^ rhs.0;
459 }
460}
461
462impl<T, const LANES: usize> core::ops::BitXorAssign<bool> for Mask<T, LANES>
463where
464 T: MaskElement,
465 LaneCount<LANES>: SupportedLaneCount,
466{
467 #[inline]
468 fn bitxor_assign(&mut self, rhs: bool) {
469 *self ^= Self::splat(rhs);
470 }
471}
472
473/// Vector of eight 8-bit masks
474pub type mask8x8 = Mask<i8, 8>;
475
476/// Vector of 16 8-bit masks
477pub type mask8x16 = Mask<i8, 16>;
478
479/// Vector of 32 8-bit masks
480pub type mask8x32 = Mask<i8, 32>;
481
482/// Vector of 16 8-bit masks
483pub type mask8x64 = Mask<i8, 64>;
484
485/// Vector of four 16-bit masks
486pub type mask16x4 = Mask<i16, 4>;
487
488/// Vector of eight 16-bit masks
489pub type mask16x8 = Mask<i16, 8>;
490
491/// Vector of 16 16-bit masks
492pub type mask16x16 = Mask<i16, 16>;
493
494/// Vector of 32 16-bit masks
495pub type mask16x32 = Mask<i32, 32>;
496
497/// Vector of two 32-bit masks
498pub type mask32x2 = Mask<i32, 2>;
499
500/// Vector of four 32-bit masks
501pub type mask32x4 = Mask<i32, 4>;
502
503/// Vector of eight 32-bit masks
504pub type mask32x8 = Mask<i32, 8>;
505
506/// Vector of 16 32-bit masks
507pub type mask32x16 = Mask<i32, 16>;
508
509/// Vector of two 64-bit masks
510pub type mask64x2 = Mask<i64, 2>;
511
512/// Vector of four 64-bit masks
513pub type mask64x4 = Mask<i64, 4>;
514
515/// Vector of eight 64-bit masks
516pub type mask64x8 = Mask<i64, 8>;
517
518/// Vector of two pointer-width masks
519pub type masksizex2 = Mask<isize, 2>;
520
521/// Vector of four pointer-width masks
522pub type masksizex4 = Mask<isize, 4>;
523
524/// Vector of eight pointer-width masks
525pub type masksizex8 = Mask<isize, 8>;
526
527macro_rules! impl_from {
528 { $from:ty => $($to:ty),* } => {
529 $(
530 impl<const LANES: usize> From<Mask<$from, LANES>> for Mask<$to, LANES>
531 where
532 LaneCount<LANES>: SupportedLaneCount,
533 {
534 fn from(value: Mask<$from, LANES>) -> Self {
535 Self(value.0.convert())
536 }
537 }
538 )*
539 }
540}
541impl_from! { i8 => i16, i32, i64, isize }
542impl_from! { i16 => i32, i64, isize, i8 }
543impl_from! { i32 => i64, isize, i8, i16 }
544impl_from! { i64 => isize, i8, i16, i32 }
545impl_from! { isize => i8, i16, i32, i64 }