]> git.lizzy.rs Git - rust.git/blobdiff - compiler/rustc_mir/src/transform/uninhabited_enum_branching.rs
Refactor how SwitchInt stores jump targets
[rust.git] / compiler / rustc_mir / src / transform / uninhabited_enum_branching.rs
index 4cca4d223c0cb6dd7c2f6cff5e71b762b549012b..465832c89fd00871b7f7117c0d28ee196152efc4 100644 (file)
@@ -1,8 +1,10 @@
 //! A pass that eliminates branches on uninhabited enum variants.
 
-use crate::transform::{MirPass, MirSource};
+use crate::transform::MirPass;
+use rustc_data_structures::stable_set::FxHashSet;
 use rustc_middle::mir::{
-    BasicBlock, BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, TerminatorKind,
+    BasicBlock, BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets,
+    TerminatorKind,
 };
 use rustc_middle::ty::layout::TyAndLayout;
 use rustc_middle::ty::{Ty, TyCtxt};
@@ -52,9 +54,13 @@ fn variant_discriminants<'tcx>(
     layout: &TyAndLayout<'tcx>,
     ty: Ty<'tcx>,
     tcx: TyCtxt<'tcx>,
-) -> Vec<u128> {
+) -> FxHashSet<u128> {
     match &layout.variants {
-        Variants::Single { index } => vec![index.as_u32() as u128],
+        Variants::Single { index } => {
+            let mut res = FxHashSet::default();
+            res.insert(index.as_u32() as u128);
+            res
+        }
         Variants::Multiple { variants, .. } => variants
             .iter_enumerated()
             .filter_map(|(idx, layout)| {
@@ -66,12 +72,12 @@ fn variant_discriminants<'tcx>(
 }
 
 impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
-    fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
-        if source.promoted.is_some() {
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+        if body.source.promoted.is_some() {
             return;
         }
 
-        trace!("UninhabitedEnumBranching starting for {:?}", source);
+        trace!("UninhabitedEnumBranching starting for {:?}", body.source);
 
         let basic_block_count = body.basic_blocks().len();
 
@@ -86,7 +92,7 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'
                     continue;
                 };
 
-            let layout = tcx.layout_of(tcx.param_env(source.def_id()).and(discriminant_ty));
+            let layout = tcx.layout_of(tcx.param_env(body.source.def_id()).and(discriminant_ty));
 
             let allowed_variants = if let Ok(layout) = layout {
                 variant_discriminants(&layout, discriminant_ty, tcx)
@@ -96,21 +102,15 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'
 
             trace!("allowed_variants = {:?}", allowed_variants);
 
-            if let TerminatorKind::SwitchInt { values, targets, .. } =
+            if let TerminatorKind::SwitchInt { targets, .. } =
                 &mut body.basic_blocks_mut()[bb].terminator_mut().kind
             {
-                // take otherwise out early
-                let otherwise = targets.pop().unwrap();
-                assert_eq!(targets.len(), values.len());
-                let mut i = 0;
-                targets.retain(|_| {
-                    let keep = allowed_variants.contains(&values[i]);
-                    i += 1;
-                    keep
-                });
-                targets.push(otherwise);
-
-                values.to_mut().retain(|var| allowed_variants.contains(var));
+                let new_targets = SwitchTargets::new(
+                    targets.iter().filter(|(val, _)| allowed_variants.contains(val)),
+                    targets.otherwise(),
+                );
+
+                *targets = new_targets;
             } else {
                 unreachable!()
             }