]> git.lizzy.rs Git - rust.git/blob - crates/core_simd/src/masks/mod.rs
Add bitmask that supports up to 64 lanes. Simplify mask op API.
[rust.git] / crates / core_simd / src / masks / mod.rs
1 //! Types and traits associated with masking lanes of vectors.
2 #![allow(non_camel_case_types)]
3
4 mod full_masks;
5 pub use full_masks::*;
6
7 mod bitmask;
8 pub use bitmask::*;
9
10 macro_rules! define_opaque_mask {
11     {
12         $(#[$attr:meta])*
13         struct $name:ident<const $lanes:ident: usize>($inner_ty:ty);
14     } => {
15         $(#[$attr])*
16         #[allow(non_camel_case_types)]
17         pub struct $name<const $lanes: usize>($inner_ty) where BitMask<LANES>: LanesAtMost64;
18
19         impl<const $lanes: usize> $name<$lanes> where BitMask<$lanes>: LanesAtMost64 {
20             /// Construct a mask by setting all lanes to the given value.
21             pub fn splat(value: bool) -> Self {
22                 Self(<$inner_ty>::splat(value))
23             }
24
25             /// Tests the value of the specified lane.
26             ///
27             /// # Panics
28             /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
29             #[inline]
30             pub fn test(&self, lane: usize) -> bool {
31                 self.0.test(lane)
32             }
33
34             /// Sets the value of the specified lane.
35             ///
36             /// # Panics
37             /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
38             #[inline]
39             pub fn set(&mut self, lane: usize, value: bool) {
40                 self.0.set(lane, value);
41             }
42         }
43
44         impl<const $lanes: usize> From<BitMask<$lanes>> for $name<$lanes>
45         where
46             BitMask<$lanes>: LanesAtMost64,
47         {
48             fn from(value: BitMask<$lanes>) -> Self {
49                 Self(value.into())
50             }
51         }
52
53         impl<const $lanes: usize> From<$name<$lanes>> for crate::BitMask<$lanes>
54         where
55             BitMask<$lanes>: LanesAtMost64,
56         {
57             fn from(value: $name<$lanes>) -> Self {
58                 value.0.into()
59             }
60         }
61
62         impl<const $lanes: usize> Copy for $name<$lanes> where BitMask<$lanes>: LanesAtMost64 {}
63
64         impl<const $lanes: usize> Clone for $name<$lanes> where BitMask<$lanes>: LanesAtMost64 {
65             #[inline]
66             fn clone(&self) -> Self {
67                 *self
68             }
69         }
70
71         impl<const $lanes: usize> Default for $name<$lanes> where BitMask<$lanes>: LanesAtMost64 {
72             #[inline]
73             fn default() -> Self {
74                 Self::splat(false)
75             }
76         }
77
78         impl<const $lanes: usize> PartialEq for $name<$lanes> where BitMask<$lanes>: LanesAtMost64 {
79             #[inline]
80             fn eq(&self, other: &Self) -> bool {
81                 self.0 == other.0
82             }
83         }
84
85         impl<const $lanes: usize> PartialOrd for $name<$lanes> where BitMask<$lanes>: LanesAtMost64 {
86             #[inline]
87             fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
88                 self.0.partial_cmp(&other.0)
89             }
90         }
91
92         impl<const $lanes: usize> core::fmt::Debug for $name<$lanes> where BitMask<$lanes>: LanesAtMost64 {
93             fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
94                 core::fmt::Debug::fmt(&self.0, f)
95             }
96         }
97
98         impl<const LANES: usize> core::ops::BitAnd for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
99             type Output = Self;
100             #[inline]
101             fn bitand(self, rhs: Self) -> Self {
102                 Self(self.0 & rhs.0)
103             }
104         }
105
106         impl<const LANES: usize> core::ops::BitAnd<bool> for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
107             type Output = Self;
108             #[inline]
109             fn bitand(self, rhs: bool) -> Self {
110                 self & Self::splat(rhs)
111             }
112         }
113
114         impl<const LANES: usize> core::ops::BitAnd<$name<LANES>> for bool where BitMask<LANES>: LanesAtMost64 {
115             type Output = $name<LANES>;
116             #[inline]
117             fn bitand(self, rhs: $name<LANES>) -> $name<LANES> {
118                 $name::<LANES>::splat(self) & rhs
119             }
120         }
121
122         impl<const LANES: usize> core::ops::BitOr for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
123             type Output = Self;
124             #[inline]
125             fn bitor(self, rhs: Self) -> Self {
126                 Self(self.0 | rhs.0)
127             }
128         }
129
130         impl<const LANES: usize> core::ops::BitOr<bool> for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
131             type Output = Self;
132             #[inline]
133             fn bitor(self, rhs: bool) -> Self {
134                 self | Self::splat(rhs)
135             }
136         }
137
138         impl<const LANES: usize> core::ops::BitOr<$name<LANES>> for bool where BitMask<LANES>: LanesAtMost64 {
139             type Output = $name<LANES>;
140             #[inline]
141             fn bitor(self, rhs: $name<LANES>) -> $name<LANES> {
142                 $name::<LANES>::splat(self) | rhs
143             }
144         }
145
146         impl<const LANES: usize> core::ops::BitXor for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
147             type Output = Self;
148             #[inline]
149             fn bitxor(self, rhs: Self) -> Self::Output {
150                 Self(self.0 ^ rhs.0)
151             }
152         }
153
154         impl<const LANES: usize> core::ops::BitXor<bool> for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
155             type Output = Self;
156             #[inline]
157             fn bitxor(self, rhs: bool) -> Self::Output {
158                 self ^ Self::splat(rhs)
159             }
160         }
161
162         impl<const LANES: usize> core::ops::BitXor<$name<LANES>> for bool where BitMask<LANES>: LanesAtMost64 {
163             type Output = $name<LANES>;
164             #[inline]
165             fn bitxor(self, rhs: $name<LANES>) -> Self::Output {
166                 $name::<LANES>::splat(self) ^ rhs
167             }
168         }
169
170         impl<const LANES: usize> core::ops::Not for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
171             type Output = $name<LANES>;
172             #[inline]
173             fn not(self) -> Self::Output {
174                 Self(!self.0)
175             }
176         }
177
178         impl<const LANES: usize> core::ops::BitAndAssign for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
179             #[inline]
180             fn bitand_assign(&mut self, rhs: Self) {
181                 self.0 &= rhs.0;
182             }
183         }
184
185         impl<const LANES: usize> core::ops::BitAndAssign<bool> for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
186             #[inline]
187             fn bitand_assign(&mut self, rhs: bool) {
188                 *self &= Self::splat(rhs);
189             }
190         }
191
192         impl<const LANES: usize> core::ops::BitOrAssign for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
193             #[inline]
194             fn bitor_assign(&mut self, rhs: Self) {
195                 self.0 |= rhs.0;
196             }
197         }
198
199         impl<const LANES: usize> core::ops::BitOrAssign<bool> for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
200             #[inline]
201             fn bitor_assign(&mut self, rhs: bool) {
202                 *self |= Self::splat(rhs);
203             }
204         }
205
206         impl<const LANES: usize> core::ops::BitXorAssign for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
207             #[inline]
208             fn bitxor_assign(&mut self, rhs: Self) {
209                 self.0 ^= rhs.0;
210             }
211         }
212
213         impl<const LANES: usize> core::ops::BitXorAssign<bool> for $name<LANES> where BitMask<LANES>: LanesAtMost64 {
214             #[inline]
215             fn bitxor_assign(&mut self, rhs: bool) {
216                 *self ^= Self::splat(rhs);
217             }
218         }
219     };
220 }
221
222 define_opaque_mask! {
223     /// Mask for vectors with `LANES` 8-bit elements.
224     ///
225     /// The layout of this type is unspecified.
226     struct Mask8<const LANES: usize>(SimdMask8<LANES>);
227 }
228
229 define_opaque_mask! {
230     /// Mask for vectors with `LANES` 16-bit elements.
231     ///
232     /// The layout of this type is unspecified.
233     struct Mask16<const LANES: usize>(SimdMask16<LANES>);
234 }
235
236 define_opaque_mask! {
237     /// Mask for vectors with `LANES` 32-bit elements.
238     ///
239     /// The layout of this type is unspecified.
240     struct Mask32<const LANES: usize>(SimdMask32<LANES>);
241 }
242
243 define_opaque_mask! {
244     /// Mask for vectors with `LANES` 64-bit elements.
245     ///
246     /// The layout of this type is unspecified.
247     struct Mask64<const LANES: usize>(SimdMask64<LANES>);
248 }
249
250 define_opaque_mask! {
251     /// Mask for vectors with `LANES` 128-bit elements.
252     ///
253     /// The layout of this type is unspecified.
254     struct Mask128<const LANES: usize>(SimdMask128<LANES>);
255 }
256
257 define_opaque_mask! {
258     /// Mask for vectors with `LANES` pointer-width elements.
259     ///
260     /// The layout of this type is unspecified.
261     struct MaskSize<const LANES: usize>(SimdMaskSize<LANES>);
262 }
263
264 macro_rules! implement_mask_ops {
265     { $($vector:ident => $mask:ident,)* } => {
266         $(
267             impl<const LANES: usize> crate::$vector<LANES> where BitMask<LANES>: LanesAtMost64 {
268                 /// Test if each lane is equal to the corresponding lane in `other`.
269                 #[inline]
270                 pub fn lanes_eq(&self, other: &Self) -> $mask<LANES> {
271                     unsafe { $mask(crate::intrinsics::simd_eq(self, other)) }
272                 }
273
274                 /// Test if each lane is not equal to the corresponding lane in `other`.
275                 #[inline]
276                 pub fn lanes_ne(&self, other: &Self) -> $mask<LANES> {
277                     unsafe { $mask(crate::intrinsics::simd_ne(self, other)) }
278                 }
279
280                 /// Test if each lane is less than the corresponding lane in `other`.
281                 #[inline]
282                 pub fn lanes_lt(&self, other: &Self) -> $mask<LANES> {
283                     unsafe { $mask(crate::intrinsics::simd_lt(self, other)) }
284                 }
285
286                 /// Test if each lane is greater than the corresponding lane in `other`.
287                 #[inline]
288                 pub fn lanes_gt(&self, other: &Self) -> $mask<LANES> {
289                     unsafe { $mask(crate::intrinsics::simd_gt(self, other)) }
290                 }
291
292                 /// Test if each lane is less than or equal to the corresponding lane in `other`.
293                 #[inline]
294                 pub fn lanes_le(&self, other: &Self) -> $mask<LANES> {
295                     unsafe { $mask(crate::intrinsics::simd_le(self, other)) }
296                 }
297
298                 /// Test if each lane is greater than or equal to the corresponding lane in `other`.
299                 #[inline]
300                 pub fn lanes_ge(&self, other: &Self) -> $mask<LANES> {
301                     unsafe { $mask(crate::intrinsics::simd_ge(self, other)) }
302                 }
303             }
304         )*
305     }
306 }
307
308 implement_mask_ops! {
309     SimdI8 => Mask8,
310     SimdI16 => Mask16,
311     SimdI32 => Mask32,
312     SimdI64 => Mask64,
313     SimdI128 => Mask128,
314     SimdIsize => MaskSize,
315
316     SimdU8 => Mask8,
317     SimdU16 => Mask16,
318     SimdU32 => Mask32,
319     SimdU64 => Mask64,
320     SimdU128 => Mask128,
321     SimdUsize => MaskSize,
322
323     SimdF32 => Mask32,
324     SimdF64 => Mask64,
325 }
326
327 /// Vector of eight 8-bit masks
328 pub type mask8x8 = Mask8<8>;
329
330 /// Vector of 16 8-bit masks
331 pub type mask8x16 = Mask8<16>;
332
333 /// Vector of 32 8-bit masks
334 pub type mask8x32 = Mask8<32>;
335
336 /// Vector of 16 8-bit masks
337 pub type mask8x64 = Mask8<64>;
338
339 /// Vector of four 16-bit masks
340 pub type mask16x4 = Mask16<4>;
341
342 /// Vector of eight 16-bit masks
343 pub type mask16x8 = Mask16<8>;
344
345 /// Vector of 16 16-bit masks
346 pub type mask16x16 = Mask16<16>;
347
348 /// Vector of 32 16-bit masks
349 pub type mask16x32 = Mask32<32>;
350
351 /// Vector of two 32-bit masks
352 pub type mask32x2 = Mask32<2>;
353
354 /// Vector of four 32-bit masks
355 pub type mask32x4 = Mask32<4>;
356
357 /// Vector of eight 32-bit masks
358 pub type mask32x8 = Mask32<8>;
359
360 /// Vector of 16 32-bit masks
361 pub type mask32x16 = Mask32<16>;
362
363 /// Vector of two 64-bit masks
364 pub type mask64x2 = Mask64<2>;
365
366 /// Vector of four 64-bit masks
367 pub type mask64x4 = Mask64<4>;
368
369 /// Vector of eight 64-bit masks
370 pub type mask64x8 = Mask64<8>;
371
372 /// Vector of two 128-bit masks
373 pub type mask128x2 = Mask128<2>;
374
375 /// Vector of four 128-bit masks
376 pub type mask128x4 = Mask128<4>;
377
378 /// Vector of two pointer-width masks
379 pub type masksizex2 = MaskSize<2>;
380
381 /// Vector of four pointer-width masks
382 pub type masksizex4 = MaskSize<4>;
383
384 /// Vector of eight pointer-width masks
385 pub type masksizex8 = MaskSize<8>;