]> git.lizzy.rs Git - rust.git/blob - crates/core_simd/src/masks/bitmask.rs
Merge pull request #61 from rust-lang/feature/masks
[rust.git] / crates / core_simd / src / masks / bitmask.rs
1 /// Implemented for bitmask sizes that are supported by the implementation.
2 pub trait LanesAtMost64 {}
3 impl LanesAtMost64 for BitMask<1> {}
4 impl LanesAtMost64 for BitMask<2> {}
5 impl LanesAtMost64 for BitMask<4> {}
6 impl LanesAtMost64 for BitMask<8> {}
7 impl LanesAtMost64 for BitMask<16> {}
8 impl LanesAtMost64 for BitMask<32> {}
9 impl LanesAtMost64 for BitMask<64> {}
10
11 /// A mask where each lane is represented by a single bit.
12 #[derive(Copy, Clone, Debug)]
13 #[repr(transparent)]
14 pub struct BitMask<const LANES: usize>(u64)
15 where
16     BitMask<LANES>: LanesAtMost64;
17
18 impl<const LANES: usize> BitMask<LANES>
19 where
20     Self: LanesAtMost64,
21 {
22     /// Construct a mask by setting all lanes to the given value.
23     pub fn splat(value: bool) -> Self {
24         if value {
25             Self(u64::MAX)
26         } else {
27             Self(u64::MIN)
28         }
29     }
30
31     /// Tests the value of the specified lane.
32     ///
33     /// # Panics
34     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
35     #[inline]
36     pub fn test(&self, lane: usize) -> bool {
37         assert!(lane < LANES, "lane index out of range");
38         (self.0 >> lane) & 0x1 > 0
39     }
40
41     /// Sets the value of the specified lane.
42     ///
43     /// # Panics
44     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
45     #[inline]
46     pub fn set(&mut self, lane: usize, value: bool) {
47         assert!(lane < LANES, "lane index out of range");
48         self.0 ^= ((value ^ self.test(lane)) as u64) << lane
49     }
50 }
51
52 impl<const LANES: usize> core::ops::BitAnd for BitMask<LANES>
53 where
54     Self: LanesAtMost64,
55 {
56     type Output = Self;
57     #[inline]
58     fn bitand(self, rhs: Self) -> Self {
59         Self(self.0 & rhs.0)
60     }
61 }
62
63 impl<const LANES: usize> core::ops::BitAnd<bool> for BitMask<LANES>
64 where
65     Self: LanesAtMost64,
66 {
67     type Output = Self;
68     #[inline]
69     fn bitand(self, rhs: bool) -> Self {
70         self & Self::splat(rhs)
71     }
72 }
73
74 impl<const LANES: usize> core::ops::BitAnd<BitMask<LANES>> for bool
75 where
76     BitMask<LANES>: LanesAtMost64,
77 {
78     type Output = BitMask<LANES>;
79     #[inline]
80     fn bitand(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
81         BitMask::<LANES>::splat(self) & rhs
82     }
83 }
84
85 impl<const LANES: usize> core::ops::BitOr for BitMask<LANES>
86 where
87     Self: LanesAtMost64,
88 {
89     type Output = Self;
90     #[inline]
91     fn bitor(self, rhs: Self) -> Self {
92         Self(self.0 | rhs.0)
93     }
94 }
95
96 impl<const LANES: usize> core::ops::BitOr<bool> for BitMask<LANES>
97 where
98     Self: LanesAtMost64,
99 {
100     type Output = Self;
101     #[inline]
102     fn bitor(self, rhs: bool) -> Self {
103         self | Self::splat(rhs)
104     }
105 }
106
107 impl<const LANES: usize> core::ops::BitOr<BitMask<LANES>> for bool
108 where
109     BitMask<LANES>: LanesAtMost64,
110 {
111     type Output = BitMask<LANES>;
112     #[inline]
113     fn bitor(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
114         BitMask::<LANES>::splat(self) | rhs
115     }
116 }
117
118 impl<const LANES: usize> core::ops::BitXor for BitMask<LANES>
119 where
120     Self: LanesAtMost64,
121 {
122     type Output = Self;
123     #[inline]
124     fn bitxor(self, rhs: Self) -> Self::Output {
125         Self(self.0 ^ rhs.0)
126     }
127 }
128
129 impl<const LANES: usize> core::ops::BitXor<bool> for BitMask<LANES>
130 where
131     Self: LanesAtMost64,
132 {
133     type Output = Self;
134     #[inline]
135     fn bitxor(self, rhs: bool) -> Self::Output {
136         self ^ Self::splat(rhs)
137     }
138 }
139
140 impl<const LANES: usize> core::ops::BitXor<BitMask<LANES>> for bool
141 where
142     BitMask<LANES>: LanesAtMost64,
143 {
144     type Output = BitMask<LANES>;
145     #[inline]
146     fn bitxor(self, rhs: BitMask<LANES>) -> Self::Output {
147         BitMask::<LANES>::splat(self) ^ rhs
148     }
149 }
150
151 impl<const LANES: usize> core::ops::Not for BitMask<LANES>
152 where
153     Self: LanesAtMost64,
154 {
155     type Output = BitMask<LANES>;
156     #[inline]
157     fn not(self) -> Self::Output {
158         Self(!self.0)
159     }
160 }
161
162 impl<const LANES: usize> core::ops::BitAndAssign for BitMask<LANES>
163 where
164     Self: LanesAtMost64,
165 {
166     #[inline]
167     fn bitand_assign(&mut self, rhs: Self) {
168         self.0 &= rhs.0;
169     }
170 }
171
172 impl<const LANES: usize> core::ops::BitAndAssign<bool> for BitMask<LANES>
173 where
174     Self: LanesAtMost64,
175 {
176     #[inline]
177     fn bitand_assign(&mut self, rhs: bool) {
178         *self &= Self::splat(rhs);
179     }
180 }
181
182 impl<const LANES: usize> core::ops::BitOrAssign for BitMask<LANES>
183 where
184     Self: LanesAtMost64,
185 {
186     #[inline]
187     fn bitor_assign(&mut self, rhs: Self) {
188         self.0 |= rhs.0;
189     }
190 }
191
192 impl<const LANES: usize> core::ops::BitOrAssign<bool> for BitMask<LANES>
193 where
194     Self: LanesAtMost64,
195 {
196     #[inline]
197     fn bitor_assign(&mut self, rhs: bool) {
198         *self |= Self::splat(rhs);
199     }
200 }
201
202 impl<const LANES: usize> core::ops::BitXorAssign for BitMask<LANES>
203 where
204     Self: LanesAtMost64,
205 {
206     #[inline]
207     fn bitxor_assign(&mut self, rhs: Self) {
208         self.0 ^= rhs.0;
209     }
210 }
211
212 impl<const LANES: usize> core::ops::BitXorAssign<bool> for BitMask<LANES>
213 where
214     Self: LanesAtMost64,
215 {
216     #[inline]
217     fn bitxor_assign(&mut self, rhs: bool) {
218         *self ^= Self::splat(rhs);
219     }
220 }