]> git.lizzy.rs Git - rust.git/blob - crates/core_simd/src/masks/bitmask.rs
Renovate for Edition 2021
[rust.git] / crates / core_simd / src / masks / bitmask.rs
1 use super::MaskElement;
2 use crate::simd::intrinsics;
3 use crate::simd::{LaneCount, Simd, SupportedLaneCount};
4 use core::marker::PhantomData;
5
6 /// A mask where each lane is represented by a single bit.
7 #[repr(transparent)]
8 pub struct Mask<T, const LANES: usize>(
9     <LaneCount<LANES> as SupportedLaneCount>::BitMask,
10     PhantomData<T>,
11 )
12 where
13     T: MaskElement,
14     LaneCount<LANES>: SupportedLaneCount;
15
16 impl<T, const LANES: usize> Copy for Mask<T, LANES>
17 where
18     T: MaskElement,
19     LaneCount<LANES>: SupportedLaneCount,
20 {
21 }
22
23 impl<T, const LANES: usize> Clone for Mask<T, LANES>
24 where
25     T: MaskElement,
26     LaneCount<LANES>: SupportedLaneCount,
27 {
28     fn clone(&self) -> Self {
29         *self
30     }
31 }
32
33 impl<T, const LANES: usize> PartialEq for Mask<T, LANES>
34 where
35     T: MaskElement,
36     LaneCount<LANES>: SupportedLaneCount,
37 {
38     fn eq(&self, other: &Self) -> bool {
39         self.0.as_ref() == other.0.as_ref()
40     }
41 }
42
43 impl<T, const LANES: usize> PartialOrd for Mask<T, LANES>
44 where
45     T: MaskElement,
46     LaneCount<LANES>: SupportedLaneCount,
47 {
48     fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
49         self.0.as_ref().partial_cmp(other.0.as_ref())
50     }
51 }
52
53 impl<T, const LANES: usize> Eq for Mask<T, LANES>
54 where
55     T: MaskElement,
56     LaneCount<LANES>: SupportedLaneCount,
57 {
58 }
59
60 impl<T, const LANES: usize> Ord for Mask<T, LANES>
61 where
62     T: MaskElement,
63     LaneCount<LANES>: SupportedLaneCount,
64 {
65     fn cmp(&self, other: &Self) -> core::cmp::Ordering {
66         self.0.as_ref().cmp(other.0.as_ref())
67     }
68 }
69
70 impl<T, const LANES: usize> Mask<T, LANES>
71 where
72     T: MaskElement,
73     LaneCount<LANES>: SupportedLaneCount,
74 {
75     #[inline]
76     pub fn splat(value: bool) -> Self {
77         let mut mask = <LaneCount<LANES> as SupportedLaneCount>::BitMask::default();
78         if value {
79             mask.as_mut().fill(u8::MAX)
80         } else {
81             mask.as_mut().fill(u8::MIN)
82         }
83         if LANES % 8 > 0 {
84             *mask.as_mut().last_mut().unwrap() &= u8::MAX >> (8 - LANES % 8);
85         }
86         Self(mask, PhantomData)
87     }
88
89     #[inline]
90     pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
91         (self.0.as_ref()[lane / 8] >> (lane % 8)) & 0x1 > 0
92     }
93
94     #[inline]
95     pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
96         unsafe {
97             self.0.as_mut()[lane / 8] ^= ((value ^ self.test_unchecked(lane)) as u8) << (lane % 8)
98         }
99     }
100
101     #[inline]
102     pub fn to_int(self) -> Simd<T, LANES> {
103         unsafe {
104             let mask: <LaneCount<LANES> as SupportedLaneCount>::IntBitMask =
105                 core::mem::transmute_copy(&self);
106             intrinsics::simd_select_bitmask(mask, Simd::splat(T::TRUE), Simd::splat(T::FALSE))
107         }
108     }
109
110     #[inline]
111     pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self {
112         // TODO remove the transmute when rustc is more flexible
113         assert_eq!(
114             core::mem::size_of::<<LaneCount::<LANES> as SupportedLaneCount>::BitMask>(),
115             core::mem::size_of::<<LaneCount::<LANES> as SupportedLaneCount>::IntBitMask>(),
116         );
117         unsafe {
118             let mask: <LaneCount<LANES> as SupportedLaneCount>::IntBitMask =
119                 intrinsics::simd_bitmask(value);
120             Self(core::mem::transmute_copy(&mask), PhantomData)
121         }
122     }
123
124     #[cfg(feature = "generic_const_exprs")]
125     #[inline]
126     pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
127         // Safety: these are the same type and we are laundering the generic
128         unsafe { core::mem::transmute_copy(&self.0) }
129     }
130
131     #[cfg(feature = "generic_const_exprs")]
132     #[inline]
133     pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
134         // Safety: these are the same type and we are laundering the generic
135         Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
136     }
137
138     #[inline]
139     pub fn convert<U>(self) -> Mask<U, LANES>
140     where
141         U: MaskElement,
142     {
143         unsafe { core::mem::transmute_copy(&self) }
144     }
145
146     #[inline]
147     pub fn any(self) -> bool {
148         self != Self::splat(false)
149     }
150
151     #[inline]
152     pub fn all(self) -> bool {
153         self == Self::splat(true)
154     }
155 }
156
157 impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES>
158 where
159     T: MaskElement,
160     LaneCount<LANES>: SupportedLaneCount,
161     <LaneCount<LANES> as SupportedLaneCount>::BitMask: AsRef<[u8]> + AsMut<[u8]>,
162 {
163     type Output = Self;
164     #[inline]
165     fn bitand(mut self, rhs: Self) -> Self {
166         for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) {
167             *l &= r;
168         }
169         self
170     }
171 }
172
173 impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES>
174 where
175     T: MaskElement,
176     LaneCount<LANES>: SupportedLaneCount,
177     <LaneCount<LANES> as SupportedLaneCount>::BitMask: AsRef<[u8]> + AsMut<[u8]>,
178 {
179     type Output = Self;
180     #[inline]
181     fn bitor(mut self, rhs: Self) -> Self {
182         for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) {
183             *l |= r;
184         }
185         self
186     }
187 }
188
189 impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES>
190 where
191     T: MaskElement,
192     LaneCount<LANES>: SupportedLaneCount,
193 {
194     type Output = Self;
195     #[inline]
196     fn bitxor(mut self, rhs: Self) -> Self::Output {
197         for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) {
198             *l ^= r;
199         }
200         self
201     }
202 }
203
204 impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES>
205 where
206     T: MaskElement,
207     LaneCount<LANES>: SupportedLaneCount,
208 {
209     type Output = Self;
210     #[inline]
211     fn not(mut self) -> Self::Output {
212         for x in self.0.as_mut() {
213             *x = !*x;
214         }
215         if LANES % 8 > 0 {
216             *self.0.as_mut().last_mut().unwrap() &= u8::MAX >> (8 - LANES % 8);
217         }
218         self
219     }
220 }