]> git.lizzy.rs Git - rust.git/blob - crates/core_simd/src/masks/full_masks.rs
Add bitmask that supports up to 64 lanes. Simplify mask op API.
[rust.git] / crates / core_simd / src / masks / full_masks.rs
1 //! Masks that take up full SIMD vector registers.
2
3 /// The error type returned when converting an integer to a mask fails.
4 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
5 pub struct TryFromMaskError(());
6
7 impl core::fmt::Display for TryFromMaskError {
8     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
9         write!(
10             f,
11             "mask vector must have all bits set or unset in each lane"
12         )
13     }
14 }
15
16 macro_rules! define_mask {
17     { $(#[$attr:meta])* struct $name:ident<const $lanes:ident: usize>($type:ty); } => {
18         $(#[$attr])*
19         #[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash)]
20         #[repr(transparent)]
21         pub struct $name<const $lanes: usize>($type);
22
23         impl<const $lanes: usize> $name<$lanes> {
24             /// Construct a mask by setting all lanes to the given value.
25             pub fn splat(value: bool) -> Self {
26                 Self(<$type>::splat(
27                     if value {
28                         -1
29                     } else {
30                         0
31                     }
32                 ))
33             }
34
35             /// Tests the value of the specified lane.
36             ///
37             /// # Panics
38             /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
39             #[inline]
40             pub fn test(&self, lane: usize) -> bool {
41                 assert!(lane < LANES, "lane index out of range");
42                 self.0[lane] == -1
43             }
44
45             /// Sets the value of the specified lane.
46             ///
47             /// # Panics
48             /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
49             #[inline]
50             pub fn set(&mut self, lane: usize, value: bool) {
51                 assert!(lane < LANES, "lane index out of range");
52                 self.0[lane] = if value {
53                     -1
54                 } else {
55                     0
56                 }
57             }
58         }
59
60         impl<const $lanes: usize> core::convert::From<bool> for $name<$lanes> {
61             fn from(value: bool) -> Self {
62                 Self::splat(value)
63             }
64         }
65
66         impl<const $lanes: usize> core::convert::TryFrom<$type> for $name<$lanes> {
67             type Error = TryFromMaskError;
68             fn try_from(value: $type) -> Result<Self, Self::Error> {
69                 if value.as_slice().iter().all(|x| *x == 0 || *x == -1) {
70                     Ok(Self(value))
71                 } else {
72                     Err(TryFromMaskError(()))
73                 }
74             }
75         }
76
77         impl<const $lanes: usize> core::convert::From<$name<$lanes>> for $type {
78             fn from(value: $name<$lanes>) -> Self {
79                 value.0
80             }
81         }
82
83         impl<const $lanes: usize> core::convert::From<crate::BitMask<$lanes>> for $name<$lanes>
84         where
85             crate::BitMask<$lanes>: crate::LanesAtMost64,
86         {
87             fn from(value: crate::BitMask<$lanes>) -> Self {
88                 // TODO use an intrinsic to do this efficiently (with LLVM's sext instruction)
89                 let mut mask = Self::splat(false);
90                 for lane in 0..LANES {
91                     mask.set(lane, value.test(lane));
92                 }
93                 mask
94             }
95         }
96
97         impl<const $lanes: usize> core::convert::From<$name<$lanes>> for crate::BitMask<$lanes>
98         where
99             crate::BitMask<$lanes>: crate::LanesAtMost64,
100         {
101             fn from(value: $name<$lanes>) -> Self {
102                 // TODO use an intrinsic to do this efficiently (with LLVM's trunc instruction)
103                 let mut mask = Self::splat(false);
104                 for lane in 0..LANES {
105                     mask.set(lane, value.test(lane));
106                 }
107                 mask
108             }
109         }
110
111         impl<const $lanes: usize> core::fmt::Debug for $name<$lanes> {
112             fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
113                 f.debug_list()
114                     .entries((0..LANES).map(|lane| self.test(lane)))
115                     .finish()
116             }
117         }
118
119         impl<const $lanes: usize> core::fmt::Binary for $name<$lanes> {
120             fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
121                 core::fmt::Binary::fmt(&self.0, f)
122             }
123         }
124
125         impl<const $lanes: usize> core::fmt::Octal for $name<$lanes> {
126             fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
127                 core::fmt::Octal::fmt(&self.0, f)
128             }
129         }
130
131         impl<const $lanes: usize> core::fmt::LowerHex for $name<$lanes> {
132             fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
133                 core::fmt::LowerHex::fmt(&self.0, f)
134             }
135         }
136
137         impl<const $lanes: usize> core::fmt::UpperHex for $name<$lanes> {
138             fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
139                 core::fmt::UpperHex::fmt(&self.0, f)
140             }
141         }
142
143         impl<const LANES: usize> core::ops::BitAnd for $name<LANES> {
144             type Output = Self;
145             #[inline]
146             fn bitand(self, rhs: Self) -> Self {
147                 Self(self.0 & rhs.0)
148             }
149         }
150
151         impl<const LANES: usize> core::ops::BitAnd<bool> for $name<LANES> {
152             type Output = Self;
153             #[inline]
154             fn bitand(self, rhs: bool) -> Self {
155                 self & Self::splat(rhs)
156             }
157         }
158
159         impl<const LANES: usize> core::ops::BitAnd<$name<LANES>> for bool {
160             type Output = $name<LANES>;
161             #[inline]
162             fn bitand(self, rhs: $name<LANES>) -> $name<LANES> {
163                 $name::<LANES>::splat(self) & rhs
164             }
165         }
166
167         impl<const LANES: usize> core::ops::BitOr for $name<LANES> {
168             type Output = Self;
169             #[inline]
170             fn bitor(self, rhs: Self) -> Self {
171                 Self(self.0 | rhs.0)
172             }
173         }
174
175         impl<const LANES: usize> core::ops::BitOr<bool> for $name<LANES> {
176             type Output = Self;
177             #[inline]
178             fn bitor(self, rhs: bool) -> Self {
179                 self | Self::splat(rhs)
180             }
181         }
182
183         impl<const LANES: usize> core::ops::BitOr<$name<LANES>> for bool {
184             type Output = $name<LANES>;
185             #[inline]
186             fn bitor(self, rhs: $name<LANES>) -> $name<LANES> {
187                 $name::<LANES>::splat(self) | rhs
188             }
189         }
190
191         impl<const LANES: usize> core::ops::BitXor for $name<LANES> {
192             type Output = Self;
193             #[inline]
194             fn bitxor(self, rhs: Self) -> Self::Output {
195                 Self(self.0 ^ rhs.0)
196             }
197         }
198
199         impl<const LANES: usize> core::ops::BitXor<bool> for $name<LANES> {
200             type Output = Self;
201             #[inline]
202             fn bitxor(self, rhs: bool) -> Self::Output {
203                 self ^ Self::splat(rhs)
204             }
205         }
206
207         impl<const LANES: usize> core::ops::BitXor<$name<LANES>> for bool {
208             type Output = $name<LANES>;
209             #[inline]
210             fn bitxor(self, rhs: $name<LANES>) -> Self::Output {
211                 $name::<LANES>::splat(self) ^ rhs
212             }
213         }
214
215         impl<const LANES: usize> core::ops::Not for $name<LANES> {
216             type Output = $name<LANES>;
217             #[inline]
218             fn not(self) -> Self::Output {
219                 Self(!self.0)
220             }
221         }
222
223         impl<const LANES: usize> core::ops::BitAndAssign for $name<LANES> {
224             #[inline]
225             fn bitand_assign(&mut self, rhs: Self) {
226                 self.0 &= rhs.0;
227             }
228         }
229
230         impl<const LANES: usize> core::ops::BitAndAssign<bool> for $name<LANES> {
231             #[inline]
232             fn bitand_assign(&mut self, rhs: bool) {
233                 *self &= Self::splat(rhs);
234             }
235         }
236
237         impl<const LANES: usize> core::ops::BitOrAssign for $name<LANES> {
238             #[inline]
239             fn bitor_assign(&mut self, rhs: Self) {
240                 self.0 |= rhs.0;
241             }
242         }
243
244         impl<const LANES: usize> core::ops::BitOrAssign<bool> for $name<LANES> {
245             #[inline]
246             fn bitor_assign(&mut self, rhs: bool) {
247                 *self |= Self::splat(rhs);
248             }
249         }
250
251         impl<const LANES: usize> core::ops::BitXorAssign for $name<LANES> {
252             #[inline]
253             fn bitxor_assign(&mut self, rhs: Self) {
254                 self.0 ^= rhs.0;
255             }
256         }
257
258         impl<const LANES: usize> core::ops::BitXorAssign<bool> for $name<LANES> {
259             #[inline]
260             fn bitxor_assign(&mut self, rhs: bool) {
261                 *self ^= Self::splat(rhs);
262             }
263         }
264     }
265 }
266
267 define_mask! {
268     /// A mask equivalent to [SimdI8](crate::SimdI8), where all bits in the lane must be either set
269     /// or unset.
270     struct SimdMask8<const LANES: usize>(crate::SimdI8<LANES>);
271 }
272
273 define_mask! {
274     /// A mask equivalent to [SimdI16](crate::SimdI16), where all bits in the lane must be either set
275     /// or unset.
276     struct SimdMask16<const LANES: usize>(crate::SimdI16<LANES>);
277 }
278
279 define_mask! {
280     /// A mask equivalent to [SimdI32](crate::SimdI32), where all bits in the lane must be either set
281     /// or unset.
282     struct SimdMask32<const LANES: usize>(crate::SimdI32<LANES>);
283 }
284
285 define_mask! {
286     /// A mask equivalent to [SimdI64](crate::SimdI64), where all bits in the lane must be either set
287     /// or unset.
288     struct SimdMask64<const LANES: usize>(crate::SimdI64<LANES>);
289 }
290
291 define_mask! {
292     /// A mask equivalent to [SimdI128](crate::SimdI128), where all bits in the lane must be either set
293     /// or unset.
294     struct SimdMask128<const LANES: usize>(crate::SimdI64<LANES>);
295 }
296
297 define_mask! {
298     /// A mask equivalent to [SimdIsize](crate::SimdIsize), where all bits in the lane must be either set
299     /// or unset.
300     struct SimdMaskSize<const LANES: usize>(crate::SimdI64<LANES>);
301 }