]> git.lizzy.rs Git - rust.git/blob - src/librustc_mir/transform/uninhabited_enum_branching.rs
pin docs: add some forward references
[rust.git] / src / librustc_mir / transform / uninhabited_enum_branching.rs
1 //! A pass that eliminates branches on uninhabited enum variants.
2
3 use crate::transform::{MirPass, MirSource};
4 use rustc_middle::mir::{
5     BasicBlock, BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, TerminatorKind,
6 };
7 use rustc_middle::ty::layout::TyAndLayout;
8 use rustc_middle::ty::{Ty, TyCtxt};
9 use rustc_target::abi::{Abi, Variants};
10
11 pub struct UninhabitedEnumBranching;
12
13 fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> {
14     if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator {
15         p.as_local()
16     } else {
17         None
18     }
19 }
20
21 /// If the basic block terminates by switching on a discriminant, this returns the `Ty` the
22 /// discriminant is read from. Otherwise, returns None.
23 fn get_switched_on_type<'tcx>(
24     block_data: &BasicBlockData<'tcx>,
25     body: &Body<'tcx>,
26 ) -> Option<Ty<'tcx>> {
27     let terminator = block_data.terminator();
28
29     // Only bother checking blocks which terminate by switching on a local.
30     if let Some(local) = get_discriminant_local(&terminator.kind) {
31         let stmt_before_term = (!block_data.statements.is_empty())
32             .then(|| &block_data.statements[block_data.statements.len() - 1].kind);
33
34         if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term
35         {
36             if l.as_local() == Some(local) {
37                 if let Some(r_local) = place.as_local() {
38                     let ty = body.local_decls[r_local].ty;
39
40                     if ty.is_enum() {
41                         return Some(ty);
42                     }
43                 }
44             }
45         }
46     }
47
48     None
49 }
50
51 fn variant_discriminants<'tcx>(
52     layout: &TyAndLayout<'tcx>,
53     ty: Ty<'tcx>,
54     tcx: TyCtxt<'tcx>,
55 ) -> Vec<u128> {
56     match &layout.variants {
57         Variants::Single { index } => vec![index.as_u32() as u128],
58         Variants::Multiple { variants, .. } => variants
59             .iter_enumerated()
60             .filter_map(|(idx, layout)| {
61                 (layout.abi != Abi::Uninhabited)
62                     .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val)
63             })
64             .collect(),
65     }
66 }
67
68 impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
69     fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
70         if source.promoted.is_some() {
71             return;
72         }
73
74         trace!("UninhabitedEnumBranching starting for {:?}", source);
75
76         let basic_block_count = body.basic_blocks().len();
77
78         for bb in 0..basic_block_count {
79             let bb = BasicBlock::from_usize(bb);
80             trace!("processing block {:?}", bb);
81
82             let discriminant_ty =
83                 if let Some(ty) = get_switched_on_type(&body.basic_blocks()[bb], body) {
84                     ty
85                 } else {
86                     continue;
87                 };
88
89             let layout = tcx.layout_of(tcx.param_env(source.def_id()).and(discriminant_ty));
90
91             let allowed_variants = if let Ok(layout) = layout {
92                 variant_discriminants(&layout, discriminant_ty, tcx)
93             } else {
94                 continue;
95             };
96
97             trace!("allowed_variants = {:?}", allowed_variants);
98
99             if let TerminatorKind::SwitchInt { values, targets, .. } =
100                 &mut body.basic_blocks_mut()[bb].terminator_mut().kind
101             {
102                 let vals = &*values;
103                 let zipped = vals.iter().zip(targets.iter());
104
105                 let mut matched_values = Vec::with_capacity(allowed_variants.len());
106                 let mut matched_targets = Vec::with_capacity(allowed_variants.len() + 1);
107
108                 for (val, target) in zipped {
109                     if allowed_variants.contains(val) {
110                         matched_values.push(*val);
111                         matched_targets.push(*target);
112                     } else {
113                         trace!("eliminating {:?} -> {:?}", val, target);
114                     }
115                 }
116
117                 // handle the "otherwise" branch
118                 matched_targets.push(targets.pop().unwrap());
119
120                 *values = matched_values.into();
121                 *targets = matched_targets;
122             } else {
123                 unreachable!()
124             }
125         }
126     }
127 }