]> git.lizzy.rs Git - rust.git/blobdiff - compiler/rustc_mir/src/transform/simplify_try.rs
Refactor how SwitchInt stores jump targets
[rust.git] / compiler / rustc_mir / src / transform / simplify_try.rs
index b2c889bfd05e8af1883c0c010005b3dedec27f5f..27bb1def726e1a51ed5cf6f68e3ea848c29f2e84 100644 (file)
@@ -9,7 +9,7 @@
 //!
 //! into just `x`.
 
-use crate::transform::{simplify, MirPass, MirSource};
+use crate::transform::{simplify, MirPass};
 use itertools::Itertools as _;
 use rustc_index::{bit_set::BitSet, vec::IndexVec};
 use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
@@ -367,13 +367,15 @@ fn optimization_applies<'tcx>(
 }
 
 impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
-    fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         // FIXME(77359): This optimization can result in unsoundness.
         if !tcx.sess.opts.debugging_opts.unsound_mir_opts {
             return;
         }
 
+        let source = body.source;
         trace!("running SimplifyArmIdentity on {:?}", source);
+
         let local_uses = LocalUseCounter::get_local_uses(body);
         let (basic_blocks, local_decls, debug_info) =
             body.basic_blocks_local_decls_mut_and_var_debug_info();
@@ -528,8 +530,8 @@ fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarFiel
 pub struct SimplifyBranchSame;
 
 impl<'tcx> MirPass<'tcx> for SimplifyBranchSame {
-    fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
-        trace!("Running SimplifyBranchSame on {:?}", source);
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+        trace!("Running SimplifyBranchSame on {:?}", body.source);
         let finder = SimplifyBranchSameOptimizationFinder { body, tcx };
         let opts = finder.find();
 
@@ -574,15 +576,13 @@ fn find(&self) -> Vec<SimplifyBranchSameOptimization> {
             .iter_enumerated()
             .filter_map(|(bb_idx, bb)| {
                 let (discr_switched_on, targets_and_values) = match &bb.terminator().kind {
-                    TerminatorKind::SwitchInt { targets, discr, values, .. } => {
-                        // if values.len() == targets.len() - 1, we need to include None where no value is present
-                        // such that the zip does not throw away targets. If no `otherwise` case is in targets, the zip will simply throw away the added None
-                        let values_extended = values.iter().map(|x|Some(*x)).chain(once(None));
-                        let targets_and_values:Vec<_> = targets.iter().zip(values_extended)
-                            .map(|(target, value)| SwitchTargetAndValue{target:*target, value})
+                    TerminatorKind::SwitchInt { targets, discr, .. } => {
+                        let targets_and_values: Vec<_> = targets.iter()
+                            .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) })
+                            .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None }))
                             .collect();
-                        assert_eq!(targets.len(), targets_and_values.len());
-                        (discr, targets_and_values)},
+                        (discr, targets_and_values)
+                    },
                     _ => return None,
                 };
 
@@ -628,7 +628,8 @@ fn find(&self) -> Vec<SimplifyBranchSameOptimization> {
                 // All successor basic blocks must be equal or contain statements that are pairwise considered equal.
                 for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() {
                     let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup
-                    && bb_l.terminator().kind == bb_r.terminator().kind;
+                                            && bb_l.terminator().kind == bb_r.terminator().kind
+                                            && bb_l.statements.len() == bb_r.statements.len();
                     let statement_check = || {
                         bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| {
                             let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r);