]> git.lizzy.rs Git - rust.git/blob - crates/core_simd/src/masks.rs
Convert all masks to a single type
[rust.git] / crates / core_simd / src / masks.rs
1 //! Types and traits associated with masking lanes of vectors.
2 //! Types representing
3 #![allow(non_camel_case_types)]
4
5 #[cfg_attr(
6     not(all(target_arch = "x86_64", target_feature = "avx512f")),
7     path = "masks/full_masks.rs"
8 )]
9 #[cfg_attr(
10     all(target_arch = "x86_64", target_feature = "avx512f"),
11     path = "masks/bitmask.rs"
12 )]
13 mod mask_impl;
14
15 use crate::{LaneCount, Simd, SimdElement, SupportedLaneCount};
16
17 /// Marker trait for types that may be used as SIMD mask elements.
18 pub unsafe trait MaskElement: SimdElement {
19     #[doc(hidden)]
20     fn valid<const LANES: usize>(values: Simd<Self, LANES>) -> bool
21     where
22         LaneCount<LANES>: SupportedLaneCount;
23
24     #[doc(hidden)]
25     fn eq(self, other: Self) -> bool;
26
27     #[doc(hidden)]
28     const TRUE: Self;
29
30     #[doc(hidden)]
31     const FALSE: Self;
32 }
33
34 macro_rules! impl_element {
35     { $ty:ty } => {
36         unsafe impl MaskElement for $ty {
37             fn valid<const LANES: usize>(value: Simd<Self, LANES>) -> bool
38             where
39                 LaneCount<LANES>: SupportedLaneCount,
40             {
41                 (value.lanes_eq(Simd::splat(0)) | value.lanes_eq(Simd::splat(-1))).all()
42             }
43
44             fn eq(self, other: Self) -> bool { self == other }
45
46             const TRUE: Self = -1;
47             const FALSE: Self = 0;
48         }
49     }
50 }
51
52 impl_element! { i8 }
53 impl_element! { i16 }
54 impl_element! { i32 }
55 impl_element! { i64 }
56 impl_element! { isize }
57
58 /// A SIMD vector mask for `LANES` elements of width specified by `Element`.
59 ///
60 /// The layout of this type is unspecified.
61 #[repr(transparent)]
62 pub struct Mask<Element, const LANES: usize>(mask_impl::Mask<Element, LANES>)
63 where
64     Element: MaskElement,
65     LaneCount<LANES>: SupportedLaneCount;
66
67 impl<Element, const LANES: usize> Copy for Mask<Element, LANES>
68 where
69     Element: MaskElement,
70     LaneCount<LANES>: SupportedLaneCount,
71 {
72 }
73
74 impl<Element, const LANES: usize> Clone for Mask<Element, LANES>
75 where
76     Element: MaskElement,
77     LaneCount<LANES>: SupportedLaneCount,
78 {
79     fn clone(&self) -> Self {
80         *self
81     }
82 }
83
84 impl<Element, const LANES: usize> Mask<Element, LANES>
85 where
86     Element: MaskElement,
87     LaneCount<LANES>: SupportedLaneCount,
88 {
89     /// Construct a mask by setting all lanes to the given value.
90     pub fn splat(value: bool) -> Self {
91         Self(mask_impl::Mask::splat(value))
92     }
93
94     /// Converts an array to a SIMD vector.
95     pub fn from_array(array: [bool; LANES]) -> Self {
96         let mut vector = Self::splat(false);
97         for (i, v) in array.iter().enumerate() {
98             vector.set(i, *v);
99         }
100         vector
101     }
102
103     /// Converts a SIMD vector to an array.
104     pub fn to_array(self) -> [bool; LANES] {
105         let mut array = [false; LANES];
106         for (i, v) in array.iter_mut().enumerate() {
107             *v = self.test(i);
108         }
109         array
110     }
111
112     /// Converts a vector of integers to a mask, where 0 represents `false` and -1
113     /// represents `true`.
114     ///
115     /// # Safety
116     /// All lanes must be either 0 or -1.
117     #[inline]
118     pub unsafe fn from_int_unchecked(value: Simd<Element, LANES>) -> Self {
119         Self(mask_impl::Mask::from_int_unchecked(value))
120     }
121
122     /// Converts a vector of integers to a mask, where 0 represents `false` and -1
123     /// represents `true`.
124     ///
125     /// # Panics
126     /// Panics if any lane is not 0 or -1.
127     #[inline]
128     pub fn from_int(value: Simd<Element, LANES>) -> Self {
129         assert!(Element::valid(value), "all values must be either 0 or -1",);
130         unsafe { Self::from_int_unchecked(value) }
131     }
132
133     /// Converts the mask to a vector of integers, where 0 represents `false` and -1
134     /// represents `true`.
135     #[inline]
136     pub fn to_int(self) -> Simd<Element, LANES> {
137         self.0.to_int()
138     }
139
140     /// Tests the value of the specified lane.
141     ///
142     /// # Safety
143     /// `lane` must be less than `LANES`.
144     #[inline]
145     pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
146         self.0.test_unchecked(lane)
147     }
148
149     /// Tests the value of the specified lane.
150     ///
151     /// # Panics
152     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
153     #[inline]
154     pub fn test(&self, lane: usize) -> bool {
155         assert!(lane < LANES, "lane index out of range");
156         unsafe { self.test_unchecked(lane) }
157     }
158
159     /// Sets the value of the specified lane.
160     ///
161     /// # Safety
162     /// `lane` must be less than `LANES`.
163     #[inline]
164     pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
165         self.0.set_unchecked(lane, value);
166     }
167
168     /// Sets the value of the specified lane.
169     ///
170     /// # Panics
171     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
172     #[inline]
173     pub fn set(&mut self, lane: usize, value: bool) {
174         assert!(lane < LANES, "lane index out of range");
175         unsafe {
176             self.set_unchecked(lane, value);
177         }
178     }
179
180     /// Convert this mask to a bitmask, with one bit set per lane.
181     pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
182         self.0.to_bitmask()
183     }
184
185     /// Convert a bitmask to a mask.
186     pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
187         Self(mask_impl::Mask::from_bitmask(bitmask))
188     }
189
190     /// Returns true if any lane is set, or false otherwise.
191     #[inline]
192     pub fn any(self) -> bool {
193         self.0.any()
194     }
195
196     /// Returns true if all lanes are set, or false otherwise.
197     #[inline]
198     pub fn all(self) -> bool {
199         self.0.all()
200     }
201 }
202
203 // vector/array conversion
204 impl<Element, const LANES: usize> From<[bool; LANES]> for Mask<Element, LANES>
205 where
206     Element: MaskElement,
207     LaneCount<LANES>: SupportedLaneCount,
208 {
209     fn from(array: [bool; LANES]) -> Self {
210         Self::from_array(array)
211     }
212 }
213
214 impl<Element, const LANES: usize> From<Mask<Element, LANES>> for [bool; LANES]
215 where
216     Element: MaskElement,
217     LaneCount<LANES>: SupportedLaneCount,
218 {
219     fn from(vector: Mask<Element, LANES>) -> Self {
220         vector.to_array()
221     }
222 }
223
224 impl<Element, const LANES: usize> Default for Mask<Element, LANES>
225 where
226     Element: MaskElement,
227     LaneCount<LANES>: SupportedLaneCount,
228 {
229     #[inline]
230     fn default() -> Self {
231         Self::splat(false)
232     }
233 }
234
235 impl<Element, const LANES: usize> PartialEq for Mask<Element, LANES>
236 where
237     Element: MaskElement + PartialEq,
238     LaneCount<LANES>: SupportedLaneCount,
239 {
240     #[inline]
241     fn eq(&self, other: &Self) -> bool {
242         self.0 == other.0
243     }
244 }
245
246 impl<Element, const LANES: usize> PartialOrd for Mask<Element, LANES>
247 where
248     Element: MaskElement + PartialOrd,
249     LaneCount<LANES>: SupportedLaneCount,
250 {
251     #[inline]
252     fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
253         self.0.partial_cmp(&other.0)
254     }
255 }
256
257 impl<Element, const LANES: usize> core::fmt::Debug for Mask<Element, LANES>
258 where
259     Element: MaskElement + core::fmt::Debug,
260     LaneCount<LANES>: SupportedLaneCount,
261 {
262     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
263         f.debug_list()
264             .entries((0..LANES).map(|lane| self.test(lane)))
265             .finish()
266     }
267 }
268
269 impl<Element, const LANES: usize> core::ops::BitAnd for Mask<Element, LANES>
270 where
271     Element: MaskElement,
272     LaneCount<LANES>: SupportedLaneCount,
273 {
274     type Output = Self;
275     #[inline]
276     fn bitand(self, rhs: Self) -> Self {
277         Self(self.0 & rhs.0)
278     }
279 }
280
281 impl<Element, const LANES: usize> core::ops::BitAnd<bool> for Mask<Element, LANES>
282 where
283     Element: MaskElement,
284     LaneCount<LANES>: SupportedLaneCount,
285 {
286     type Output = Self;
287     #[inline]
288     fn bitand(self, rhs: bool) -> Self {
289         self & Self::splat(rhs)
290     }
291 }
292
293 impl<Element, const LANES: usize> core::ops::BitAnd<Mask<Element, LANES>> for bool
294 where
295     Element: MaskElement,
296     LaneCount<LANES>: SupportedLaneCount,
297 {
298     type Output = Mask<Element, LANES>;
299     #[inline]
300     fn bitand(self, rhs: Mask<Element, LANES>) -> Mask<Element, LANES> {
301         Mask::splat(self) & rhs
302     }
303 }
304
305 impl<Element, const LANES: usize> core::ops::BitOr for Mask<Element, LANES>
306 where
307     Element: MaskElement,
308     LaneCount<LANES>: SupportedLaneCount,
309 {
310     type Output = Self;
311     #[inline]
312     fn bitor(self, rhs: Self) -> Self {
313         Self(self.0 | rhs.0)
314     }
315 }
316
317 impl<Element, const LANES: usize> core::ops::BitOr<bool> for Mask<Element, LANES>
318 where
319     Element: MaskElement,
320     LaneCount<LANES>: SupportedLaneCount,
321 {
322     type Output = Self;
323     #[inline]
324     fn bitor(self, rhs: bool) -> Self {
325         self | Self::splat(rhs)
326     }
327 }
328
329 impl<Element, const LANES: usize> core::ops::BitOr<Mask<Element, LANES>> for bool
330 where
331     Element: MaskElement,
332     LaneCount<LANES>: SupportedLaneCount,
333 {
334     type Output = Mask<Element, LANES>;
335     #[inline]
336     fn bitor(self, rhs: Mask<Element, LANES>) -> Mask<Element, LANES> {
337         Mask::splat(self) | rhs
338     }
339 }
340
341 impl<Element, const LANES: usize> core::ops::BitXor for Mask<Element, LANES>
342 where
343     Element: MaskElement,
344     LaneCount<LANES>: SupportedLaneCount,
345 {
346     type Output = Self;
347     #[inline]
348     fn bitxor(self, rhs: Self) -> Self::Output {
349         Self(self.0 ^ rhs.0)
350     }
351 }
352
353 impl<Element, const LANES: usize> core::ops::BitXor<bool> for Mask<Element, LANES>
354 where
355     Element: MaskElement,
356     LaneCount<LANES>: SupportedLaneCount,
357 {
358     type Output = Self;
359     #[inline]
360     fn bitxor(self, rhs: bool) -> Self::Output {
361         self ^ Self::splat(rhs)
362     }
363 }
364
365 impl<Element, const LANES: usize> core::ops::BitXor<Mask<Element, LANES>> for bool
366 where
367     Element: MaskElement,
368     LaneCount<LANES>: SupportedLaneCount,
369 {
370     type Output = Mask<Element, LANES>;
371     #[inline]
372     fn bitxor(self, rhs: Mask<Element, LANES>) -> Self::Output {
373         Mask::splat(self) ^ rhs
374     }
375 }
376
377 impl<Element, const LANES: usize> core::ops::Not for Mask<Element, LANES>
378 where
379     Element: MaskElement,
380     LaneCount<LANES>: SupportedLaneCount,
381 {
382     type Output = Mask<Element, LANES>;
383     #[inline]
384     fn not(self) -> Self::Output {
385         Self(!self.0)
386     }
387 }
388
389 impl<Element, const LANES: usize> core::ops::BitAndAssign for Mask<Element, LANES>
390 where
391     Element: MaskElement,
392     LaneCount<LANES>: SupportedLaneCount,
393 {
394     #[inline]
395     fn bitand_assign(&mut self, rhs: Self) {
396         self.0 = self.0 & rhs.0;
397     }
398 }
399
400 impl<Element, const LANES: usize> core::ops::BitAndAssign<bool> for Mask<Element, LANES>
401 where
402     Element: MaskElement,
403     LaneCount<LANES>: SupportedLaneCount,
404 {
405     #[inline]
406     fn bitand_assign(&mut self, rhs: bool) {
407         *self &= Self::splat(rhs);
408     }
409 }
410
411 impl<Element, const LANES: usize> core::ops::BitOrAssign for Mask<Element, LANES>
412 where
413     Element: MaskElement,
414     LaneCount<LANES>: SupportedLaneCount,
415 {
416     #[inline]
417     fn bitor_assign(&mut self, rhs: Self) {
418         self.0 = self.0 | rhs.0;
419     }
420 }
421
422 impl<Element, const LANES: usize> core::ops::BitOrAssign<bool> for Mask<Element, LANES>
423 where
424     Element: MaskElement,
425     LaneCount<LANES>: SupportedLaneCount,
426 {
427     #[inline]
428     fn bitor_assign(&mut self, rhs: bool) {
429         *self |= Self::splat(rhs);
430     }
431 }
432
433 impl<Element, const LANES: usize> core::ops::BitXorAssign for Mask<Element, LANES>
434 where
435     Element: MaskElement,
436     LaneCount<LANES>: SupportedLaneCount,
437 {
438     #[inline]
439     fn bitxor_assign(&mut self, rhs: Self) {
440         self.0 = self.0 ^ rhs.0;
441     }
442 }
443
444 impl<Element, const LANES: usize> core::ops::BitXorAssign<bool> for Mask<Element, LANES>
445 where
446     Element: MaskElement,
447     LaneCount<LANES>: SupportedLaneCount,
448 {
449     #[inline]
450     fn bitxor_assign(&mut self, rhs: bool) {
451         *self ^= Self::splat(rhs);
452     }
453 }
454
455 /// A SIMD mask of `LANES` 8-bit values.
456 pub type Mask8<const LANES: usize> = Mask<i8, LANES>;
457
458 /// A SIMD mask of `LANES` 16-bit values.
459 pub type Mask16<const LANES: usize> = Mask<i16, LANES>;
460
461 /// A SIMD mask of `LANES` 32-bit values.
462 pub type Mask32<const LANES: usize> = Mask<i32, LANES>;
463
464 /// A SIMD mask of `LANES` 64-bit values.
465 pub type Mask64<const LANES: usize> = Mask<i64, LANES>;
466
467 /// A SIMD mask of `LANES` pointer-width values.
468 pub type MaskSize<const LANES: usize> = Mask<isize, LANES>;
469
470 /// Vector of eight 8-bit masks
471 pub type mask8x8 = Mask8<8>;
472
473 /// Vector of 16 8-bit masks
474 pub type mask8x16 = Mask8<16>;
475
476 /// Vector of 32 8-bit masks
477 pub type mask8x32 = Mask8<32>;
478
479 /// Vector of 16 8-bit masks
480 pub type mask8x64 = Mask8<64>;
481
482 /// Vector of four 16-bit masks
483 pub type mask16x4 = Mask16<4>;
484
485 /// Vector of eight 16-bit masks
486 pub type mask16x8 = Mask16<8>;
487
488 /// Vector of 16 16-bit masks
489 pub type mask16x16 = Mask16<16>;
490
491 /// Vector of 32 16-bit masks
492 pub type mask16x32 = Mask32<32>;
493
494 /// Vector of two 32-bit masks
495 pub type mask32x2 = Mask32<2>;
496
497 /// Vector of four 32-bit masks
498 pub type mask32x4 = Mask32<4>;
499
500 /// Vector of eight 32-bit masks
501 pub type mask32x8 = Mask32<8>;
502
503 /// Vector of 16 32-bit masks
504 pub type mask32x16 = Mask32<16>;
505
506 /// Vector of two 64-bit masks
507 pub type mask64x2 = Mask64<2>;
508
509 /// Vector of four 64-bit masks
510 pub type mask64x4 = Mask64<4>;
511
512 /// Vector of eight 64-bit masks
513 pub type mask64x8 = Mask64<8>;
514
515 /// Vector of two pointer-width masks
516 pub type masksizex2 = MaskSize<2>;
517
518 /// Vector of four pointer-width masks
519 pub type masksizex4 = MaskSize<4>;
520
521 /// Vector of eight pointer-width masks
522 pub type masksizex8 = MaskSize<8>;
523
524 macro_rules! impl_from {
525     { $from:ident ($from_inner:ident) => $($to:ident ($to_inner:ident)),* } => {
526         $(
527         impl<const LANES: usize> From<$from<LANES>> for $to<LANES>
528         where
529             crate::LaneCount<LANES>: crate::SupportedLaneCount,
530         {
531             fn from(value: $from<LANES>) -> Self {
532                 Self(value.0.convert())
533             }
534         }
535         )*
536     }
537 }
538 impl_from! { Mask8 (SimdI8) => Mask16 (SimdI16), Mask32 (SimdI32), Mask64 (SimdI64), MaskSize (SimdIsize) }
539 impl_from! { Mask16 (SimdI16) => Mask32 (SimdI32), Mask64 (SimdI64), MaskSize (SimdIsize), Mask8 (SimdI8) }
540 impl_from! { Mask32 (SimdI32) => Mask64 (SimdI64), MaskSize (SimdIsize), Mask8 (SimdI8), Mask16 (SimdI16) }
541 impl_from! { Mask64 (SimdI64) => MaskSize (SimdIsize), Mask8 (SimdI8), Mask16 (SimdI16), Mask32 (SimdI32) }
542 impl_from! { MaskSize (SimdIsize) => Mask8 (SimdI8), Mask16 (SimdI16), Mask32 (SimdI32), Mask64 (SimdI64) }