1 use rustc_middle::mir::patch::MirPatch;
2 use rustc_middle::mir::*;
3 use rustc_middle::ty::{self, Ty, TyCtxt};
6 use super::simplify::simplify_cfg;
8 /// This pass optimizes something like
9 /// ```ignore (syntax-highlighting-only)
10 /// let x: Option<()>;
11 /// let y: Option<()>;
13 /// (Some(_), Some(_)) => {0},
17 /// into something like
18 /// ```ignore (syntax-highlighting-only)
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = std::mem::discriminant(x);
22 /// let discriminant_y = std::mem::discriminant(y);
23 /// if discriminant_x == discriminant_y {
27 /// } // | Actually the same bb
29 /// 1 // <--------------
33 /// Specifically, it looks for instances of control flow like this:
38 /// |---------------| ============================
39 /// | ... | /------> | BBC |
40 /// |---------------| | |--------------------------|
41 /// | switchInt(Q) | | | _cl = discriminant(P) |
42 /// | c | --------/ |--------------------------|
43 /// | d | -------\ | switchInt(_cl) |
44 /// | ... | | | c | ---> BBC.2
45 /// | otherwise | --\ | /--- | otherwise |
46 /// ================= | | | ============================
48 /// ================= | | |
49 /// | BBU | <-| | | ============================
50 /// |---------------| | \-------> | BBD |
51 /// |---------------| | | |--------------------------|
52 /// | unreachable | | | | _dl = discriminant(P) |
53 /// ================= | | |--------------------------|
54 /// | | | switchInt(_dl) |
55 /// ================= | | | d | ---> BBD.2
56 /// | BB9 | <--------------- | otherwise |
57 /// |---------------| ============================
61 /// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
63 /// - `BB1` is `parent` and `BBC, BBD` are children
64 /// - `P` is `child_place`
65 /// - `child_ty` is the type of `_cl`.
66 /// - `Q` is `parent_op`.
67 /// - `parent_ty` is the type of `Q`.
68 /// - `BB9` is `destination`
69 /// All this is then transformed into:
72 /// =======================
74 /// |---------------------| ============================
75 /// | ... | /------> | BBEq |
76 /// | _s = discriminant(P)| | |--------------------------|
77 /// | _t = Ne(Q, _s) | | |--------------------------|
78 /// |---------------------| | | switchInt(Q) |
79 /// | switchInt(_t) | | | c | ---> BBC.2
80 /// | false | --------/ | d | ---> BBD.2
81 /// | otherwise | ---------------- | otherwise |
82 /// ======================= | ============================
84 /// ================= |
85 /// | BB9 | <-----------/
91 /// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
92 /// The filter on which `P` are allowed (together with discussion of its correctness) is found in
94 pub struct EarlyOtherwiseBranch;
96 impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
97 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
98 sess.mir_opt_level() >= 2
101 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
102 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
104 let mut should_cleanup = false;
106 // Also consider newly generated bbs in the same pass
107 for i in 0..body.basic_blocks().len() {
108 let bbs = body.basic_blocks();
109 let parent = BasicBlock::from_usize(i);
110 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else {
114 if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_data)) {
118 trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data);
120 should_cleanup = true;
122 let TerminatorKind::SwitchInt {
124 switch_ty: parent_ty,
125 targets: parent_targets
126 } = &bbs[parent].terminator().kind else {
129 // Always correct since we can only switch on `Copy` types
130 let parent_op = match parent_op {
131 Operand::Move(x) => Operand::Copy(*x),
132 Operand::Copy(x) => Operand::Copy(*x),
133 Operand::Constant(x) => Operand::Constant(x.clone()),
135 let statements_before = bbs[parent].statements.len();
136 let parent_end = Location { block: parent, statement_index: statements_before };
138 let mut patch = MirPatch::new(body);
140 // create temp to store second discriminant in, `_s` in example above
141 let second_discriminant_temp =
142 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
144 patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
146 // create assignment of discriminant
149 Place::from(second_discriminant_temp),
150 Rvalue::Discriminant(opt_data.child_place),
153 // create temp to store inequality comparison between the two discriminants, `_t` in
155 let nequal = BinOp::Ne;
156 let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
157 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
158 patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
160 // create inequality comparison between the two discriminants
161 let comp_rvalue = Rvalue::BinaryOp(
163 Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
167 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
170 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
171 let TerminatorKind::SwitchInt{ targets, .. } = &bbs[child].terminator().kind else {
174 (value, targets.target_for_value(value))
176 let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
178 // Create `bbEq` in example above
179 let eq_switch = BasicBlockData::new(Some(Terminator {
180 source_info: bbs[parent].terminator().source_info,
181 kind: TerminatorKind::SwitchInt {
182 // switch on the first discriminant, so we can mark the second one as dead
184 switch_ty: opt_data.child_ty,
189 let eq_bb = patch.new_block(eq_switch);
191 // Jump to it on the basis of the inequality comparison
192 let true_case = opt_data.destination;
193 let false_case = eq_bb;
194 patch.patch_terminator(
198 Operand::Move(Place::from(comp_temp)),
204 // generate StorageDead for the second_discriminant_temp not in use anymore
205 patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
207 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
209 for bb in [false_case, true_case].iter() {
211 Location { block: *bb, statement_index: 0 },
212 StatementKind::StorageDead(comp_temp),
219 // Since this optimization adds new basic blocks and invalidates others,
220 // clean up the cfg to make it nicer for other passes
222 simplify_cfg(tcx, body);
227 /// Returns true if computing the discriminant of `place` may be hoisted out of the branch
228 fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
229 for (place, proj) in place.iter_projections() {
231 // Dereferencing in the computation of `place` might cause issues from one of two
232 // cateogires. First, the referrent might be invalid. We protect against this by
233 // dereferencing references only (not pointers). Second, the use of a reference may
234 // invalidate other references that are used later (for aliasing reasons). Consider
235 // where such an invalidated reference may appear:
236 // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
237 // cannot contain referenced data.
238 // - In `BBU`: Not possible since that block contains only the `unreachable` terminator
239 // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
240 // reaching that block in the input to our transformation, and so any data
241 // invalidated by that computation could not have been used there.
242 // - In `BB9`: Not possible since control flow might have reached `BB9` via the
243 // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
244 // have invalidated the data when computing `discriminant(P)`
245 // So dereferencing here is correct.
246 ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
250 // Field projections are always valid
251 ProjectionElem::Field(..) => {}
253 // downcasts either, since the correctness of the downcast may depend on the parent
254 // branch being taken. An easy example of this is
256 // Q = discriminant(_3)
257 // P = (_3 as Variant)
259 // However, checking if the child and parent place are the same and only erroring then
260 // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
261 // be replaced by another optimization pass with any other condition that can be proven
263 ProjectionElem::Downcast(..) => {
266 // We cannot allow indexing since the index may be out of bounds.
276 struct OptimizationData<'tcx> {
277 destination: BasicBlock,
278 child_place: Place<'tcx>,
280 child_source: SourceInfo,
283 fn evaluate_candidate<'tcx>(
287 ) -> Option<OptimizationData<'tcx>> {
288 let bbs = body.basic_blocks();
289 let TerminatorKind::SwitchInt {
291 switch_ty: parent_ty,
293 } = &bbs[parent].terminator().kind else {
297 let poss = targets.otherwise();
298 // If the fallthrough on the parent is trivially unreachable, we can let the
299 // children choose the destination
300 if bbs[poss].statements.len() == 0
301 && bbs[poss].terminator().kind == TerminatorKind::Unreachable
308 let Some((_, child)) = targets.iter().next() else {
311 let child_terminator = &bbs[child].terminator();
312 let TerminatorKind::SwitchInt {
314 targets: child_targets,
316 } = &child_terminator.kind else {
319 if child_ty != parent_ty {
322 let Some(StatementKind::Assign(boxed))
323 = &bbs[child].statements.first().map(|x| &x.kind) else {
326 let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
329 let destination = parent_dest.unwrap_or(child_targets.otherwise());
331 // Verify that the optimization is legal in general
332 // We can hoist evaluating the child discriminant out of the branch
333 if !may_hoist(tcx, body, *child_place) {
337 // Verify that the optimization is legal for each branch
338 for (value, child) in targets.iter() {
339 if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
343 Some(OptimizationData {
345 child_place: *child_place,
347 child_source: child_terminator.source_info,
351 fn verify_candidate_branch<'tcx>(
352 branch: &BasicBlockData<'tcx>,
355 destination: BasicBlock,
357 // In order for the optimization to be correct, the branch must...
358 // ...have exactly one statement
359 if branch.statements.len() != 1 {
362 // ...assign the descriminant of `place` in that statement
363 let StatementKind::Assign(boxed) = &branch.statements[0].kind else {
366 let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else {
369 if *from_place != place {
372 // ...make that assignment to a local
373 if discr_place.projection.len() != 0 {
376 // ...terminate on a `SwitchInt` that invalidates that local
377 let TerminatorKind::SwitchInt{ discr: switch_op, targets, .. } = &branch.terminator().kind else {
380 if *switch_op != Operand::Move(*discr_place) {
383 // ...fall through to `destination` if the switch misses
384 if destination != targets.otherwise() {
387 // ...have a branch for value `value`
388 let mut iter = targets.iter();
389 let Some((target_value, _)) = iter.next() else {
392 if target_value != value {
395 // ...and have no more branches
396 if let Some(_) = iter.next() {