]> git.lizzy.rs Git - rust.git/blob - crates/core_simd/src/masks.rs
Renovate for Edition 2021
[rust.git] / crates / core_simd / src / masks.rs
1 //! Types and traits associated with masking lanes of vectors.
2 //! Types representing
3 #![allow(non_camel_case_types)]
4
5 #[cfg_attr(
6     not(all(target_arch = "x86_64", target_feature = "avx512f")),
7     path = "masks/full_masks.rs"
8 )]
9 #[cfg_attr(
10     all(target_arch = "x86_64", target_feature = "avx512f"),
11     path = "masks/bitmask.rs"
12 )]
13 mod mask_impl;
14
15 use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
16 use core::cmp::Ordering;
17 use core::fmt;
18
19 /// Marker trait for types that may be used as SIMD mask elements.
20 pub unsafe trait MaskElement: SimdElement {
21     #[doc(hidden)]
22     fn valid<const LANES: usize>(values: Simd<Self, LANES>) -> bool
23     where
24         LaneCount<LANES>: SupportedLaneCount;
25
26     #[doc(hidden)]
27     fn eq(self, other: Self) -> bool;
28
29     #[doc(hidden)]
30     const TRUE: Self;
31
32     #[doc(hidden)]
33     const FALSE: Self;
34 }
35
36 macro_rules! impl_element {
37     { $ty:ty } => {
38         unsafe impl MaskElement for $ty {
39             fn valid<const LANES: usize>(value: Simd<Self, LANES>) -> bool
40             where
41                 LaneCount<LANES>: SupportedLaneCount,
42             {
43                 (value.lanes_eq(Simd::splat(0)) | value.lanes_eq(Simd::splat(-1))).all()
44             }
45
46             fn eq(self, other: Self) -> bool { self == other }
47
48             const TRUE: Self = -1;
49             const FALSE: Self = 0;
50         }
51     }
52 }
53
54 impl_element! { i8 }
55 impl_element! { i16 }
56 impl_element! { i32 }
57 impl_element! { i64 }
58 impl_element! { isize }
59
60 /// A SIMD vector mask for `LANES` elements of width specified by `Element`.
61 ///
62 /// The layout of this type is unspecified.
63 #[repr(transparent)]
64 pub struct Mask<T, const LANES: usize>(mask_impl::Mask<T, LANES>)
65 where
66     T: MaskElement,
67     LaneCount<LANES>: SupportedLaneCount;
68
69 impl<T, const LANES: usize> Copy for Mask<T, LANES>
70 where
71     T: MaskElement,
72     LaneCount<LANES>: SupportedLaneCount,
73 {
74 }
75
76 impl<T, const LANES: usize> Clone for Mask<T, LANES>
77 where
78     T: MaskElement,
79     LaneCount<LANES>: SupportedLaneCount,
80 {
81     fn clone(&self) -> Self {
82         *self
83     }
84 }
85
86 impl<T, const LANES: usize> Mask<T, LANES>
87 where
88     T: MaskElement,
89     LaneCount<LANES>: SupportedLaneCount,
90 {
91     /// Construct a mask by setting all lanes to the given value.
92     pub fn splat(value: bool) -> Self {
93         Self(mask_impl::Mask::splat(value))
94     }
95
96     /// Converts an array to a SIMD vector.
97     pub fn from_array(array: [bool; LANES]) -> Self {
98         let mut vector = Self::splat(false);
99         for (i, v) in array.iter().enumerate() {
100             vector.set(i, *v);
101         }
102         vector
103     }
104
105     /// Converts a SIMD vector to an array.
106     pub fn to_array(self) -> [bool; LANES] {
107         let mut array = [false; LANES];
108         for (i, v) in array.iter_mut().enumerate() {
109             *v = self.test(i);
110         }
111         array
112     }
113
114     /// Converts a vector of integers to a mask, where 0 represents `false` and -1
115     /// represents `true`.
116     ///
117     /// # Safety
118     /// All lanes must be either 0 or -1.
119     #[inline]
120     pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self {
121         unsafe { Self(mask_impl::Mask::from_int_unchecked(value)) }
122     }
123
124     /// Converts a vector of integers to a mask, where 0 represents `false` and -1
125     /// represents `true`.
126     ///
127     /// # Panics
128     /// Panics if any lane is not 0 or -1.
129     #[inline]
130     pub fn from_int(value: Simd<T, LANES>) -> Self {
131         assert!(T::valid(value), "all values must be either 0 or -1",);
132         unsafe { Self::from_int_unchecked(value) }
133     }
134
135     /// Converts the mask to a vector of integers, where 0 represents `false` and -1
136     /// represents `true`.
137     #[inline]
138     pub fn to_int(self) -> Simd<T, LANES> {
139         self.0.to_int()
140     }
141
142     /// Tests the value of the specified lane.
143     ///
144     /// # Safety
145     /// `lane` must be less than `LANES`.
146     #[inline]
147     pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
148         unsafe { self.0.test_unchecked(lane) }
149     }
150
151     /// Tests the value of the specified lane.
152     ///
153     /// # Panics
154     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
155     #[inline]
156     pub fn test(&self, lane: usize) -> bool {
157         assert!(lane < LANES, "lane index out of range");
158         unsafe { self.test_unchecked(lane) }
159     }
160
161     /// Sets the value of the specified lane.
162     ///
163     /// # Safety
164     /// `lane` must be less than `LANES`.
165     #[inline]
166     pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
167         unsafe {
168             self.0.set_unchecked(lane, value);
169         }
170     }
171
172     /// Sets the value of the specified lane.
173     ///
174     /// # Panics
175     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
176     #[inline]
177     pub fn set(&mut self, lane: usize, value: bool) {
178         assert!(lane < LANES, "lane index out of range");
179         unsafe {
180             self.set_unchecked(lane, value);
181         }
182     }
183
184     /// Convert this mask to a bitmask, with one bit set per lane.
185     #[cfg(feature = "generic_const_exprs")]
186     pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
187         self.0.to_bitmask()
188     }
189
190     /// Convert a bitmask to a mask.
191     #[cfg(feature = "generic_const_exprs")]
192     pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
193         Self(mask_impl::Mask::from_bitmask(bitmask))
194     }
195
196     /// Returns true if any lane is set, or false otherwise.
197     #[inline]
198     pub fn any(self) -> bool {
199         self.0.any()
200     }
201
202     /// Returns true if all lanes are set, or false otherwise.
203     #[inline]
204     pub fn all(self) -> bool {
205         self.0.all()
206     }
207 }
208
209 // vector/array conversion
210 impl<T, const LANES: usize> From<[bool; LANES]> for Mask<T, LANES>
211 where
212     T: MaskElement,
213     LaneCount<LANES>: SupportedLaneCount,
214 {
215     fn from(array: [bool; LANES]) -> Self {
216         Self::from_array(array)
217     }
218 }
219
220 impl<T, const LANES: usize> From<Mask<T, LANES>> for [bool; LANES]
221 where
222     T: MaskElement,
223     LaneCount<LANES>: SupportedLaneCount,
224 {
225     fn from(vector: Mask<T, LANES>) -> Self {
226         vector.to_array()
227     }
228 }
229
230 impl<T, const LANES: usize> Default for Mask<T, LANES>
231 where
232     T: MaskElement,
233     LaneCount<LANES>: SupportedLaneCount,
234 {
235     #[inline]
236     fn default() -> Self {
237         Self::splat(false)
238     }
239 }
240
241 impl<T, const LANES: usize> PartialEq for Mask<T, LANES>
242 where
243     T: MaskElement + PartialEq,
244     LaneCount<LANES>: SupportedLaneCount,
245 {
246     #[inline]
247     fn eq(&self, other: &Self) -> bool {
248         self.0 == other.0
249     }
250 }
251
252 impl<T, const LANES: usize> PartialOrd for Mask<T, LANES>
253 where
254     T: MaskElement + PartialOrd,
255     LaneCount<LANES>: SupportedLaneCount,
256 {
257     #[inline]
258     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
259         self.0.partial_cmp(&other.0)
260     }
261 }
262
263 impl<T, const LANES: usize> fmt::Debug for Mask<T, LANES>
264 where
265     T: MaskElement + fmt::Debug,
266     LaneCount<LANES>: SupportedLaneCount,
267 {
268     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269         f.debug_list()
270             .entries((0..LANES).map(|lane| self.test(lane)))
271             .finish()
272     }
273 }
274
275 impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES>
276 where
277     T: MaskElement,
278     LaneCount<LANES>: SupportedLaneCount,
279 {
280     type Output = Self;
281     #[inline]
282     fn bitand(self, rhs: Self) -> Self {
283         Self(self.0 & rhs.0)
284     }
285 }
286
287 impl<T, const LANES: usize> core::ops::BitAnd<bool> for Mask<T, LANES>
288 where
289     T: MaskElement,
290     LaneCount<LANES>: SupportedLaneCount,
291 {
292     type Output = Self;
293     #[inline]
294     fn bitand(self, rhs: bool) -> Self {
295         self & Self::splat(rhs)
296     }
297 }
298
299 impl<T, const LANES: usize> core::ops::BitAnd<Mask<T, LANES>> for bool
300 where
301     T: MaskElement,
302     LaneCount<LANES>: SupportedLaneCount,
303 {
304     type Output = Mask<T, LANES>;
305     #[inline]
306     fn bitand(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> {
307         Mask::splat(self) & rhs
308     }
309 }
310
311 impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES>
312 where
313     T: MaskElement,
314     LaneCount<LANES>: SupportedLaneCount,
315 {
316     type Output = Self;
317     #[inline]
318     fn bitor(self, rhs: Self) -> Self {
319         Self(self.0 | rhs.0)
320     }
321 }
322
323 impl<T, const LANES: usize> core::ops::BitOr<bool> for Mask<T, LANES>
324 where
325     T: MaskElement,
326     LaneCount<LANES>: SupportedLaneCount,
327 {
328     type Output = Self;
329     #[inline]
330     fn bitor(self, rhs: bool) -> Self {
331         self | Self::splat(rhs)
332     }
333 }
334
335 impl<T, const LANES: usize> core::ops::BitOr<Mask<T, LANES>> for bool
336 where
337     T: MaskElement,
338     LaneCount<LANES>: SupportedLaneCount,
339 {
340     type Output = Mask<T, LANES>;
341     #[inline]
342     fn bitor(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> {
343         Mask::splat(self) | rhs
344     }
345 }
346
347 impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES>
348 where
349     T: MaskElement,
350     LaneCount<LANES>: SupportedLaneCount,
351 {
352     type Output = Self;
353     #[inline]
354     fn bitxor(self, rhs: Self) -> Self::Output {
355         Self(self.0 ^ rhs.0)
356     }
357 }
358
359 impl<T, const LANES: usize> core::ops::BitXor<bool> for Mask<T, LANES>
360 where
361     T: MaskElement,
362     LaneCount<LANES>: SupportedLaneCount,
363 {
364     type Output = Self;
365     #[inline]
366     fn bitxor(self, rhs: bool) -> Self::Output {
367         self ^ Self::splat(rhs)
368     }
369 }
370
371 impl<T, const LANES: usize> core::ops::BitXor<Mask<T, LANES>> for bool
372 where
373     T: MaskElement,
374     LaneCount<LANES>: SupportedLaneCount,
375 {
376     type Output = Mask<T, LANES>;
377     #[inline]
378     fn bitxor(self, rhs: Mask<T, LANES>) -> Self::Output {
379         Mask::splat(self) ^ rhs
380     }
381 }
382
383 impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES>
384 where
385     T: MaskElement,
386     LaneCount<LANES>: SupportedLaneCount,
387 {
388     type Output = Mask<T, LANES>;
389     #[inline]
390     fn not(self) -> Self::Output {
391         Self(!self.0)
392     }
393 }
394
395 impl<T, const LANES: usize> core::ops::BitAndAssign for Mask<T, LANES>
396 where
397     T: MaskElement,
398     LaneCount<LANES>: SupportedLaneCount,
399 {
400     #[inline]
401     fn bitand_assign(&mut self, rhs: Self) {
402         self.0 = self.0 & rhs.0;
403     }
404 }
405
406 impl<T, const LANES: usize> core::ops::BitAndAssign<bool> for Mask<T, LANES>
407 where
408     T: MaskElement,
409     LaneCount<LANES>: SupportedLaneCount,
410 {
411     #[inline]
412     fn bitand_assign(&mut self, rhs: bool) {
413         *self &= Self::splat(rhs);
414     }
415 }
416
417 impl<T, const LANES: usize> core::ops::BitOrAssign for Mask<T, LANES>
418 where
419     T: MaskElement,
420     LaneCount<LANES>: SupportedLaneCount,
421 {
422     #[inline]
423     fn bitor_assign(&mut self, rhs: Self) {
424         self.0 = self.0 | rhs.0;
425     }
426 }
427
428 impl<T, const LANES: usize> core::ops::BitOrAssign<bool> for Mask<T, LANES>
429 where
430     T: MaskElement,
431     LaneCount<LANES>: SupportedLaneCount,
432 {
433     #[inline]
434     fn bitor_assign(&mut self, rhs: bool) {
435         *self |= Self::splat(rhs);
436     }
437 }
438
439 impl<T, const LANES: usize> core::ops::BitXorAssign for Mask<T, LANES>
440 where
441     T: MaskElement,
442     LaneCount<LANES>: SupportedLaneCount,
443 {
444     #[inline]
445     fn bitxor_assign(&mut self, rhs: Self) {
446         self.0 = self.0 ^ rhs.0;
447     }
448 }
449
450 impl<T, const LANES: usize> core::ops::BitXorAssign<bool> for Mask<T, LANES>
451 where
452     T: MaskElement,
453     LaneCount<LANES>: SupportedLaneCount,
454 {
455     #[inline]
456     fn bitxor_assign(&mut self, rhs: bool) {
457         *self ^= Self::splat(rhs);
458     }
459 }
460
461 /// Vector of eight 8-bit masks
462 pub type mask8x8 = Mask<i8, 8>;
463
464 /// Vector of 16 8-bit masks
465 pub type mask8x16 = Mask<i8, 16>;
466
467 /// Vector of 32 8-bit masks
468 pub type mask8x32 = Mask<i8, 32>;
469
470 /// Vector of 16 8-bit masks
471 pub type mask8x64 = Mask<i8, 64>;
472
473 /// Vector of four 16-bit masks
474 pub type mask16x4 = Mask<i16, 4>;
475
476 /// Vector of eight 16-bit masks
477 pub type mask16x8 = Mask<i16, 8>;
478
479 /// Vector of 16 16-bit masks
480 pub type mask16x16 = Mask<i16, 16>;
481
482 /// Vector of 32 16-bit masks
483 pub type mask16x32 = Mask<i32, 32>;
484
485 /// Vector of two 32-bit masks
486 pub type mask32x2 = Mask<i32, 2>;
487
488 /// Vector of four 32-bit masks
489 pub type mask32x4 = Mask<i32, 4>;
490
491 /// Vector of eight 32-bit masks
492 pub type mask32x8 = Mask<i32, 8>;
493
494 /// Vector of 16 32-bit masks
495 pub type mask32x16 = Mask<i32, 16>;
496
497 /// Vector of two 64-bit masks
498 pub type mask64x2 = Mask<i64, 2>;
499
500 /// Vector of four 64-bit masks
501 pub type mask64x4 = Mask<i64, 4>;
502
503 /// Vector of eight 64-bit masks
504 pub type mask64x8 = Mask<i64, 8>;
505
506 /// Vector of two pointer-width masks
507 pub type masksizex2 = Mask<isize, 2>;
508
509 /// Vector of four pointer-width masks
510 pub type masksizex4 = Mask<isize, 4>;
511
512 /// Vector of eight pointer-width masks
513 pub type masksizex8 = Mask<isize, 8>;
514
515 macro_rules! impl_from {
516     { $from:ty  => $($to:ty),* } => {
517         $(
518         impl<const LANES: usize> From<Mask<$from, LANES>> for Mask<$to, LANES>
519         where
520             LaneCount<LANES>: SupportedLaneCount,
521         {
522             fn from(value: Mask<$from, LANES>) -> Self {
523                 Self(value.0.convert())
524             }
525         }
526         )*
527     }
528 }
529 impl_from! { i8 => i16, i32, i64, isize }
530 impl_from! { i16 => i32, i64, isize, i8 }
531 impl_from! { i32 => i64, isize, i8, i16 }
532 impl_from! { i64 => isize, i8, i16, i32 }
533 impl_from! { isize => i8, i16, i32, i64 }