]>
Commit | Line | Data |
---|---|---|
3c0e092e XL |
1 | //! Types and traits associated with masking lanes of vectors. |
2 | //! Types representing | |
3 | #![allow(non_camel_case_types)] | |
4 | ||
5 | #[cfg_attr( | |
6 | not(all(target_arch = "x86_64", target_feature = "avx512f")), | |
7 | path = "masks/full_masks.rs" | |
8 | )] | |
9 | #[cfg_attr( | |
10 | all(target_arch = "x86_64", target_feature = "avx512f"), | |
11 | path = "masks/bitmask.rs" | |
12 | )] | |
13 | mod mask_impl; | |
14 | ||
15 | use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount}; | |
16 | use core::cmp::Ordering; | |
17 | use core::fmt; | |
18 | ||
19 | mod sealed { | |
20 | use super::*; | |
21 | ||
22 | /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits | |
23 | /// from bleeding into the parent bounds. | |
24 | /// | |
25 | /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would | |
26 | /// prevent us from ever removing that bound, or from implementing `MaskElement` on | |
27 | /// non-`PartialEq` types in the future. | |
28 | pub trait Sealed { | |
29 | fn valid<const LANES: usize>(values: Simd<Self, LANES>) -> bool | |
30 | where | |
31 | LaneCount<LANES>: SupportedLaneCount, | |
32 | Self: SimdElement; | |
33 | ||
34 | fn eq(self, other: Self) -> bool; | |
35 | ||
36 | const TRUE: Self; | |
37 | ||
38 | const FALSE: Self; | |
39 | } | |
40 | } | |
41 | use sealed::Sealed; | |
42 | ||
43 | /// Marker trait for types that may be used as SIMD mask elements. | |
44 | pub unsafe trait MaskElement: SimdElement + Sealed {} | |
45 | ||
46 | macro_rules! impl_element { | |
47 | { $ty:ty } => { | |
48 | impl Sealed for $ty { | |
49 | fn valid<const LANES: usize>(value: Simd<Self, LANES>) -> bool | |
50 | where | |
51 | LaneCount<LANES>: SupportedLaneCount, | |
52 | { | |
53 | (value.lanes_eq(Simd::splat(0)) | value.lanes_eq(Simd::splat(-1))).all() | |
54 | } | |
55 | ||
56 | fn eq(self, other: Self) -> bool { self == other } | |
57 | ||
58 | const TRUE: Self = -1; | |
59 | const FALSE: Self = 0; | |
60 | } | |
61 | ||
62 | unsafe impl MaskElement for $ty {} | |
63 | } | |
64 | } | |
65 | ||
66 | impl_element! { i8 } | |
67 | impl_element! { i16 } | |
68 | impl_element! { i32 } | |
69 | impl_element! { i64 } | |
70 | impl_element! { isize } | |
71 | ||
72 | /// A SIMD vector mask for `LANES` elements of width specified by `Element`. | |
73 | /// | |
74 | /// The layout of this type is unspecified. | |
75 | #[repr(transparent)] | |
76 | pub struct Mask<T, const LANES: usize>(mask_impl::Mask<T, LANES>) | |
77 | where | |
78 | T: MaskElement, | |
79 | LaneCount<LANES>: SupportedLaneCount; | |
80 | ||
81 | impl<T, const LANES: usize> Copy for Mask<T, LANES> | |
82 | where | |
83 | T: MaskElement, | |
84 | LaneCount<LANES>: SupportedLaneCount, | |
85 | { | |
86 | } | |
87 | ||
88 | impl<T, const LANES: usize> Clone for Mask<T, LANES> | |
89 | where | |
90 | T: MaskElement, | |
91 | LaneCount<LANES>: SupportedLaneCount, | |
92 | { | |
93 | fn clone(&self) -> Self { | |
94 | *self | |
95 | } | |
96 | } | |
97 | ||
98 | impl<T, const LANES: usize> Mask<T, LANES> | |
99 | where | |
100 | T: MaskElement, | |
101 | LaneCount<LANES>: SupportedLaneCount, | |
102 | { | |
103 | /// Construct a mask by setting all lanes to the given value. | |
104 | pub fn splat(value: bool) -> Self { | |
105 | Self(mask_impl::Mask::splat(value)) | |
106 | } | |
107 | ||
108 | /// Converts an array to a SIMD vector. | |
109 | pub fn from_array(array: [bool; LANES]) -> Self { | |
110 | let mut vector = Self::splat(false); | |
111 | for (i, v) in array.iter().enumerate() { | |
112 | vector.set(i, *v); | |
113 | } | |
114 | vector | |
115 | } | |
116 | ||
117 | /// Converts a SIMD vector to an array. | |
118 | pub fn to_array(self) -> [bool; LANES] { | |
119 | let mut array = [false; LANES]; | |
120 | for (i, v) in array.iter_mut().enumerate() { | |
121 | *v = self.test(i); | |
122 | } | |
123 | array | |
124 | } | |
125 | ||
126 | /// Converts a vector of integers to a mask, where 0 represents `false` and -1 | |
127 | /// represents `true`. | |
128 | /// | |
129 | /// # Safety | |
130 | /// All lanes must be either 0 or -1. | |
131 | #[inline] | |
132 | pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self { | |
133 | unsafe { Self(mask_impl::Mask::from_int_unchecked(value)) } | |
134 | } | |
135 | ||
136 | /// Converts a vector of integers to a mask, where 0 represents `false` and -1 | |
137 | /// represents `true`. | |
138 | /// | |
139 | /// # Panics | |
140 | /// Panics if any lane is not 0 or -1. | |
141 | #[inline] | |
142 | pub fn from_int(value: Simd<T, LANES>) -> Self { | |
143 | assert!(T::valid(value), "all values must be either 0 or -1",); | |
144 | unsafe { Self::from_int_unchecked(value) } | |
145 | } | |
146 | ||
147 | /// Converts the mask to a vector of integers, where 0 represents `false` and -1 | |
148 | /// represents `true`. | |
149 | #[inline] | |
150 | pub fn to_int(self) -> Simd<T, LANES> { | |
151 | self.0.to_int() | |
152 | } | |
153 | ||
154 | /// Tests the value of the specified lane. | |
155 | /// | |
156 | /// # Safety | |
157 | /// `lane` must be less than `LANES`. | |
158 | #[inline] | |
159 | pub unsafe fn test_unchecked(&self, lane: usize) -> bool { | |
160 | unsafe { self.0.test_unchecked(lane) } | |
161 | } | |
162 | ||
163 | /// Tests the value of the specified lane. | |
164 | /// | |
165 | /// # Panics | |
166 | /// Panics if `lane` is greater than or equal to the number of lanes in the vector. | |
167 | #[inline] | |
168 | pub fn test(&self, lane: usize) -> bool { | |
169 | assert!(lane < LANES, "lane index out of range"); | |
170 | unsafe { self.test_unchecked(lane) } | |
171 | } | |
172 | ||
173 | /// Sets the value of the specified lane. | |
174 | /// | |
175 | /// # Safety | |
176 | /// `lane` must be less than `LANES`. | |
177 | #[inline] | |
178 | pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { | |
179 | unsafe { | |
180 | self.0.set_unchecked(lane, value); | |
181 | } | |
182 | } | |
183 | ||
184 | /// Sets the value of the specified lane. | |
185 | /// | |
186 | /// # Panics | |
187 | /// Panics if `lane` is greater than or equal to the number of lanes in the vector. | |
188 | #[inline] | |
189 | pub fn set(&mut self, lane: usize, value: bool) { | |
190 | assert!(lane < LANES, "lane index out of range"); | |
191 | unsafe { | |
192 | self.set_unchecked(lane, value); | |
193 | } | |
194 | } | |
195 | ||
196 | /// Convert this mask to a bitmask, with one bit set per lane. | |
197 | #[cfg(feature = "generic_const_exprs")] | |
198 | pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] { | |
199 | self.0.to_bitmask() | |
200 | } | |
201 | ||
202 | /// Convert a bitmask to a mask. | |
203 | #[cfg(feature = "generic_const_exprs")] | |
204 | pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self { | |
205 | Self(mask_impl::Mask::from_bitmask(bitmask)) | |
206 | } | |
207 | ||
208 | /// Returns true if any lane is set, or false otherwise. | |
209 | #[inline] | |
210 | pub fn any(self) -> bool { | |
211 | self.0.any() | |
212 | } | |
213 | ||
214 | /// Returns true if all lanes are set, or false otherwise. | |
215 | #[inline] | |
216 | pub fn all(self) -> bool { | |
217 | self.0.all() | |
218 | } | |
219 | } | |
220 | ||
221 | // vector/array conversion | |
222 | impl<T, const LANES: usize> From<[bool; LANES]> for Mask<T, LANES> | |
223 | where | |
224 | T: MaskElement, | |
225 | LaneCount<LANES>: SupportedLaneCount, | |
226 | { | |
227 | fn from(array: [bool; LANES]) -> Self { | |
228 | Self::from_array(array) | |
229 | } | |
230 | } | |
231 | ||
232 | impl<T, const LANES: usize> From<Mask<T, LANES>> for [bool; LANES] | |
233 | where | |
234 | T: MaskElement, | |
235 | LaneCount<LANES>: SupportedLaneCount, | |
236 | { | |
237 | fn from(vector: Mask<T, LANES>) -> Self { | |
238 | vector.to_array() | |
239 | } | |
240 | } | |
241 | ||
242 | impl<T, const LANES: usize> Default for Mask<T, LANES> | |
243 | where | |
244 | T: MaskElement, | |
245 | LaneCount<LANES>: SupportedLaneCount, | |
246 | { | |
247 | #[inline] | |
248 | fn default() -> Self { | |
249 | Self::splat(false) | |
250 | } | |
251 | } | |
252 | ||
253 | impl<T, const LANES: usize> PartialEq for Mask<T, LANES> | |
254 | where | |
255 | T: MaskElement + PartialEq, | |
256 | LaneCount<LANES>: SupportedLaneCount, | |
257 | { | |
258 | #[inline] | |
259 | fn eq(&self, other: &Self) -> bool { | |
260 | self.0 == other.0 | |
261 | } | |
262 | } | |
263 | ||
264 | impl<T, const LANES: usize> PartialOrd for Mask<T, LANES> | |
265 | where | |
266 | T: MaskElement + PartialOrd, | |
267 | LaneCount<LANES>: SupportedLaneCount, | |
268 | { | |
269 | #[inline] | |
270 | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { | |
271 | self.0.partial_cmp(&other.0) | |
272 | } | |
273 | } | |
274 | ||
275 | impl<T, const LANES: usize> fmt::Debug for Mask<T, LANES> | |
276 | where | |
277 | T: MaskElement + fmt::Debug, | |
278 | LaneCount<LANES>: SupportedLaneCount, | |
279 | { | |
280 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
281 | f.debug_list() | |
282 | .entries((0..LANES).map(|lane| self.test(lane))) | |
283 | .finish() | |
284 | } | |
285 | } | |
286 | ||
287 | impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES> | |
288 | where | |
289 | T: MaskElement, | |
290 | LaneCount<LANES>: SupportedLaneCount, | |
291 | { | |
292 | type Output = Self; | |
293 | #[inline] | |
294 | fn bitand(self, rhs: Self) -> Self { | |
295 | Self(self.0 & rhs.0) | |
296 | } | |
297 | } | |
298 | ||
299 | impl<T, const LANES: usize> core::ops::BitAnd<bool> for Mask<T, LANES> | |
300 | where | |
301 | T: MaskElement, | |
302 | LaneCount<LANES>: SupportedLaneCount, | |
303 | { | |
304 | type Output = Self; | |
305 | #[inline] | |
306 | fn bitand(self, rhs: bool) -> Self { | |
307 | self & Self::splat(rhs) | |
308 | } | |
309 | } | |
310 | ||
311 | impl<T, const LANES: usize> core::ops::BitAnd<Mask<T, LANES>> for bool | |
312 | where | |
313 | T: MaskElement, | |
314 | LaneCount<LANES>: SupportedLaneCount, | |
315 | { | |
316 | type Output = Mask<T, LANES>; | |
317 | #[inline] | |
318 | fn bitand(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> { | |
319 | Mask::splat(self) & rhs | |
320 | } | |
321 | } | |
322 | ||
323 | impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES> | |
324 | where | |
325 | T: MaskElement, | |
326 | LaneCount<LANES>: SupportedLaneCount, | |
327 | { | |
328 | type Output = Self; | |
329 | #[inline] | |
330 | fn bitor(self, rhs: Self) -> Self { | |
331 | Self(self.0 | rhs.0) | |
332 | } | |
333 | } | |
334 | ||
335 | impl<T, const LANES: usize> core::ops::BitOr<bool> for Mask<T, LANES> | |
336 | where | |
337 | T: MaskElement, | |
338 | LaneCount<LANES>: SupportedLaneCount, | |
339 | { | |
340 | type Output = Self; | |
341 | #[inline] | |
342 | fn bitor(self, rhs: bool) -> Self { | |
343 | self | Self::splat(rhs) | |
344 | } | |
345 | } | |
346 | ||
347 | impl<T, const LANES: usize> core::ops::BitOr<Mask<T, LANES>> for bool | |
348 | where | |
349 | T: MaskElement, | |
350 | LaneCount<LANES>: SupportedLaneCount, | |
351 | { | |
352 | type Output = Mask<T, LANES>; | |
353 | #[inline] | |
354 | fn bitor(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> { | |
355 | Mask::splat(self) | rhs | |
356 | } | |
357 | } | |
358 | ||
359 | impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES> | |
360 | where | |
361 | T: MaskElement, | |
362 | LaneCount<LANES>: SupportedLaneCount, | |
363 | { | |
364 | type Output = Self; | |
365 | #[inline] | |
366 | fn bitxor(self, rhs: Self) -> Self::Output { | |
367 | Self(self.0 ^ rhs.0) | |
368 | } | |
369 | } | |
370 | ||
371 | impl<T, const LANES: usize> core::ops::BitXor<bool> for Mask<T, LANES> | |
372 | where | |
373 | T: MaskElement, | |
374 | LaneCount<LANES>: SupportedLaneCount, | |
375 | { | |
376 | type Output = Self; | |
377 | #[inline] | |
378 | fn bitxor(self, rhs: bool) -> Self::Output { | |
379 | self ^ Self::splat(rhs) | |
380 | } | |
381 | } | |
382 | ||
383 | impl<T, const LANES: usize> core::ops::BitXor<Mask<T, LANES>> for bool | |
384 | where | |
385 | T: MaskElement, | |
386 | LaneCount<LANES>: SupportedLaneCount, | |
387 | { | |
388 | type Output = Mask<T, LANES>; | |
389 | #[inline] | |
390 | fn bitxor(self, rhs: Mask<T, LANES>) -> Self::Output { | |
391 | Mask::splat(self) ^ rhs | |
392 | } | |
393 | } | |
394 | ||
395 | impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES> | |
396 | where | |
397 | T: MaskElement, | |
398 | LaneCount<LANES>: SupportedLaneCount, | |
399 | { | |
400 | type Output = Mask<T, LANES>; | |
401 | #[inline] | |
402 | fn not(self) -> Self::Output { | |
403 | Self(!self.0) | |
404 | } | |
405 | } | |
406 | ||
407 | impl<T, const LANES: usize> core::ops::BitAndAssign for Mask<T, LANES> | |
408 | where | |
409 | T: MaskElement, | |
410 | LaneCount<LANES>: SupportedLaneCount, | |
411 | { | |
412 | #[inline] | |
413 | fn bitand_assign(&mut self, rhs: Self) { | |
414 | self.0 = self.0 & rhs.0; | |
415 | } | |
416 | } | |
417 | ||
418 | impl<T, const LANES: usize> core::ops::BitAndAssign<bool> for Mask<T, LANES> | |
419 | where | |
420 | T: MaskElement, | |
421 | LaneCount<LANES>: SupportedLaneCount, | |
422 | { | |
423 | #[inline] | |
424 | fn bitand_assign(&mut self, rhs: bool) { | |
425 | *self &= Self::splat(rhs); | |
426 | } | |
427 | } | |
428 | ||
429 | impl<T, const LANES: usize> core::ops::BitOrAssign for Mask<T, LANES> | |
430 | where | |
431 | T: MaskElement, | |
432 | LaneCount<LANES>: SupportedLaneCount, | |
433 | { | |
434 | #[inline] | |
435 | fn bitor_assign(&mut self, rhs: Self) { | |
436 | self.0 = self.0 | rhs.0; | |
437 | } | |
438 | } | |
439 | ||
440 | impl<T, const LANES: usize> core::ops::BitOrAssign<bool> for Mask<T, LANES> | |
441 | where | |
442 | T: MaskElement, | |
443 | LaneCount<LANES>: SupportedLaneCount, | |
444 | { | |
445 | #[inline] | |
446 | fn bitor_assign(&mut self, rhs: bool) { | |
447 | *self |= Self::splat(rhs); | |
448 | } | |
449 | } | |
450 | ||
451 | impl<T, const LANES: usize> core::ops::BitXorAssign for Mask<T, LANES> | |
452 | where | |
453 | T: MaskElement, | |
454 | LaneCount<LANES>: SupportedLaneCount, | |
455 | { | |
456 | #[inline] | |
457 | fn bitxor_assign(&mut self, rhs: Self) { | |
458 | self.0 = self.0 ^ rhs.0; | |
459 | } | |
460 | } | |
461 | ||
462 | impl<T, const LANES: usize> core::ops::BitXorAssign<bool> for Mask<T, LANES> | |
463 | where | |
464 | T: MaskElement, | |
465 | LaneCount<LANES>: SupportedLaneCount, | |
466 | { | |
467 | #[inline] | |
468 | fn bitxor_assign(&mut self, rhs: bool) { | |
469 | *self ^= Self::splat(rhs); | |
470 | } | |
471 | } | |
472 | ||
473 | /// Vector of eight 8-bit masks | |
474 | pub type mask8x8 = Mask<i8, 8>; | |
475 | ||
476 | /// Vector of 16 8-bit masks | |
477 | pub type mask8x16 = Mask<i8, 16>; | |
478 | ||
479 | /// Vector of 32 8-bit masks | |
480 | pub type mask8x32 = Mask<i8, 32>; | |
481 | ||
482 | /// Vector of 16 8-bit masks | |
483 | pub type mask8x64 = Mask<i8, 64>; | |
484 | ||
485 | /// Vector of four 16-bit masks | |
486 | pub type mask16x4 = Mask<i16, 4>; | |
487 | ||
488 | /// Vector of eight 16-bit masks | |
489 | pub type mask16x8 = Mask<i16, 8>; | |
490 | ||
491 | /// Vector of 16 16-bit masks | |
492 | pub type mask16x16 = Mask<i16, 16>; | |
493 | ||
494 | /// Vector of 32 16-bit masks | |
495 | pub type mask16x32 = Mask<i32, 32>; | |
496 | ||
497 | /// Vector of two 32-bit masks | |
498 | pub type mask32x2 = Mask<i32, 2>; | |
499 | ||
500 | /// Vector of four 32-bit masks | |
501 | pub type mask32x4 = Mask<i32, 4>; | |
502 | ||
503 | /// Vector of eight 32-bit masks | |
504 | pub type mask32x8 = Mask<i32, 8>; | |
505 | ||
506 | /// Vector of 16 32-bit masks | |
507 | pub type mask32x16 = Mask<i32, 16>; | |
508 | ||
509 | /// Vector of two 64-bit masks | |
510 | pub type mask64x2 = Mask<i64, 2>; | |
511 | ||
512 | /// Vector of four 64-bit masks | |
513 | pub type mask64x4 = Mask<i64, 4>; | |
514 | ||
515 | /// Vector of eight 64-bit masks | |
516 | pub type mask64x8 = Mask<i64, 8>; | |
517 | ||
518 | /// Vector of two pointer-width masks | |
519 | pub type masksizex2 = Mask<isize, 2>; | |
520 | ||
521 | /// Vector of four pointer-width masks | |
522 | pub type masksizex4 = Mask<isize, 4>; | |
523 | ||
524 | /// Vector of eight pointer-width masks | |
525 | pub type masksizex8 = Mask<isize, 8>; | |
526 | ||
527 | macro_rules! impl_from { | |
528 | { $from:ty => $($to:ty),* } => { | |
529 | $( | |
530 | impl<const LANES: usize> From<Mask<$from, LANES>> for Mask<$to, LANES> | |
531 | where | |
532 | LaneCount<LANES>: SupportedLaneCount, | |
533 | { | |
534 | fn from(value: Mask<$from, LANES>) -> Self { | |
535 | Self(value.0.convert()) | |
536 | } | |
537 | } | |
538 | )* | |
539 | } | |
540 | } | |
541 | impl_from! { i8 => i16, i32, i64, isize } | |
542 | impl_from! { i16 => i32, i64, isize, i8 } | |
543 | impl_from! { i32 => i64, isize, i8, i16 } | |
544 | impl_from! { i64 => isize, i8, i16, i32 } | |
545 | impl_from! { isize => i8, i16, i32, i64 } |