]> git.lizzy.rs Git - rust.git/blob - library/portable-simd/crates/core_simd/src/ops.rs
Rollup merge of #104419 - Ayush1325:test-issue-30490, r=lcnr
[rust.git] / library / portable-simd / crates / core_simd / src / ops.rs
1 use crate::simd::{LaneCount, Simd, SimdElement, SimdPartialEq, SupportedLaneCount};
2 use core::ops::{Add, Mul};
3 use core::ops::{BitAnd, BitOr, BitXor};
4 use core::ops::{Div, Rem, Sub};
5 use core::ops::{Shl, Shr};
6
7 mod assign;
8 mod deref;
9 mod unary;
10
11 impl<I, T, const LANES: usize> core::ops::Index<I> for Simd<T, LANES>
12 where
13     T: SimdElement,
14     LaneCount<LANES>: SupportedLaneCount,
15     I: core::slice::SliceIndex<[T]>,
16 {
17     type Output = I::Output;
18     fn index(&self, index: I) -> &Self::Output {
19         &self.as_array()[index]
20     }
21 }
22
23 impl<I, T, const LANES: usize> core::ops::IndexMut<I> for Simd<T, LANES>
24 where
25     T: SimdElement,
26     LaneCount<LANES>: SupportedLaneCount,
27     I: core::slice::SliceIndex<[T]>,
28 {
29     fn index_mut(&mut self, index: I) -> &mut Self::Output {
30         &mut self.as_mut_array()[index]
31     }
32 }
33
34 macro_rules! unsafe_base {
35     ($lhs:ident, $rhs:ident, {$simd_call:ident}, $($_:tt)*) => {
36         // Safety: $lhs and $rhs are vectors
37         unsafe { $crate::simd::intrinsics::$simd_call($lhs, $rhs) }
38     };
39 }
40
41 /// SAFETY: This macro should not be used for anything except Shl or Shr, and passed the appropriate shift intrinsic.
42 /// It handles performing a bitand in addition to calling the shift operator, so that the result
43 /// is well-defined: LLVM can return a poison value if you shl, lshr, or ashr if `rhs >= <Int>::BITS`
44 /// At worst, this will maybe add another instruction and cycle,
45 /// at best, it may open up more optimization opportunities,
46 /// or simply be elided entirely, especially for SIMD ISAs which default to this.
47 ///
48 // FIXME: Consider implementing this in cg_llvm instead?
49 // cg_clif defaults to this, and scalar MIR shifts also default to wrapping
50 macro_rules! wrap_bitshift {
51     ($lhs:ident, $rhs:ident, {$simd_call:ident}, $int:ident) => {
52         #[allow(clippy::suspicious_arithmetic_impl)]
53         // Safety: $lhs and the bitand result are vectors
54         unsafe {
55             $crate::simd::intrinsics::$simd_call(
56                 $lhs,
57                 $rhs.bitand(Simd::splat(<$int>::BITS as $int - 1)),
58             )
59         }
60     };
61 }
62
63 /// SAFETY: This macro must only be used to impl Div or Rem and given the matching intrinsic.
64 /// It guards against LLVM's UB conditions for integer div or rem using masks and selects,
65 /// thus guaranteeing a Rust value returns instead.
66 ///
67 /// |                  | LLVM | Rust
68 /// | :--------------: | :--- | :----------
69 /// | N {/,%} 0        | UB   | panic!()
70 /// | <$int>::MIN / -1 | UB   | <$int>::MIN
71 /// | <$int>::MIN % -1 | UB   | 0
72 ///
73 macro_rules! int_divrem_guard {
74     (   $lhs:ident,
75         $rhs:ident,
76         {   const PANIC_ZERO: &'static str = $zero:literal;
77             $simd_call:ident
78         },
79         $int:ident ) => {
80         if $rhs.simd_eq(Simd::splat(0 as _)).any() {
81             panic!($zero);
82         } else {
83             // Prevent otherwise-UB overflow on the MIN / -1 case.
84             let rhs = if <$int>::MIN != 0 {
85                 // This should, at worst, optimize to a few branchless logical ops
86                 // Ideally, this entire conditional should evaporate
87                 // Fire LLVM and implement those manually if it doesn't get the hint
88                 ($lhs.simd_eq(Simd::splat(<$int>::MIN))
89                 // type inference can break here, so cut an SInt to size
90                 & $rhs.simd_eq(Simd::splat(-1i64 as _)))
91                 .select(Simd::splat(1 as _), $rhs)
92             } else {
93                 // Nice base case to make it easy to const-fold away the other branch.
94                 $rhs
95             };
96             // Safety: $lhs and rhs are vectors
97             unsafe { $crate::simd::intrinsics::$simd_call($lhs, rhs) }
98         }
99     };
100 }
101
102 macro_rules! for_base_types {
103     (   T = ($($scalar:ident),*);
104         type Lhs = Simd<T, N>;
105         type Rhs = Simd<T, N>;
106         type Output = $out:ty;
107
108         impl $op:ident::$call:ident {
109             $macro_impl:ident $inner:tt
110         }) => {
111             $(
112                 impl<const N: usize> $op<Self> for Simd<$scalar, N>
113                 where
114                     $scalar: SimdElement,
115                     LaneCount<N>: SupportedLaneCount,
116                 {
117                     type Output = $out;
118
119                     #[inline]
120                     #[must_use = "operator returns a new vector without mutating the inputs"]
121                     fn $call(self, rhs: Self) -> Self::Output {
122                         $macro_impl!(self, rhs, $inner, $scalar)
123                     }
124                 })*
125     }
126 }
127
128 // A "TokenTree muncher": takes a set of scalar types `T = {};`
129 // type parameters for the ops it implements, `Op::fn` names,
130 // and a macro that expands into an expr, substituting in an intrinsic.
131 // It passes that to for_base_types, which expands an impl for the types,
132 // using the expanded expr in the function, and recurses with itself.
133 //
134 // tl;dr impls a set of ops::{Traits} for a set of types
135 macro_rules! for_base_ops {
136     (
137         T = $types:tt;
138         type Lhs = Simd<T, N>;
139         type Rhs = Simd<T, N>;
140         type Output = $out:ident;
141         impl $op:ident::$call:ident
142             $inner:tt
143         $($rest:tt)*
144     ) => {
145         for_base_types! {
146             T = $types;
147             type Lhs = Simd<T, N>;
148             type Rhs = Simd<T, N>;
149             type Output = $out;
150             impl $op::$call
151                 $inner
152         }
153         for_base_ops! {
154             T = $types;
155             type Lhs = Simd<T, N>;
156             type Rhs = Simd<T, N>;
157             type Output = $out;
158             $($rest)*
159         }
160     };
161     ($($done:tt)*) => {
162         // Done.
163     }
164 }
165
166 // Integers can always accept add, mul, sub, bitand, bitor, and bitxor.
167 // For all of these operations, simd_* intrinsics apply wrapping logic.
168 for_base_ops! {
169     T = (i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
170     type Lhs = Simd<T, N>;
171     type Rhs = Simd<T, N>;
172     type Output = Self;
173
174     impl Add::add {
175         unsafe_base { simd_add }
176     }
177
178     impl Mul::mul {
179         unsafe_base { simd_mul }
180     }
181
182     impl Sub::sub {
183         unsafe_base { simd_sub }
184     }
185
186     impl BitAnd::bitand {
187         unsafe_base { simd_and }
188     }
189
190     impl BitOr::bitor {
191         unsafe_base { simd_or }
192     }
193
194     impl BitXor::bitxor {
195         unsafe_base { simd_xor }
196     }
197
198     impl Div::div {
199         int_divrem_guard {
200             const PANIC_ZERO: &'static str = "attempt to divide by zero";
201             simd_div
202         }
203     }
204
205     impl Rem::rem {
206         int_divrem_guard {
207             const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero";
208             simd_rem
209         }
210     }
211
212     // The only question is how to handle shifts >= <Int>::BITS?
213     // Our current solution uses wrapping logic.
214     impl Shl::shl {
215         wrap_bitshift { simd_shl }
216     }
217
218     impl Shr::shr {
219         wrap_bitshift {
220             // This automatically monomorphizes to lshr or ashr, depending,
221             // so it's fine to use it for both UInts and SInts.
222             simd_shr
223         }
224     }
225 }
226
227 // We don't need any special precautions here:
228 // Floats always accept arithmetic ops, but may become NaN.
229 for_base_ops! {
230     T = (f32, f64);
231     type Lhs = Simd<T, N>;
232     type Rhs = Simd<T, N>;
233     type Output = Self;
234
235     impl Add::add {
236         unsafe_base { simd_add }
237     }
238
239     impl Mul::mul {
240         unsafe_base { simd_mul }
241     }
242
243     impl Sub::sub {
244         unsafe_base { simd_sub }
245     }
246
247     impl Div::div {
248         unsafe_base { simd_div }
249     }
250
251     impl Rem::rem {
252         unsafe_base { simd_rem }
253     }
254 }