]> git.lizzy.rs Git - rust.git/blob - library/portable-simd/crates/core_simd/src/masks.rs
Merge commit '57b3c4b90f4346b3990c1be387c3b3ca7b78412c' into clippyup
[rust.git] / library / portable-simd / crates / core_simd / src / masks.rs
1 //! Types and traits associated with masking lanes of vectors.
2 //! Types representing
3 #![allow(non_camel_case_types)]
4
5 #[cfg_attr(
6     not(all(target_arch = "x86_64", target_feature = "avx512f")),
7     path = "masks/full_masks.rs"
8 )]
9 #[cfg_attr(
10     all(target_arch = "x86_64", target_feature = "avx512f"),
11     path = "masks/bitmask.rs"
12 )]
13 mod mask_impl;
14
15 use crate::simd::intrinsics;
16 use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
17 use core::cmp::Ordering;
18 use core::{fmt, mem};
19
20 mod sealed {
21     use super::*;
22
23     /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
24     /// from bleeding into the parent bounds.
25     ///
26     /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
27     /// prevent us from ever removing that bound, or from implementing `MaskElement` on
28     /// non-`PartialEq` types in the future.
29     pub trait Sealed {
30         fn valid<const LANES: usize>(values: Simd<Self, LANES>) -> bool
31         where
32             LaneCount<LANES>: SupportedLaneCount,
33             Self: SimdElement;
34
35         fn eq(self, other: Self) -> bool;
36
37         const TRUE: Self;
38
39         const FALSE: Self;
40     }
41 }
42 use sealed::Sealed;
43
44 /// Marker trait for types that may be used as SIMD mask elements.
45 pub unsafe trait MaskElement: SimdElement + Sealed {}
46
47 macro_rules! impl_element {
48     { $ty:ty } => {
49         impl Sealed for $ty {
50             fn valid<const LANES: usize>(value: Simd<Self, LANES>) -> bool
51             where
52                 LaneCount<LANES>: SupportedLaneCount,
53             {
54                 (value.lanes_eq(Simd::splat(0)) | value.lanes_eq(Simd::splat(-1))).all()
55             }
56
57             fn eq(self, other: Self) -> bool { self == other }
58
59             const TRUE: Self = -1;
60             const FALSE: Self = 0;
61         }
62
63         unsafe impl MaskElement for $ty {}
64     }
65 }
66
67 impl_element! { i8 }
68 impl_element! { i16 }
69 impl_element! { i32 }
70 impl_element! { i64 }
71 impl_element! { isize }
72
73 /// A SIMD vector mask for `LANES` elements of width specified by `Element`.
74 ///
75 /// The layout of this type is unspecified.
76 #[repr(transparent)]
77 pub struct Mask<T, const LANES: usize>(mask_impl::Mask<T, LANES>)
78 where
79     T: MaskElement,
80     LaneCount<LANES>: SupportedLaneCount;
81
82 impl<T, const LANES: usize> Copy for Mask<T, LANES>
83 where
84     T: MaskElement,
85     LaneCount<LANES>: SupportedLaneCount,
86 {
87 }
88
89 impl<T, const LANES: usize> Clone for Mask<T, LANES>
90 where
91     T: MaskElement,
92     LaneCount<LANES>: SupportedLaneCount,
93 {
94     fn clone(&self) -> Self {
95         *self
96     }
97 }
98
99 impl<T, const LANES: usize> Mask<T, LANES>
100 where
101     T: MaskElement,
102     LaneCount<LANES>: SupportedLaneCount,
103 {
104     /// Construct a mask by setting all lanes to the given value.
105     pub fn splat(value: bool) -> Self {
106         Self(mask_impl::Mask::splat(value))
107     }
108
109     /// Converts an array of bools to a SIMD mask.
110     pub fn from_array(array: [bool; LANES]) -> Self {
111         // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
112         //     true:    0b_0000_0001
113         //     false:   0b_0000_0000
114         // Thus, an array of bools is also a valid array of bytes: [u8; N]
115         // This would be hypothetically valid as an "in-place" transmute,
116         // but these are "dependently-sized" types, so copy elision it is!
117         unsafe {
118             let bytes: [u8; LANES] = mem::transmute_copy(&array);
119             let bools: Simd<i8, LANES> =
120                 intrinsics::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
121             Mask::from_int_unchecked(intrinsics::simd_cast(bools))
122         }
123     }
124
125     /// Converts a SIMD mask to an array of bools.
126     pub fn to_array(self) -> [bool; LANES] {
127         // This follows mostly the same logic as from_array.
128         // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
129         //     true:    0b_0000_0001
130         //     false:   0b_0000_0000
131         // Thus, an array of bools is also a valid array of bytes: [u8; N]
132         // Since our masks are equal to integers where all bits are set,
133         // we can simply convert them to i8s, and then bitand them by the
134         // bitpattern for Rust's "true" bool.
135         // This would be hypothetically valid as an "in-place" transmute,
136         // but these are "dependently-sized" types, so copy elision it is!
137         unsafe {
138             let mut bytes: Simd<i8, LANES> = intrinsics::simd_cast(self.to_int());
139             bytes &= Simd::splat(1i8);
140             mem::transmute_copy(&bytes)
141         }
142     }
143
144     /// Converts a vector of integers to a mask, where 0 represents `false` and -1
145     /// represents `true`.
146     ///
147     /// # Safety
148     /// All lanes must be either 0 or -1.
149     #[inline]
150     #[must_use = "method returns a new mask and does not mutate the original value"]
151     pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self {
152         unsafe { Self(mask_impl::Mask::from_int_unchecked(value)) }
153     }
154
155     /// Converts a vector of integers to a mask, where 0 represents `false` and -1
156     /// represents `true`.
157     ///
158     /// # Panics
159     /// Panics if any lane is not 0 or -1.
160     #[inline]
161     #[must_use = "method returns a new mask and does not mutate the original value"]
162     pub fn from_int(value: Simd<T, LANES>) -> Self {
163         assert!(T::valid(value), "all values must be either 0 or -1",);
164         unsafe { Self::from_int_unchecked(value) }
165     }
166
167     /// Converts the mask to a vector of integers, where 0 represents `false` and -1
168     /// represents `true`.
169     #[inline]
170     #[must_use = "method returns a new vector and does not mutate the original value"]
171     pub fn to_int(self) -> Simd<T, LANES> {
172         self.0.to_int()
173     }
174
175     /// Tests the value of the specified lane.
176     ///
177     /// # Safety
178     /// `lane` must be less than `LANES`.
179     #[inline]
180     #[must_use = "method returns a new bool and does not mutate the original value"]
181     pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
182         unsafe { self.0.test_unchecked(lane) }
183     }
184
185     /// Tests the value of the specified lane.
186     ///
187     /// # Panics
188     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
189     #[inline]
190     #[must_use = "method returns a new bool and does not mutate the original value"]
191     pub fn test(&self, lane: usize) -> bool {
192         assert!(lane < LANES, "lane index out of range");
193         unsafe { self.test_unchecked(lane) }
194     }
195
196     /// Sets the value of the specified lane.
197     ///
198     /// # Safety
199     /// `lane` must be less than `LANES`.
200     #[inline]
201     pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
202         unsafe {
203             self.0.set_unchecked(lane, value);
204         }
205     }
206
207     /// Sets the value of the specified lane.
208     ///
209     /// # Panics
210     /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
211     #[inline]
212     pub fn set(&mut self, lane: usize, value: bool) {
213         assert!(lane < LANES, "lane index out of range");
214         unsafe {
215             self.set_unchecked(lane, value);
216         }
217     }
218
219     /// Convert this mask to a bitmask, with one bit set per lane.
220     #[cfg(feature = "generic_const_exprs")]
221     #[inline]
222     #[must_use = "method returns a new array and does not mutate the original value"]
223     pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
224         self.0.to_bitmask()
225     }
226
227     /// Convert a bitmask to a mask.
228     #[cfg(feature = "generic_const_exprs")]
229     #[inline]
230     #[must_use = "method returns a new mask and does not mutate the original value"]
231     pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
232         Self(mask_impl::Mask::from_bitmask(bitmask))
233     }
234
235     /// Returns true if any lane is set, or false otherwise.
236     #[inline]
237     #[must_use = "method returns a new bool and does not mutate the original value"]
238     pub fn any(self) -> bool {
239         self.0.any()
240     }
241
242     /// Returns true if all lanes are set, or false otherwise.
243     #[inline]
244     #[must_use = "method returns a new bool and does not mutate the original value"]
245     pub fn all(self) -> bool {
246         self.0.all()
247     }
248 }
249
250 // vector/array conversion
251 impl<T, const LANES: usize> From<[bool; LANES]> for Mask<T, LANES>
252 where
253     T: MaskElement,
254     LaneCount<LANES>: SupportedLaneCount,
255 {
256     fn from(array: [bool; LANES]) -> Self {
257         Self::from_array(array)
258     }
259 }
260
261 impl<T, const LANES: usize> From<Mask<T, LANES>> for [bool; LANES]
262 where
263     T: MaskElement,
264     LaneCount<LANES>: SupportedLaneCount,
265 {
266     fn from(vector: Mask<T, LANES>) -> Self {
267         vector.to_array()
268     }
269 }
270
271 impl<T, const LANES: usize> Default for Mask<T, LANES>
272 where
273     T: MaskElement,
274     LaneCount<LANES>: SupportedLaneCount,
275 {
276     #[inline]
277     #[must_use = "method returns a defaulted mask with all lanes set to false (0)"]
278     fn default() -> Self {
279         Self::splat(false)
280     }
281 }
282
283 impl<T, const LANES: usize> PartialEq for Mask<T, LANES>
284 where
285     T: MaskElement + PartialEq,
286     LaneCount<LANES>: SupportedLaneCount,
287 {
288     #[inline]
289     #[must_use = "method returns a new bool and does not mutate the original value"]
290     fn eq(&self, other: &Self) -> bool {
291         self.0 == other.0
292     }
293 }
294
295 impl<T, const LANES: usize> PartialOrd for Mask<T, LANES>
296 where
297     T: MaskElement + PartialOrd,
298     LaneCount<LANES>: SupportedLaneCount,
299 {
300     #[inline]
301     #[must_use = "method returns a new Ordering and does not mutate the original value"]
302     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
303         self.0.partial_cmp(&other.0)
304     }
305 }
306
307 impl<T, const LANES: usize> fmt::Debug for Mask<T, LANES>
308 where
309     T: MaskElement + fmt::Debug,
310     LaneCount<LANES>: SupportedLaneCount,
311 {
312     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313         f.debug_list()
314             .entries((0..LANES).map(|lane| self.test(lane)))
315             .finish()
316     }
317 }
318
319 impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES>
320 where
321     T: MaskElement,
322     LaneCount<LANES>: SupportedLaneCount,
323 {
324     type Output = Self;
325     #[inline]
326     #[must_use = "method returns a new mask and does not mutate the original value"]
327     fn bitand(self, rhs: Self) -> Self {
328         Self(self.0 & rhs.0)
329     }
330 }
331
332 impl<T, const LANES: usize> core::ops::BitAnd<bool> for Mask<T, LANES>
333 where
334     T: MaskElement,
335     LaneCount<LANES>: SupportedLaneCount,
336 {
337     type Output = Self;
338     #[inline]
339     #[must_use = "method returns a new mask and does not mutate the original value"]
340     fn bitand(self, rhs: bool) -> Self {
341         self & Self::splat(rhs)
342     }
343 }
344
345 impl<T, const LANES: usize> core::ops::BitAnd<Mask<T, LANES>> for bool
346 where
347     T: MaskElement,
348     LaneCount<LANES>: SupportedLaneCount,
349 {
350     type Output = Mask<T, LANES>;
351     #[inline]
352     #[must_use = "method returns a new mask and does not mutate the original value"]
353     fn bitand(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> {
354         Mask::splat(self) & rhs
355     }
356 }
357
358 impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES>
359 where
360     T: MaskElement,
361     LaneCount<LANES>: SupportedLaneCount,
362 {
363     type Output = Self;
364     #[inline]
365     #[must_use = "method returns a new mask and does not mutate the original value"]
366     fn bitor(self, rhs: Self) -> Self {
367         Self(self.0 | rhs.0)
368     }
369 }
370
371 impl<T, const LANES: usize> core::ops::BitOr<bool> for Mask<T, LANES>
372 where
373     T: MaskElement,
374     LaneCount<LANES>: SupportedLaneCount,
375 {
376     type Output = Self;
377     #[inline]
378     #[must_use = "method returns a new mask and does not mutate the original value"]
379     fn bitor(self, rhs: bool) -> Self {
380         self | Self::splat(rhs)
381     }
382 }
383
384 impl<T, const LANES: usize> core::ops::BitOr<Mask<T, LANES>> for bool
385 where
386     T: MaskElement,
387     LaneCount<LANES>: SupportedLaneCount,
388 {
389     type Output = Mask<T, LANES>;
390     #[inline]
391     #[must_use = "method returns a new mask and does not mutate the original value"]
392     fn bitor(self, rhs: Mask<T, LANES>) -> Mask<T, LANES> {
393         Mask::splat(self) | rhs
394     }
395 }
396
397 impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES>
398 where
399     T: MaskElement,
400     LaneCount<LANES>: SupportedLaneCount,
401 {
402     type Output = Self;
403     #[inline]
404     #[must_use = "method returns a new mask and does not mutate the original value"]
405     fn bitxor(self, rhs: Self) -> Self::Output {
406         Self(self.0 ^ rhs.0)
407     }
408 }
409
410 impl<T, const LANES: usize> core::ops::BitXor<bool> for Mask<T, LANES>
411 where
412     T: MaskElement,
413     LaneCount<LANES>: SupportedLaneCount,
414 {
415     type Output = Self;
416     #[inline]
417     #[must_use = "method returns a new mask and does not mutate the original value"]
418     fn bitxor(self, rhs: bool) -> Self::Output {
419         self ^ Self::splat(rhs)
420     }
421 }
422
423 impl<T, const LANES: usize> core::ops::BitXor<Mask<T, LANES>> for bool
424 where
425     T: MaskElement,
426     LaneCount<LANES>: SupportedLaneCount,
427 {
428     type Output = Mask<T, LANES>;
429     #[inline]
430     #[must_use = "method returns a new mask and does not mutate the original value"]
431     fn bitxor(self, rhs: Mask<T, LANES>) -> Self::Output {
432         Mask::splat(self) ^ rhs
433     }
434 }
435
436 impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES>
437 where
438     T: MaskElement,
439     LaneCount<LANES>: SupportedLaneCount,
440 {
441     type Output = Mask<T, LANES>;
442     #[inline]
443     #[must_use = "method returns a new mask and does not mutate the original value"]
444     fn not(self) -> Self::Output {
445         Self(!self.0)
446     }
447 }
448
449 impl<T, const LANES: usize> core::ops::BitAndAssign for Mask<T, LANES>
450 where
451     T: MaskElement,
452     LaneCount<LANES>: SupportedLaneCount,
453 {
454     #[inline]
455     fn bitand_assign(&mut self, rhs: Self) {
456         self.0 = self.0 & rhs.0;
457     }
458 }
459
460 impl<T, const LANES: usize> core::ops::BitAndAssign<bool> for Mask<T, LANES>
461 where
462     T: MaskElement,
463     LaneCount<LANES>: SupportedLaneCount,
464 {
465     #[inline]
466     fn bitand_assign(&mut self, rhs: bool) {
467         *self &= Self::splat(rhs);
468     }
469 }
470
471 impl<T, const LANES: usize> core::ops::BitOrAssign for Mask<T, LANES>
472 where
473     T: MaskElement,
474     LaneCount<LANES>: SupportedLaneCount,
475 {
476     #[inline]
477     fn bitor_assign(&mut self, rhs: Self) {
478         self.0 = self.0 | rhs.0;
479     }
480 }
481
482 impl<T, const LANES: usize> core::ops::BitOrAssign<bool> for Mask<T, LANES>
483 where
484     T: MaskElement,
485     LaneCount<LANES>: SupportedLaneCount,
486 {
487     #[inline]
488     fn bitor_assign(&mut self, rhs: bool) {
489         *self |= Self::splat(rhs);
490     }
491 }
492
493 impl<T, const LANES: usize> core::ops::BitXorAssign for Mask<T, LANES>
494 where
495     T: MaskElement,
496     LaneCount<LANES>: SupportedLaneCount,
497 {
498     #[inline]
499     fn bitxor_assign(&mut self, rhs: Self) {
500         self.0 = self.0 ^ rhs.0;
501     }
502 }
503
504 impl<T, const LANES: usize> core::ops::BitXorAssign<bool> for Mask<T, LANES>
505 where
506     T: MaskElement,
507     LaneCount<LANES>: SupportedLaneCount,
508 {
509     #[inline]
510     fn bitxor_assign(&mut self, rhs: bool) {
511         *self ^= Self::splat(rhs);
512     }
513 }
514
515 /// Vector of eight 8-bit masks
516 pub type mask8x8 = Mask<i8, 8>;
517
518 /// Vector of 16 8-bit masks
519 pub type mask8x16 = Mask<i8, 16>;
520
521 /// Vector of 32 8-bit masks
522 pub type mask8x32 = Mask<i8, 32>;
523
524 /// Vector of 16 8-bit masks
525 pub type mask8x64 = Mask<i8, 64>;
526
527 /// Vector of four 16-bit masks
528 pub type mask16x4 = Mask<i16, 4>;
529
530 /// Vector of eight 16-bit masks
531 pub type mask16x8 = Mask<i16, 8>;
532
533 /// Vector of 16 16-bit masks
534 pub type mask16x16 = Mask<i16, 16>;
535
536 /// Vector of 32 16-bit masks
537 pub type mask16x32 = Mask<i16, 32>;
538
539 /// Vector of two 32-bit masks
540 pub type mask32x2 = Mask<i32, 2>;
541
542 /// Vector of four 32-bit masks
543 pub type mask32x4 = Mask<i32, 4>;
544
545 /// Vector of eight 32-bit masks
546 pub type mask32x8 = Mask<i32, 8>;
547
548 /// Vector of 16 32-bit masks
549 pub type mask32x16 = Mask<i32, 16>;
550
551 /// Vector of two 64-bit masks
552 pub type mask64x2 = Mask<i64, 2>;
553
554 /// Vector of four 64-bit masks
555 pub type mask64x4 = Mask<i64, 4>;
556
557 /// Vector of eight 64-bit masks
558 pub type mask64x8 = Mask<i64, 8>;
559
560 /// Vector of two pointer-width masks
561 pub type masksizex2 = Mask<isize, 2>;
562
563 /// Vector of four pointer-width masks
564 pub type masksizex4 = Mask<isize, 4>;
565
566 /// Vector of eight pointer-width masks
567 pub type masksizex8 = Mask<isize, 8>;
568
569 macro_rules! impl_from {
570     { $from:ty  => $($to:ty),* } => {
571         $(
572         impl<const LANES: usize> From<Mask<$from, LANES>> for Mask<$to, LANES>
573         where
574             LaneCount<LANES>: SupportedLaneCount,
575         {
576             fn from(value: Mask<$from, LANES>) -> Self {
577                 Self(value.0.convert())
578             }
579         }
580         )*
581     }
582 }
583 impl_from! { i8 => i16, i32, i64, isize }
584 impl_from! { i16 => i32, i64, isize, i8 }
585 impl_from! { i32 => i64, isize, i8, i16 }
586 impl_from! { i64 => isize, i8, i16, i32 }
587 impl_from! { isize => i8, i16, i32, i64 }