]> git.lizzy.rs Git - rust.git/blob - library/portable-simd/crates/core_simd/src/ops.rs
Merge commit 'a8385522ade6f67853edac730b5bf164ddb298fd' into simd-remove-autosplats
[rust.git] / library / portable-simd / crates / core_simd / src / ops.rs
1 use crate::simd::intrinsics;
2 use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
3 use core::ops::{Add, Mul};
4 use core::ops::{BitAnd, BitOr, BitXor};
5 use core::ops::{Div, Rem, Sub};
6 use core::ops::{Shl, Shr};
7
8 mod assign;
9 mod deref;
10 mod unary;
11
12 impl<I, T, const LANES: usize> core::ops::Index<I> for Simd<T, LANES>
13 where
14     T: SimdElement,
15     LaneCount<LANES>: SupportedLaneCount,
16     I: core::slice::SliceIndex<[T]>,
17 {
18     type Output = I::Output;
19     fn index(&self, index: I) -> &Self::Output {
20         &self.as_array()[index]
21     }
22 }
23
24 impl<I, T, const LANES: usize> core::ops::IndexMut<I> for Simd<T, LANES>
25 where
26     T: SimdElement,
27     LaneCount<LANES>: SupportedLaneCount,
28     I: core::slice::SliceIndex<[T]>,
29 {
30     fn index_mut(&mut self, index: I) -> &mut Self::Output {
31         &mut self.as_mut_array()[index]
32     }
33 }
34
35 /// Checks if the right-hand side argument of a left- or right-shift would cause overflow.
36 fn invalid_shift_rhs<T>(rhs: T) -> bool
37 where
38     T: Default + PartialOrd + core::convert::TryFrom<usize>,
39     <T as core::convert::TryFrom<usize>>::Error: core::fmt::Debug,
40 {
41     let bits_in_type = T::try_from(8 * core::mem::size_of::<T>()).unwrap();
42     rhs < T::default() || rhs >= bits_in_type
43 }
44
45 /// Automatically implements operators over references in addition to the provided operator.
46 macro_rules! impl_ref_ops {
47     // binary op
48     {
49         impl<const $lanes:ident: usize> core::ops::$trait:ident<$rhs:ty> for $type:ty
50         where
51             LaneCount<$lanes2:ident>: SupportedLaneCount,
52         {
53             type Output = $output:ty;
54
55             $(#[$attrs:meta])*
56             fn $fn:ident($self_tok:ident, $rhs_arg:ident: $rhs_arg_ty:ty) -> Self::Output $body:tt
57         }
58     } => {
59         impl<const $lanes: usize> core::ops::$trait<$rhs> for $type
60         where
61             LaneCount<$lanes2>: SupportedLaneCount,
62         {
63             type Output = $output;
64
65             $(#[$attrs])*
66             fn $fn($self_tok, $rhs_arg: $rhs_arg_ty) -> Self::Output $body
67         }
68     };
69 }
70
71 /// Automatically implements operators over vectors and scalars for a particular vector.
72 macro_rules! impl_op {
73     { impl Add for $scalar:ty } => {
74         impl_op! { @binary $scalar, Add::add, simd_add }
75     };
76     { impl Sub for $scalar:ty } => {
77         impl_op! { @binary $scalar, Sub::sub, simd_sub }
78     };
79     { impl Mul for $scalar:ty } => {
80         impl_op! { @binary $scalar, Mul::mul, simd_mul }
81     };
82     { impl Div for $scalar:ty } => {
83         impl_op! { @binary $scalar, Div::div, simd_div }
84     };
85     { impl Rem for $scalar:ty } => {
86         impl_op! { @binary $scalar, Rem::rem, simd_rem }
87     };
88     { impl Shl for $scalar:ty } => {
89         impl_op! { @binary $scalar, Shl::shl, simd_shl }
90     };
91     { impl Shr for $scalar:ty } => {
92         impl_op! { @binary $scalar, Shr::shr, simd_shr }
93     };
94     { impl BitAnd for $scalar:ty } => {
95         impl_op! { @binary $scalar, BitAnd::bitand, simd_and }
96     };
97     { impl BitOr for $scalar:ty } => {
98         impl_op! { @binary $scalar, BitOr::bitor, simd_or }
99     };
100     { impl BitXor for $scalar:ty } => {
101         impl_op! { @binary $scalar, BitXor::bitxor, simd_xor }
102     };
103
104     // generic binary op with assignment when output is `Self`
105     { @binary $scalar:ty, $trait:ident :: $trait_fn:ident, $intrinsic:ident } => {
106         impl_ref_ops! {
107             impl<const LANES: usize> core::ops::$trait<Self> for Simd<$scalar, LANES>
108             where
109                 LaneCount<LANES>: SupportedLaneCount,
110             {
111                 type Output = Self;
112
113                 #[inline]
114                 fn $trait_fn(self, rhs: Self) -> Self::Output {
115                     unsafe {
116                         intrinsics::$intrinsic(self, rhs)
117                     }
118                 }
119             }
120         }
121     };
122 }
123
124 /// Implements floating-point operators for the provided types.
125 macro_rules! impl_float_ops {
126     { $($scalar:ty),* } => {
127         $(
128             impl_op! { impl Add for $scalar }
129             impl_op! { impl Sub for $scalar }
130             impl_op! { impl Mul for $scalar }
131             impl_op! { impl Div for $scalar }
132             impl_op! { impl Rem for $scalar }
133         )*
134     };
135 }
136
137 /// Implements unsigned integer operators for the provided types.
138 macro_rules! impl_unsigned_int_ops {
139     { $($scalar:ty),* } => {
140         $(
141             impl_op! { impl Add for $scalar }
142             impl_op! { impl Sub for $scalar }
143             impl_op! { impl Mul for $scalar }
144             impl_op! { impl BitAnd for $scalar }
145             impl_op! { impl BitOr  for $scalar }
146             impl_op! { impl BitXor for $scalar }
147
148             // Integers panic on divide by 0
149             impl_ref_ops! {
150                 impl<const LANES: usize> core::ops::Div<Self> for Simd<$scalar, LANES>
151                 where
152                     LaneCount<LANES>: SupportedLaneCount,
153                 {
154                     type Output = Self;
155
156                     #[inline]
157                     fn div(self, rhs: Self) -> Self::Output {
158                         if rhs.as_array()
159                             .iter()
160                             .any(|x| *x == 0)
161                         {
162                             panic!("attempt to divide by zero");
163                         }
164
165                         // Guards for div(MIN, -1),
166                         // this check only applies to signed ints
167                         if <$scalar>::MIN != 0 && self.as_array().iter()
168                                 .zip(rhs.as_array().iter())
169                                 .any(|(x,y)| *x == <$scalar>::MIN && *y == -1 as _) {
170                             panic!("attempt to divide with overflow");
171                         }
172                         unsafe { intrinsics::simd_div(self, rhs) }
173                     }
174                 }
175             }
176
177             // remainder panics on zero divisor
178             impl_ref_ops! {
179                 impl<const LANES: usize> core::ops::Rem<Self> for Simd<$scalar, LANES>
180                 where
181                     LaneCount<LANES>: SupportedLaneCount,
182                 {
183                     type Output = Self;
184
185                     #[inline]
186                     fn rem(self, rhs: Self) -> Self::Output {
187                         if rhs.as_array()
188                             .iter()
189                             .any(|x| *x == 0)
190                         {
191                             panic!("attempt to calculate the remainder with a divisor of zero");
192                         }
193
194                         // Guards for rem(MIN, -1)
195                         // this branch applies the check only to signed ints
196                         if <$scalar>::MIN != 0 && self.as_array().iter()
197                                 .zip(rhs.as_array().iter())
198                                 .any(|(x,y)| *x == <$scalar>::MIN && *y == -1 as _) {
199                             panic!("attempt to calculate the remainder with overflow");
200                         }
201                         unsafe { intrinsics::simd_rem(self, rhs) }
202                     }
203                 }
204             }
205
206             // shifts panic on overflow
207             impl_ref_ops! {
208                 impl<const LANES: usize> core::ops::Shl<Self> for Simd<$scalar, LANES>
209                 where
210                     LaneCount<LANES>: SupportedLaneCount,
211                 {
212                     type Output = Self;
213
214                     #[inline]
215                     fn shl(self, rhs: Self) -> Self::Output {
216                         // TODO there is probably a better way of doing this
217                         if rhs.as_array()
218                             .iter()
219                             .copied()
220                             .any(invalid_shift_rhs)
221                         {
222                             panic!("attempt to shift left with overflow");
223                         }
224                         unsafe { intrinsics::simd_shl(self, rhs) }
225                     }
226                 }
227             }
228
229             impl_ref_ops! {
230                 impl<const LANES: usize> core::ops::Shr<Self> for Simd<$scalar, LANES>
231                 where
232                     LaneCount<LANES>: SupportedLaneCount,
233                 {
234                     type Output = Self;
235
236                     #[inline]
237                     fn shr(self, rhs: Self) -> Self::Output {
238                         // TODO there is probably a better way of doing this
239                         if rhs.as_array()
240                             .iter()
241                             .copied()
242                             .any(invalid_shift_rhs)
243                         {
244                             panic!("attempt to shift with overflow");
245                         }
246                         unsafe { intrinsics::simd_shr(self, rhs) }
247                     }
248                 }
249             }
250         )*
251     };
252 }
253
254 /// Implements unsigned integer operators for the provided types.
255 macro_rules! impl_signed_int_ops {
256     { $($scalar:ty),* } => {
257         impl_unsigned_int_ops! { $($scalar),* }
258     };
259 }
260
261 impl_unsigned_int_ops! { u8, u16, u32, u64, usize }
262 impl_signed_int_ops! { i8, i16, i32, i64, isize }
263 impl_float_ops! { f32, f64 }