]> git.lizzy.rs Git - rust.git/commitdiff
implement simd_reduce_min/max
authorRalf Jung <post@ralfj.de>
Sun, 6 Mar 2022 19:31:45 +0000 (14:31 -0500)
committerRalf Jung <post@ralfj.de>
Mon, 7 Mar 2022 14:40:23 +0000 (09:40 -0500)
src/shims/intrinsics.rs
tests/run-pass/portable-simd.rs

index 2f29ec4553d48cdf8de1b01ecaa0ffadf440707e..6f16853698035c917804f8d4f63381471cd74787 100644 (file)
@@ -433,7 +433,9 @@ enum Op {
             | "simd_reduce_or"
             | "simd_reduce_xor"
             | "simd_reduce_any"
-            | "simd_reduce_all" => {
+            | "simd_reduce_all"
+            | "simd_reduce_max"
+            | "simd_reduce_min" => {
                 use mir::BinOp;
 
                 let &[ref op] = check_arg_count(args)?;
@@ -445,19 +447,27 @@ enum Op {
                 enum Op {
                     MirOp(BinOp),
                     MirOpBool(BinOp),
+                    Max,
+                    Min,
                 }
-                // The initial value is the neutral element.
-                let (which, init) = match intrinsic_name {
-                    "simd_reduce_and" => (Op::MirOp(BinOp::BitAnd), ImmTy::from_int(-1, dest.layout)),
-                    "simd_reduce_or" => (Op::MirOp(BinOp::BitOr), ImmTy::from_int(0, dest.layout)),
-                    "simd_reduce_xor" => (Op::MirOp(BinOp::BitXor), ImmTy::from_int(0, dest.layout)),
-                    "simd_reduce_any" => (Op::MirOpBool(BinOp::BitOr), imm_from_bool(false)),
-                    "simd_reduce_all" => (Op::MirOpBool(BinOp::BitAnd), imm_from_bool(true)),
+                let which = match intrinsic_name {
+                    "simd_reduce_and" => Op::MirOp(BinOp::BitAnd),
+                    "simd_reduce_or" => Op::MirOp(BinOp::BitOr),
+                    "simd_reduce_xor" => Op::MirOp(BinOp::BitXor),
+                    "simd_reduce_any" => Op::MirOpBool(BinOp::BitOr),
+                    "simd_reduce_all" => Op::MirOpBool(BinOp::BitAnd),
+                    "simd_reduce_max" => Op::Max,
+                    "simd_reduce_min" => Op::Min,
                     _ => unreachable!(),
                 };
 
-                let mut res = init;
-                for i in 0..op_len {
+                // Initialize with first lane, then proceed with the rest.
+                let mut res = this.read_immediate(&this.mplace_index(&op, 0)?.into())?;
+                if matches!(which, Op::MirOpBool(_)) {
+                    // Convert to `bool` scalar.
+                    res = imm_from_bool(simd_element_to_bool(res)?);
+                }
+                for i in 1..op_len {
                     let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
                     res = match which {
                         Op::MirOp(mir_op) => {
@@ -467,6 +477,26 @@ enum Op {
                             let op = imm_from_bool(simd_element_to_bool(op)?);
                             this.binary_op(mir_op, &res, &op)?
                         }
+                        Op::Max => {
+                            // if `op > res`...
+                            if this.binary_op(BinOp::Gt, &op, &res)?.to_scalar()?.to_bool()? {
+                                // update accumulator
+                                op
+                            } else {
+                                // no change
+                                res
+                            }
+                        }
+                        Op::Min => {
+                            // if `op < res`...
+                            if this.binary_op(BinOp::Lt, &op, &res)?.to_scalar()?.to_bool()? {
+                                // update accumulator
+                                op
+                            } else {
+                                // no change
+                                res
+                            }
+                        }
                     };
                 }
                 this.write_immediate(*res, dest)?;
index 022e8c91f970dee08f2d824cdbbce49b991eb6a4..ccedf61a38109ddf30b3f70e75cc478aae4224c0 100644 (file)
@@ -24,6 +24,10 @@ fn simd_ops_f32() {
     assert_eq!(b.horizontal_sum(), 2.0);
     assert_eq!(a.horizontal_product(), 100.0 * 100.0);
     assert_eq!(b.horizontal_product(), -24.0);
+    assert_eq!(a.horizontal_max(), 10.0);
+    assert_eq!(b.horizontal_max(), 3.0);
+    assert_eq!(a.horizontal_min(), 10.0);
+    assert_eq!(b.horizontal_min(), -4.0);
 }
 
 fn simd_ops_f64() {
@@ -49,6 +53,10 @@ fn simd_ops_f64() {
     assert_eq!(b.horizontal_sum(), 2.0);
     assert_eq!(a.horizontal_product(), 100.0 * 100.0);
     assert_eq!(b.horizontal_product(), -24.0);
+    assert_eq!(a.horizontal_max(), 10.0);
+    assert_eq!(b.horizontal_max(), 3.0);
+    assert_eq!(a.horizontal_min(), 10.0);
+    assert_eq!(b.horizontal_min(), -4.0);
 }
 
 fn simd_ops_i32() {
@@ -86,6 +94,10 @@ fn simd_ops_i32() {
     assert_eq!(b.horizontal_sum(), 2);
     assert_eq!(a.horizontal_product(), 100 * 100);
     assert_eq!(b.horizontal_product(), -24);
+    assert_eq!(a.horizontal_max(), 10);
+    assert_eq!(b.horizontal_max(), 3);
+    assert_eq!(a.horizontal_min(), 10);
+    assert_eq!(b.horizontal_min(), -4);
 }
 
 fn simd_mask() {