]> git.lizzy.rs Git - rust.git/blobdiff - compiler/rustc_mir/src/transform/simplify_comparison_integral.rs
Refactor how SwitchInt stores jump targets
[rust.git] / compiler / rustc_mir / src / transform / simplify_comparison_integral.rs
index 9f837cf78a60842631faffe69d4c2883e9d02021..50f5fbb3bc22234ff99f998e2005db176940662f 100644 (file)
@@ -1,8 +1,10 @@
+use std::iter;
+
 use super::MirPass;
 use rustc_middle::{
     mir::{
         interpret::Scalar, BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement,
-        StatementKind, TerminatorKind,
+        StatementKind, SwitchTargets, TerminatorKind,
     },
     ty::{Ty, TyCtxt},
 };
@@ -43,19 +45,21 @@ fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
                 Scalar::Ptr(_) => continue,
             };
             const FALSE: u128 = 0;
-            let mut new_targets = opt.targets.clone();
-            let first_is_false_target = opt.values[0] == FALSE;
+
+            let mut new_targets = opt.targets;
+            let first_value = new_targets.iter().next().unwrap().0;
+            let first_is_false_target = first_value == FALSE;
             match opt.op {
                 BinOp::Eq => {
                     // if the assignment was Eq we want the true case to be first
                     if first_is_false_target {
-                        new_targets.swap(0, 1);
+                        new_targets.all_targets_mut().swap(0, 1);
                     }
                 }
                 BinOp::Ne => {
                     // if the assignment was Ne we want the false case to be first
                     if !first_is_false_target {
-                        new_targets.swap(0, 1);
+                        new_targets.all_targets_mut().swap(0, 1);
                     }
                 }
                 _ => unreachable!(),
@@ -96,7 +100,7 @@ fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
                 }
                 storage_deads_to_remove.push((stmt_idx, opt.bb_idx));
                 // if we have StorageDeads to remove then make sure to insert them at the top of each target
-                for bb_idx in new_targets.iter() {
+                for bb_idx in new_targets.all_targets() {
                     storage_deads_to_insert.push((
                         *bb_idx,
                         Statement {
@@ -107,13 +111,18 @@ fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
                 }
             }
 
-            let terminator = bb.terminator_mut();
+            let [bb_cond, bb_otherwise] = match new_targets.all_targets() {
+                [a, b] => [*a, *b],
+                e => bug!("expected 2 switch targets, got: {:?}", e),
+            };
+
+            let targets = SwitchTargets::new(iter::once((new_value, bb_cond)), bb_otherwise);
 
+            let terminator = bb.terminator_mut();
             terminator.kind = TerminatorKind::SwitchInt {
                 discr: Operand::Move(opt.to_switch_on),
                 switch_ty: opt.branch_value_ty,
-                values: vec![new_value].into(),
-                targets: new_targets,
+                targets,
             };
         }
 
@@ -138,15 +147,13 @@ fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> {
             .iter_enumerated()
             .filter_map(|(bb_idx, bb)| {
                 // find switch
-                let (place_switched_on, values, targets, place_switched_on_moved) = match &bb
-                    .terminator()
-                    .kind
-                {
-                    rustc_middle::mir::TerminatorKind::SwitchInt {
-                        discr, values, targets, ..
-                    } => Some((discr.place()?, values, targets, discr.is_move())),
-                    _ => None,
-                }?;
+                let (place_switched_on, targets, place_switched_on_moved) =
+                    match &bb.terminator().kind {
+                        rustc_middle::mir::TerminatorKind::SwitchInt { discr, targets, .. } => {
+                            Some((discr.place()?, targets, discr.is_move()))
+                        }
+                        _ => None,
+                    }?;
 
                 // find the statement that assigns the place being switched on
                 bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
@@ -167,7 +174,6 @@ fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> {
                                         branch_value_scalar,
                                         branch_value_ty,
                                         op: *op,
-                                        values: values.clone().into_owned(),
                                         targets: targets.clone(),
                                     })
                                 }
@@ -220,8 +226,6 @@ struct OptimizationInfo<'tcx> {
     branch_value_ty: Ty<'tcx>,
     /// Either Eq or Ne
     op: BinOp,
-    /// Current values used in the switch target. This needs to be replaced with the branch_value
-    values: Vec<u128>,
     /// Current targets used in the switch
-    targets: Vec<BasicBlock>,
+    targets: SwitchTargets<'tcx>,
 }