]> git.lizzy.rs Git - rust.git/blob - crates/core_simd/src/masks.rs
Merge pull request #154 from rust-lang/feature/generic-element-type
[rust.git] / crates / core_simd / src / masks.rs
1 //! Types and traits associated with masking lanes of vectors.
2 //! Types representing
3 #![allow(non_camel_case_types)]
4
5 #[cfg_attr(
6     not(all(target_arch = "x86_64", target_feature = "avx512f")),
7     path = "masks/full_masks.rs"
8 )]
9 #[cfg_attr(
10     all(target_arch = "x86_64", target_feature = "avx512f"),
11     path = "masks/bitmask.rs"
12 )]
13 mod mask_impl;
14
15 use crate::{LaneCount, Simd, SimdElement, SupportedLaneCount};
16
17 /// Marker trait for types that may be used as SIMD mask elements.
18 pub unsafe trait MaskElement: SimdElement {
19     #[doc(hidden)]
20     fn valid<const LANES: usize>(values: Simd<Self, LANES>) -> bool
21     where
22         LaneCount<LANES>: SupportedLaneCount;
23
24     #[doc(hidden)]
25     fn eq(self, other: Self) -> bool;
26
27     #[doc(hidden)]
28     const TRUE: Self;
29
30     #[doc(hidden)]
31     const FALSE: Self;
32 }
33
34 macro_rules! impl_element {
35     { $ty:ty } => {
36         unsafe impl MaskElement for $ty {
37             fn valid<const LANES: usize>(value: Simd<Self, LANES>) -> bool
38             where
39                 LaneCount<LANES>: SupportedLaneCount,
40             {
41                 (value.lanes_eq(Simd::splat(0)) | value.lanes_eq(Simd::splat(-1))).all()
42             }
43
44             fn eq(self, other: Self) -> bool { self == other }
45
46             const TRUE: Self = -1;
47             const FALSE: Self = 0;
48         }
49     }
50 }
51
52 impl_element! { i8 }
53 impl_element! { i16 }
54 impl_element! { i32 }
55 impl_element! { i64 }
56 impl_element! { isize }
57
58 /// A SIMD vector mask for `LANES` elements of width specified by `Element`.
59 ///
60 /// The layout of this type is unspecified.
61 #[repr(transparent)]
62 pub struct Mask<T, const LANES: usize>(mask_impl::Mask<T, LANES>)
63 where
64     T: MaskElement,
65     LaneCount<LANES>: SupportedLaneCount;
66
67 impl<T, const LANES: usize> Copy for Mask<T, LANES>
68 where
69     T: MaskElement,
70     LaneCount<LANES>: SupportedLaneCount,
71 {
72 }
73
74 impl<T, const LANES: usize> Clone for Mask<T, LANES>
75 where
76     T: MaskElement,
77     LaneCount<LANES>: SupportedLaneCount,
78 {
79     fn clone(&self) -> Self {
80         *self
81     }
82 }
83
84 impl<T, const LANES: usize> Mask<T, LANES>
85 where
86     T: MaskElement,
87     LaneCount<LANES>: SupportedLaneCount,
88 {
89     /// Construct a mask by setting all lanes to the given value.
90     pub fn splat(value: bool) -> Self {
91         Self(mask_impl::Mask::splat(value))
92     }
93
94     /// Converts an array to a SIMD vector.
95     pub fn from_array(array: [bool; LANES]) -> Self {
96         let mut vector = Self::splat(false);
97         for (i, v) in array.iter().enumerate() {
98             vector.set(i, *v);
99         }
100         vector
101     }
102
103     /// Converts a SIMD vector to an array.
104     pub fn to_array(self) -> [bool; LANES] {
105         let mut array = [false; LANES];
106         for (i, v) in array.iter_mut().enumerate() {
107             *v = self.test(i);
108         }
109         array
110     }
111
112     /// Converts a vector of integers to a mask, where 0 represents `false` and -1
113     /// represents `true`.
114     ///
115     /// # Safety
116     /// All lanes must be either 0 or -1.
117     #[inline]
118     pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self {
119         Self(mask_impl::Mask::from_int_unchecked(value))
120     }
121
122     /// Converts a vector of integers to a mask, where 0 represents `false` and -1
123     /// represents `true`.
124     ///
125     /// # Panics
126     /// Panics if any lane is not 0 or -1.
127     #[inline]
128     pub fn from_int(value: Simd<T, LANES>) -> Self {
129         assert!(T::valid(value), "all values must be either 0 or -1",);
130         unsafe { Self::from_int_unchecked(value) }
131     }
132
133     /// Converts the mask to a vector of integers, where 0 represents `false` and -1
134     /// represents `true`.
135     #[inline]
136     pub fn to_int(self) -> Simd<T, LANES> {
137         self.0.to_int()
138     }
139
140     /// Tests the value of the specified lane.
141     ///
142     /// # Safety
143     /// `lane` must be less than `LANES`.
144     #[inline]
145     pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
146         self.0.test_unchecked(lane)
147     }
148
149     /// Tests the value of the specified lane.
150     ///
151     /// # Panics
152     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
153     #[inline]
154     pub fn test(&self, lane: usize) -> bool {
155         assert!(lane < LANES, "lane index out of range");
156         unsafe { self.test_unchecked(lane) }
157     }
158
159     /// Sets the value of the specified lane.
160     ///
161     /// # Safety
162     /// `lane` must be less than `LANES`.
163     #[inline]
164     pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
165         self.0.set_unchecked(lane, value);
166     }
167
168     /// Sets the value of the specified lane.
169     ///
170     /// # Panics
171     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
172     #[inline]
173     pub fn set(&mut self, lane: usize, value: bool) {
174         assert!(lane < LANES, "lane index out of range");
175         unsafe {
176             self.set_unchecked(lane, value);
177         }
178     }
179
180     /// Convert this mask to a bitmask, with one bit set per lane.
181     pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
182         self.0.to_bitmask()
183     }
184
185     /// Convert a bitmask to a mask.
186     pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
187         Self(mask_impl::Mask::from_bitmask(bitmask))
188     }
189
190     /// Returns true if any lane is set, or false otherwise.
191     #[inline]
192     pub fn any(self) -> bool {
193         self.0.any()
194     }
195
196     /// Returns true if all lanes are set, or false otherwise.
197     #[inline]
198     pub fn all(self) -> bool {
199         self.0.all()
200     }
201 }
202
203 // vector/array conversion
204 impl<T, const LANES: usize> From<[bool; LANES]> for Mask<T, LANES>
205 where
206     T: MaskElement,
207     LaneCount<LANES>: SupportedLaneCount,
208 {
209     fn from(array: [bool; LANES]) -> Self {
210         Self::from_array(array)
211     }
212 }
213
214 impl<T, const LANES: usize> From<Mask<T, LANES>> for [bool; LANES]
215 where
216     T: MaskElement,
217     LaneCount<LANES>: SupportedLaneCount,
218 {
219     fn from(vector: Mask<T, LANES>) -> Self {
220         vector.to_array()
221     }
222 }
223
224 impl<T, const LANES: usize> Default for Mask<T, LANES>
225 where
226     T: MaskElement,
227     LaneCount<LANES>: SupportedLaneCount,
228 {
229     #[inline]
230     fn default() -> Self {
231         Self::splat(false)
232     }
233 }
234
235 impl<T, const LANES: usize> PartialEq for Mask<T, LANES>
236 where
237     T: MaskElement + PartialEq,
238     LaneCount<LANES>: SupportedLaneCount,
239 {
240     #[inline]
241     fn eq(&self, other: &Self) -> bool {
242         self.0 == other.0
243     }
244 }
245
246 impl<T, const LANES: usize> PartialOrd for Mask<T, LANES>
247 where
248     T: MaskElement + PartialOrd,
249     LaneCount<LANES>: SupportedLaneCount,
250 {
251     #[inline]
252     fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
253         self.0.partial_cmp(&other.0)
254     }
255 }
256
257 impl<T, const LANES: usize> core::fmt::Debug for Mask<T, LANES>
258 where
259     T: MaskElement + core::fmt::Debug,
260     LaneCount<LANES>: SupportedLaneCount,
261 {
262     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
263         f.debug_list()
264             .entries((0..LANES).map(|lane| self.test(lane)))
265             .finish()
266     }
267 }
268
269 impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES>
270 where
271     T: MaskElement,
272     LaneCount<LANES>: SupportedLaneCount,
273 {
274     type Output = Self;
275     #[inline]
276     fn bitand(self, rhs: Self) -> Self {
277         Self(self.0 & rhs.0)
278     }
279 }
280
281 impl<T, const LANES: usize> core::ops::BitAnd<bool> for Mask<T, LANES>
282 where
283     T: MaskElement,
284     LaneCount<LANES>: SupportedLaneCount,
285 {
286     type Output = Self;
287     #[inline]
288     fn bitand(self, rhs: bool) -> Self {
289         self & Self::splat(rhs)
290     }
291 }
292
293 impl<T, const LANES: usize> core::ops::BitAnd<Mask<T, LANES>> for bool
294 where
295     T: MaskElement,
296     LaneCount<LANES>: SupportedLaneCount,
297 {
298     type Output = Mask<T, LANES>;
299     #[inline]
300     fn bitand(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> {
301         Mask::splat(self) & rhs
302     }
303 }
304
305 impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES>
306 where
307     T: MaskElement,
308     LaneCount<LANES>: SupportedLaneCount,
309 {
310     type Output = Self;
311     #[inline]
312     fn bitor(self, rhs: Self) -> Self {
313         Self(self.0 | rhs.0)
314     }
315 }
316
317 impl<T, const LANES: usize> core::ops::BitOr<bool> for Mask<T, LANES>
318 where
319     T: MaskElement,
320     LaneCount<LANES>: SupportedLaneCount,
321 {
322     type Output = Self;
323     #[inline]
324     fn bitor(self, rhs: bool) -> Self {
325         self | Self::splat(rhs)
326     }
327 }
328
329 impl<T, const LANES: usize> core::ops::BitOr<Mask<T, LANES>> for bool
330 where
331     T: MaskElement,
332     LaneCount<LANES>: SupportedLaneCount,
333 {
334     type Output = Mask<T, LANES>;
335     #[inline]
336     fn bitor(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> {
337         Mask::splat(self) | rhs
338     }
339 }
340
341 impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES>
342 where
343     T: MaskElement,
344     LaneCount<LANES>: SupportedLaneCount,
345 {
346     type Output = Self;
347     #[inline]
348     fn bitxor(self, rhs: Self) -> Self::Output {
349         Self(self.0 ^ rhs.0)
350     }
351 }
352
353 impl<T, const LANES: usize> core::ops::BitXor<bool> for Mask<T, LANES>
354 where
355     T: MaskElement,
356     LaneCount<LANES>: SupportedLaneCount,
357 {
358     type Output = Self;
359     #[inline]
360     fn bitxor(self, rhs: bool) -> Self::Output {
361         self ^ Self::splat(rhs)
362     }
363 }
364
365 impl<T, const LANES: usize> core::ops::BitXor<Mask<T, LANES>> for bool
366 where
367     T: MaskElement,
368     LaneCount<LANES>: SupportedLaneCount,
369 {
370     type Output = Mask<T, LANES>;
371     #[inline]
372     fn bitxor(self, rhs: Mask<T, LANES>) -> Self::Output {
373         Mask::splat(self) ^ rhs
374     }
375 }
376
377 impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES>
378 where
379     T: MaskElement,
380     LaneCount<LANES>: SupportedLaneCount,
381 {
382     type Output = Mask<T, LANES>;
383     #[inline]
384     fn not(self) -> Self::Output {
385         Self(!self.0)
386     }
387 }
388
389 impl<T, const LANES: usize> core::ops::BitAndAssign for Mask<T, LANES>
390 where
391     T: MaskElement,
392     LaneCount<LANES>: SupportedLaneCount,
393 {
394     #[inline]
395     fn bitand_assign(&mut self, rhs: Self) {
396         self.0 = self.0 & rhs.0;
397     }
398 }
399
400 impl<T, const LANES: usize> core::ops::BitAndAssign<bool> for Mask<T, LANES>
401 where
402     T: MaskElement,
403     LaneCount<LANES>: SupportedLaneCount,
404 {
405     #[inline]
406     fn bitand_assign(&mut self, rhs: bool) {
407         *self &= Self::splat(rhs);
408     }
409 }
410
411 impl<T, const LANES: usize> core::ops::BitOrAssign for Mask<T, LANES>
412 where
413     T: MaskElement,
414     LaneCount<LANES>: SupportedLaneCount,
415 {
416     #[inline]
417     fn bitor_assign(&mut self, rhs: Self) {
418         self.0 = self.0 | rhs.0;
419     }
420 }
421
422 impl<T, const LANES: usize> core::ops::BitOrAssign<bool> for Mask<T, LANES>
423 where
424     T: MaskElement,
425     LaneCount<LANES>: SupportedLaneCount,
426 {
427     #[inline]
428     fn bitor_assign(&mut self, rhs: bool) {
429         *self |= Self::splat(rhs);
430     }
431 }
432
433 impl<T, const LANES: usize> core::ops::BitXorAssign for Mask<T, LANES>
434 where
435     T: MaskElement,
436     LaneCount<LANES>: SupportedLaneCount,
437 {
438     #[inline]
439     fn bitxor_assign(&mut self, rhs: Self) {
440         self.0 = self.0 ^ rhs.0;
441     }
442 }
443
444 impl<T, const LANES: usize> core::ops::BitXorAssign<bool> for Mask<T, LANES>
445 where
446     T: MaskElement,
447     LaneCount<LANES>: SupportedLaneCount,
448 {
449     #[inline]
450     fn bitxor_assign(&mut self, rhs: bool) {
451         *self ^= Self::splat(rhs);
452     }
453 }
454
455 /// Vector of eight 8-bit masks
456 pub type mask8x8 = Mask<i8, 8>;
457
458 /// Vector of 16 8-bit masks
459 pub type mask8x16 = Mask<i8, 16>;
460
461 /// Vector of 32 8-bit masks
462 pub type mask8x32 = Mask<i8, 32>;
463
464 /// Vector of 16 8-bit masks
465 pub type mask8x64 = Mask<i8, 64>;
466
467 /// Vector of four 16-bit masks
468 pub type mask16x4 = Mask<i16, 4>;
469
470 /// Vector of eight 16-bit masks
471 pub type mask16x8 = Mask<i16, 8>;
472
473 /// Vector of 16 16-bit masks
474 pub type mask16x16 = Mask<i16, 16>;
475
476 /// Vector of 32 16-bit masks
477 pub type mask16x32 = Mask<i32, 32>;
478
479 /// Vector of two 32-bit masks
480 pub type mask32x2 = Mask<i32, 2>;
481
482 /// Vector of four 32-bit masks
483 pub type mask32x4 = Mask<i32, 4>;
484
485 /// Vector of eight 32-bit masks
486 pub type mask32x8 = Mask<i32, 8>;
487
488 /// Vector of 16 32-bit masks
489 pub type mask32x16 = Mask<i32, 16>;
490
491 /// Vector of two 64-bit masks
492 pub type mask64x2 = Mask<i64, 2>;
493
494 /// Vector of four 64-bit masks
495 pub type mask64x4 = Mask<i64, 4>;
496
497 /// Vector of eight 64-bit masks
498 pub type mask64x8 = Mask<i64, 8>;
499
500 /// Vector of two pointer-width masks
501 pub type masksizex2 = Mask<isize, 2>;
502
503 /// Vector of four pointer-width masks
504 pub type masksizex4 = Mask<isize, 4>;
505
506 /// Vector of eight pointer-width masks
507 pub type masksizex8 = Mask<isize, 8>;
508
509 macro_rules! impl_from {
510     { $from:ty  => $($to:ty),* } => {
511         $(
512         impl<const LANES: usize> From<Mask<$from, LANES>> for Mask<$to, LANES>
513         where
514             LaneCount<LANES>: SupportedLaneCount,
515         {
516             fn from(value: Mask<$from, LANES>) -> Self {
517                 Self(value.0.convert())
518             }
519         }
520         )*
521     }
522 }
523 impl_from! { i8 => i16, i32, i64, isize }
524 impl_from! { i16 => i32, i64, isize, i8 }
525 impl_from! { i32 => i64, isize, i8, i16 }
526 impl_from! { i64 => isize, i8, i16, i32 }
527 impl_from! { isize => i8, i16, i32, i64 }