]> git.lizzy.rs Git - rust.git/commitdiff
implement SIMD float rounding functions
authorRalf Jung <post@ralfj.de>
Wed, 16 Mar 2022 22:51:34 +0000 (18:51 -0400)
committerRalf Jung <post@ralfj.de>
Wed, 16 Mar 2022 22:53:36 +0000 (18:53 -0400)
src/shims/intrinsics.rs
tests/run-pass/portable-simd.rs

index 24aeb448912b4665e6057a6c226dbb4735a25b4a..9b6df483b92a8c7518b8937492cf129a341950f4 100644 (file)
@@ -325,20 +325,37 @@ fn call_intrinsic(
             // SIMD operations
             #[rustfmt::skip]
             | "simd_neg"
-            | "simd_fabs" => {
+            | "simd_fabs"
+            | "simd_ceil"
+            | "simd_floor"
+            | "simd_round"
+            | "simd_trunc" => {
                 let &[ref op] = check_arg_count(args)?;
                 let (op, op_len) = this.operand_to_simd(op)?;
                 let (dest, dest_len) = this.place_to_simd(dest)?;
 
                 assert_eq!(dest_len, op_len);
 
+                #[derive(Copy, Clone)]
+                enum HostFloatOp {
+                    Ceil,
+                    Floor,
+                    Round,
+                    Trunc,
+                }
+                #[derive(Copy, Clone)]
                 enum Op {
                     MirOp(mir::UnOp),
                     Abs,
+                    HostOp(HostFloatOp),
                 }
                 let which = match intrinsic_name {
                     "simd_neg" => Op::MirOp(mir::UnOp::Neg),
                     "simd_fabs" => Op::Abs,
+                    "simd_ceil" => Op::HostOp(HostFloatOp::Ceil),
+                    "simd_floor" => Op::HostOp(HostFloatOp::Floor),
+                    "simd_round" => Op::HostOp(HostFloatOp::Round),
+                    "simd_trunc" => Op::HostOp(HostFloatOp::Trunc),
                     _ => unreachable!(),
                 };
 
@@ -350,7 +367,7 @@ enum Op {
                         Op::Abs => {
                             // Works for f32 and f64.
                             let ty::Float(float_ty) = op.layout.ty.kind() else {
-                                bug!("simd_fabs operand is not a float")
+                                bug!("{} operand is not a float", intrinsic_name)
                             };
                             let op = op.to_scalar()?;
                             match float_ty {
@@ -358,6 +375,35 @@ enum Op {
                                 FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
                             }
                         }
+                        Op::HostOp(host_op) => {
+                            let ty::Float(float_ty) = op.layout.ty.kind() else {
+                                bug!("{} operand is not a float", intrinsic_name)
+                            };
+                            // FIXME using host floats
+                            match float_ty {
+                                FloatTy::F32 => {
+                                    let f = f32::from_bits(op.to_scalar()?.to_u32()?);
+                                    let res = match host_op {
+                                        HostFloatOp::Ceil => f.ceil(),
+                                        HostFloatOp::Floor => f.floor(),
+                                        HostFloatOp::Round => f.round(),
+                                        HostFloatOp::Trunc => f.trunc(),
+                                    };
+                                    Scalar::from_u32(res.to_bits())
+                                }
+                                FloatTy::F64 => {
+                                    let f = f64::from_bits(op.to_scalar()?.to_u64()?);
+                                    let res = match host_op {
+                                        HostFloatOp::Ceil => f.ceil(),
+                                        HostFloatOp::Floor => f.floor(),
+                                        HostFloatOp::Round => f.round(),
+                                        HostFloatOp::Trunc => f.trunc(),
+                                    };
+                                    Scalar::from_u64(res.to_bits())
+                                }
+                            }
+
+                        }
                     };
                     this.write_scalar(val, &dest.into())?;
                 }
index 28b9a1b03d94c5af61799f4bd42b55fdb1b707d1..a15a0a3b1e003d0a9cb85ac09b9dda60f0b1ecd1 100644 (file)
@@ -106,19 +106,39 @@ fn simd_ops_i32() {
     assert_eq!(a.min(b * i32x4::splat(4)), i32x4::from_array([4, 8, 10, -16]));
 
     assert_eq!(
-        i8x4::from_array([i8::MAX, -23, 23, i8::MIN]).saturating_add(i8x4::from_array([1, i8::MIN, i8::MAX, 28])),
+        i8x4::from_array([i8::MAX, -23, 23, i8::MIN]).saturating_add(i8x4::from_array([
+            1,
+            i8::MIN,
+            i8::MAX,
+            28
+        ])),
         i8x4::from_array([i8::MAX, i8::MIN, i8::MAX, -100])
     );
     assert_eq!(
-        i8x4::from_array([i8::MAX, -28, 27, 42]).saturating_sub(i8x4::from_array([1, i8::MAX, i8::MAX, -80])),
+        i8x4::from_array([i8::MAX, -28, 27, 42]).saturating_sub(i8x4::from_array([
+            1,
+            i8::MAX,
+            i8::MAX,
+            -80
+        ])),
         i8x4::from_array([126, i8::MIN, -100, 122])
     );
     assert_eq!(
-        u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_add(u8x4::from_array([1, 1, u8::MAX, 200])),
+        u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_add(u8x4::from_array([
+            1,
+            1,
+            u8::MAX,
+            200
+        ])),
         u8x4::from_array([u8::MAX, 1, u8::MAX, 242])
     );
     assert_eq!(
-        u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_sub(u8x4::from_array([1, 1, u8::MAX, 200])),
+        u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_sub(u8x4::from_array([
+            1,
+            1,
+            u8::MAX,
+            200
+        ])),
         u8x4::from_array([254, 0, 0, 0])
     );
 
@@ -259,6 +279,42 @@ fn simd_gather_scatter() {
     assert_eq!(vec, vec![124, 11, 12, 82, 14, 15, 16, 17, 18]);
 }
 
+fn simd_round() {
+    assert_eq!(
+        f32x4::from_array([0.9, 1.001, 2.0, -4.5]).ceil(),
+        f32x4::from_array([1.0, 2.0, 2.0, -4.0])
+    );
+    assert_eq!(
+        f32x4::from_array([0.9, 1.001, 2.0, -4.5]).floor(),
+        f32x4::from_array([0.0, 1.0, 2.0, -5.0])
+    );
+    assert_eq!(
+        f32x4::from_array([0.9, 1.001, 2.0, -4.5]).round(),
+        f32x4::from_array([1.0, 1.0, 2.0, -5.0])
+    );
+    assert_eq!(
+        f32x4::from_array([0.9, 1.001, 2.0, -4.5]).trunc(),
+        f32x4::from_array([0.0, 1.0, 2.0, -4.0])
+    );
+
+    assert_eq!(
+        f64x4::from_array([0.9, 1.001, 2.0, -4.5]).ceil(),
+        f64x4::from_array([1.0, 2.0, 2.0, -4.0])
+    );
+    assert_eq!(
+        f64x4::from_array([0.9, 1.001, 2.0, -4.5]).floor(),
+        f64x4::from_array([0.0, 1.0, 2.0, -5.0])
+    );
+    assert_eq!(
+        f64x4::from_array([0.9, 1.001, 2.0, -4.5]).round(),
+        f64x4::from_array([1.0, 1.0, 2.0, -5.0])
+    );
+    assert_eq!(
+        f64x4::from_array([0.9, 1.001, 2.0, -4.5]).trunc(),
+        f64x4::from_array([0.0, 1.0, 2.0, -4.0])
+    );
+}
+
 fn simd_intrinsics() {
     extern "platform-intrinsic" {
         fn simd_eq<T, U>(x: T, y: T) -> U;
@@ -299,5 +355,6 @@ fn main() {
     simd_cast();
     simd_swizzle();
     simd_gather_scatter();
+    simd_round();
     simd_intrinsics();
 }