]> git.lizzy.rs Git - rust.git/commitdiff
Auto merge of #75382 - JulianKnodt:match_branches, r=oli-obk
authorbors <bors@rust-lang.org>
Thu, 13 Aug 2020 19:26:35 +0000 (19:26 +0000)
committerbors <bors@rust-lang.org>
Thu, 13 Aug 2020 19:26:35 +0000 (19:26 +0000)
First iteration of simplify match branches

This is a simple MIR pass that attempts to convert
```
   bb0: {
        StorageLive(_2);
        _3 = discriminant(_1);
        switchInt(move _3) -> [0isize: bb2, otherwise: bb1];
    }

    bb1: {
        _2 = const false;
        goto -> bb3;
    }

    bb2: {
        _2 = const true;
        goto -> bb3;
    }
```
into
```
    bb0: {
        StorageLive(_2);
        _3 = discriminant(_1);
        _2 = _3 == 0;
        goto -> bb3;
    }
```
There are still missing components(like checking if the assignments are bools).
Was hoping that this could get some review though.

Handles #75141

r? @oli-obk

src/librustc_mir/transform/match_branches.rs [new file with mode: 0644]
src/librustc_mir/transform/mod.rs
src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.32bit [new file with mode: 0644]
src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.64bit [new file with mode: 0644]
src/test/mir-opt/matches_reduce_branches.rs [new file with mode: 0644]

diff --git a/src/librustc_mir/transform/match_branches.rs b/src/librustc_mir/transform/match_branches.rs
new file mode 100644 (file)
index 0000000..74da6d5
--- /dev/null
@@ -0,0 +1,93 @@
+use crate::transform::{MirPass, MirSource};
+use rustc_middle::mir::*;
+use rustc_middle::ty::TyCtxt;
+
+pub struct MatchBranchSimplification;
+
+// What's the intent of this pass?
+// If one block is found that switches between blocks which both go to the same place
+// AND both of these blocks set a similar const in their ->
+// condense into 1 block based on discriminant AND goto the destination afterwards
+
+impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) {
+        let param_env = tcx.param_env(src.def_id());
+        let bbs = body.basic_blocks_mut();
+        'outer: for bb_idx in bbs.indices() {
+            let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind {
+                TerminatorKind::SwitchInt {
+                    discr: Operand::Move(ref place),
+                    switch_ty,
+                    ref targets,
+                    ref values,
+                    ..
+                } if targets.len() == 2 && values.len() == 1 => {
+                    (place, values[0], switch_ty, targets[0], targets[1])
+                }
+                // Only optimize switch int statements
+                _ => continue,
+            };
+
+            // Check that destinations are identical, and if not, then don't optimize this block
+            if &bbs[first].terminator().kind != &bbs[second].terminator().kind {
+                continue;
+            }
+
+            // Check that blocks are assignments of consts to the same place or same statement,
+            // and match up 1-1, if not don't optimize this block.
+            let first_stmts = &bbs[first].statements;
+            let scnd_stmts = &bbs[second].statements;
+            if first_stmts.len() != scnd_stmts.len() {
+                continue;
+            }
+            for (f, s) in first_stmts.iter().zip(scnd_stmts.iter()) {
+                match (&f.kind, &s.kind) {
+                    // If two statements are exactly the same just ignore them.
+                    (f_s, s_s) if f_s == s_s => (),
+
+                    (
+                        StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
+                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
+                    ) if lhs_f == lhs_s => {
+                        if let Some(f_c) = f_c.literal.try_eval_bool(tcx, param_env) {
+                            // This should also be a bool because it's writing to the same place
+                            let s_c = s_c.literal.try_eval_bool(tcx, param_env).unwrap();
+                            if f_c != s_c {
+                                // have to check this here because f_c & s_c might have
+                                // different spans.
+                                continue;
+                            }
+                        }
+                        continue 'outer;
+                    }
+                    // If there are not exclusively assignments, then ignore this
+                    _ => continue 'outer,
+                }
+            }
+            // Take owenership of items now that we know we can optimize.
+            let discr = discr.clone();
+            let (from, first) = bbs.pick2_mut(bb_idx, first);
+
+            let new_stmts = first.statements.iter().cloned().map(|mut s| {
+                if let StatementKind::Assign(box (_, ref mut rhs)) = s.kind {
+                    if let Rvalue::Use(Operand::Constant(c)) = rhs {
+                        let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size;
+                        let const_cmp = Operand::const_from_scalar(
+                            tcx,
+                            switch_ty,
+                            crate::interpret::Scalar::from_uint(val, size),
+                            rustc_span::DUMMY_SP,
+                        );
+                        if let Some(c) = c.literal.try_eval_bool(tcx, param_env) {
+                            let op = if c { BinOp::Eq } else { BinOp::Ne };
+                            *rhs = Rvalue::BinaryOp(op, Operand::Move(discr), const_cmp);
+                        }
+                    }
+                }
+                s
+            });
+            from.statements.extend(new_stmts);
+            from.terminator_mut().kind = first.terminator().kind.clone();
+        }
+    }
+}
index 3803ee78fd4d9e5543550ab9bea710c2fb8c6c59..4f26f3bb45973b178c36fc8821c9bc24025221b8 100644 (file)
@@ -29,6 +29,7 @@
 pub mod inline;
 pub mod instcombine;
 pub mod instrument_coverage;
+pub mod match_branches;
 pub mod no_landing_pads;
 pub mod nrvo;
 pub mod promote_consts;
@@ -440,6 +441,7 @@ fn run_optimization_passes<'tcx>(
         // with async primitives.
         &generator::StateTransform,
         &instcombine::InstCombine,
+        &match_branches::MatchBranchSimplification,
         &const_prop::ConstProp,
         &simplify_branches::SimplifyBranches::new("after-const-prop"),
         &simplify_try::SimplifyArmIdentity,
diff --git a/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.32bit b/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.32bit
new file mode 100644 (file)
index 0000000..df94c89
--- /dev/null
@@ -0,0 +1,66 @@
+- // MIR for `foo` before MatchBranchSimplification
++ // MIR for `foo` after MatchBranchSimplification
+  
+  fn foo(_1: std::option::Option<()>) -> () {
+      debug bar => _1;                     // in scope 0 at $DIR/matches_reduce_branches.rs:4:8: 4:11
+      let mut _0: ();                      // return place in scope 0 at $DIR/matches_reduce_branches.rs:4:25: 4:25
+      let mut _2: bool;                    // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+      let mut _3: isize;                   // in scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
+  
+      bb0: {
+          StorageLive(_2);                 // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+          _3 = discriminant(_1);           // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
+-         switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
++         _2 = Eq(move _3, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
++                                          // ty::Const
++                                          // + ty: isize
++                                          // + val: Value(Scalar(0x00000000))
++                                          // mir::Constant
++                                          // + span: $DIR/matches_reduce_branches.rs:1:1: 1:1
++                                          // + literal: Const { ty: isize, val: Value(Scalar(0x00000000)) }
++         goto -> bb3;                     // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
+      }
+  
+      bb1: {
+          _2 = const false;                // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+                                           // ty::Const
+                                           // + ty: bool
+                                           // + val: Value(Scalar(0x00))
+                                           // mir::Constant
+                                           // + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
+                                           // + literal: Const { ty: bool, val: Value(Scalar(0x00)) }
+          goto -> bb3;                     // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+      }
+  
+      bb2: {
+          _2 = const true;                 // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+                                           // ty::Const
+                                           // + ty: bool
+                                           // + val: Value(Scalar(0x01))
+                                           // mir::Constant
+                                           // + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
+                                           // + literal: Const { ty: bool, val: Value(Scalar(0x01)) }
+          goto -> bb3;                     // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+      }
+  
+      bb3: {
+          switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
+      }
+  
+      bb4: {
+          _0 = const ();                   // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
+                                           // ty::Const
+                                           // + ty: ()
+                                           // + val: Value(Scalar(<ZST>))
+                                           // mir::Constant
+                                           // + span: $DIR/matches_reduce_branches.rs:5:5: 7:6
+                                           // + literal: Const { ty: (), val: Value(Scalar(<ZST>)) }
+          goto -> bb5;                     // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
+      }
+  
+      bb5: {
+          StorageDead(_2);                 // scope 0 at $DIR/matches_reduce_branches.rs:8:1: 8:2
+          return;                          // scope 0 at $DIR/matches_reduce_branches.rs:8:2: 8:2
+      }
+  }
+  
diff --git a/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.64bit b/src/test/mir-opt/matches_reduce_branches.foo.MatchBranchSimplification.diff.64bit
new file mode 100644 (file)
index 0000000..06849b4
--- /dev/null
@@ -0,0 +1,66 @@
+- // MIR for `foo` before MatchBranchSimplification
++ // MIR for `foo` after MatchBranchSimplification
+  
+  fn foo(_1: std::option::Option<()>) -> () {
+      debug bar => _1;                     // in scope 0 at $DIR/matches_reduce_branches.rs:4:8: 4:11
+      let mut _0: ();                      // return place in scope 0 at $DIR/matches_reduce_branches.rs:4:25: 4:25
+      let mut _2: bool;                    // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+      let mut _3: isize;                   // in scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
+  
+      bb0: {
+          StorageLive(_2);                 // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+          _3 = discriminant(_1);           // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
+-         switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
++         _2 = Eq(move _3, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
++                                          // ty::Const
++                                          // + ty: isize
++                                          // + val: Value(Scalar(0x0000000000000000))
++                                          // mir::Constant
++                                          // + span: $DIR/matches_reduce_branches.rs:1:1: 1:1
++                                          // + literal: Const { ty: isize, val: Value(Scalar(0x0000000000000000)) }
++         goto -> bb3;                     // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
+      }
+  
+      bb1: {
+          _2 = const false;                // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+                                           // ty::Const
+                                           // + ty: bool
+                                           // + val: Value(Scalar(0x00))
+                                           // mir::Constant
+                                           // + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
+                                           // + literal: Const { ty: bool, val: Value(Scalar(0x00)) }
+          goto -> bb3;                     // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+      }
+  
+      bb2: {
+          _2 = const true;                 // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+                                           // ty::Const
+                                           // + ty: bool
+                                           // + val: Value(Scalar(0x01))
+                                           // mir::Constant
+                                           // + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
+                                           // + literal: Const { ty: bool, val: Value(Scalar(0x01)) }
+          goto -> bb3;                     // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+      }
+  
+      bb3: {
+          switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
+      }
+  
+      bb4: {
+          _0 = const ();                   // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
+                                           // ty::Const
+                                           // + ty: ()
+                                           // + val: Value(Scalar(<ZST>))
+                                           // mir::Constant
+                                           // + span: $DIR/matches_reduce_branches.rs:5:5: 7:6
+                                           // + literal: Const { ty: (), val: Value(Scalar(<ZST>)) }
+          goto -> bb5;                     // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
+      }
+  
+      bb5: {
+          StorageDead(_2);                 // scope 0 at $DIR/matches_reduce_branches.rs:8:1: 8:2
+          return;                          // scope 0 at $DIR/matches_reduce_branches.rs:8:2: 8:2
+      }
+  }
+  
diff --git a/src/test/mir-opt/matches_reduce_branches.rs b/src/test/mir-opt/matches_reduce_branches.rs
new file mode 100644 (file)
index 0000000..91b6bfc
--- /dev/null
@@ -0,0 +1,13 @@
+// EMIT_MIR_FOR_EACH_BIT_WIDTH
+// EMIT_MIR matches_reduce_branches.foo.MatchBranchSimplification.diff
+
+fn foo(bar: Option<()>) {
+    if matches!(bar, None) {
+      ()
+    }
+}
+
+fn main() {
+  let _ = foo(None);
+  let _ = foo(Some(()));
+}