]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Auto merge of #106910 - aliemjay:alias-ty-in-regionck, r=oli-obk
[rust.git] / compiler / rustc_mir_transform / src / early_otherwise_branch.rs
1 use rustc_middle::mir::patch::MirPatch;
2 use rustc_middle::mir::*;
3 use rustc_middle::ty::{self, Ty, TyCtxt};
4 use std::fmt::Debug;
5
6 use super::simplify::simplify_cfg;
7
8 /// This pass optimizes something like
9 /// ```ignore (syntax-highlighting-only)
10 /// let x: Option<()>;
11 /// let y: Option<()>;
12 /// match (x,y) {
13 ///     (Some(_), Some(_)) => {0},
14 ///     _ => {1}
15 /// }
16 /// ```
17 /// into something like
18 /// ```ignore (syntax-highlighting-only)
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = std::mem::discriminant(x);
22 /// let discriminant_y = std::mem::discriminant(y);
23 /// if discriminant_x == discriminant_y {
24 ///     match x {
25 ///         Some(_) => 0,
26 ///         _ => 1, // <----
27 ///     } //               | Actually the same bb
28 /// } else { //            |
29 ///     1 // <--------------
30 /// }
31 /// ```
32 ///
33 /// Specifically, it looks for instances of control flow like this:
34 /// ```text
35 ///
36 ///     =================
37 ///     |      BB1      |
38 ///     |---------------|                  ============================
39 ///     |     ...       |         /------> |            BBC           |
40 ///     |---------------|         |        |--------------------------|
41 ///     |  switchInt(Q) |         |        |   _cl = discriminant(P)  |
42 ///     |       c       | --------/        |--------------------------|
43 ///     |       d       | -------\         |       switchInt(_cl)     |
44 ///     |      ...      |        |         |            c             | ---> BBC.2
45 ///     |    otherwise  | --\    |    /--- |         otherwise        |
46 ///     =================   |    |    |    ============================
47 ///                         |    |    |
48 ///     =================   |    |    |
49 ///     |      BBU      | <-|    |    |    ============================
50 ///     |---------------|   |    \-------> |            BBD           |
51 ///     |---------------|   |         |    |--------------------------|
52 ///     |  unreachable  |   |         |    |   _dl = discriminant(P)  |
53 ///     =================   |         |    |--------------------------|
54 ///                         |         |    |       switchInt(_dl)     |
55 ///     =================   |         |    |            d             | ---> BBD.2
56 ///     |      BB9      | <--------------- |         otherwise        |
57 ///     |---------------|                  ============================
58 ///     |      ...      |
59 ///     =================
60 /// ```
61 /// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
62 /// code:
63 ///  - `BB1` is `parent` and `BBC, BBD` are children
64 ///  - `P` is `child_place`
65 ///  - `child_ty` is the type of `_cl`.
66 ///  - `Q` is `parent_op`.
67 ///  - `parent_ty` is the type of `Q`.
68 ///  - `BB9` is `destination`
69 /// All this is then transformed into:
70 /// ```text
71 ///
72 ///     =======================
73 ///     |          BB1        |
74 ///     |---------------------|                  ============================
75 ///     |          ...        |         /------> |           BBEq           |
76 ///     | _s = discriminant(P)|         |        |--------------------------|
77 ///     | _t = Ne(Q, _s)      |         |        |--------------------------|
78 ///     |---------------------|         |        |       switchInt(Q)       |
79 ///     |     switchInt(_t)   |         |        |            c             | ---> BBC.2
80 ///     |        false        | --------/        |            d             | ---> BBD.2
81 ///     |       otherwise     | ---------------- |         otherwise        |
82 ///     =======================       |          ============================
83 ///                                   |
84 ///     =================             |
85 ///     |      BB9      | <-----------/
86 ///     |---------------|
87 ///     |      ...      |
88 ///     =================
89 /// ```
90 ///
91 /// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
92 /// The filter on which `P` are allowed (together with discussion of its correctness) is found in
93 /// `may_hoist`.
94 pub struct EarlyOtherwiseBranch;
95
96 impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
97     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
98         sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts
99     }
100
101     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
102         trace!("running EarlyOtherwiseBranch on {:?}", body.source);
103
104         let mut should_cleanup = false;
105
106         // Also consider newly generated bbs in the same pass
107         for i in 0..body.basic_blocks.len() {
108             let bbs = &*body.basic_blocks;
109             let parent = BasicBlock::from_usize(i);
110             let Some(opt_data) = evaluate_candidate(tcx, body, parent) else {
111                 continue
112             };
113
114             if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_data)) {
115                 break;
116             }
117
118             trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data);
119
120             should_cleanup = true;
121
122             let TerminatorKind::SwitchInt {
123                 discr: parent_op,
124                 targets: parent_targets
125             } = &bbs[parent].terminator().kind else {
126                 unreachable!()
127             };
128             // Always correct since we can only switch on `Copy` types
129             let parent_op = match parent_op {
130                 Operand::Move(x) => Operand::Copy(*x),
131                 Operand::Copy(x) => Operand::Copy(*x),
132                 Operand::Constant(x) => Operand::Constant(x.clone()),
133             };
134             let parent_ty = parent_op.ty(body.local_decls(), tcx);
135             let statements_before = bbs[parent].statements.len();
136             let parent_end = Location { block: parent, statement_index: statements_before };
137
138             let mut patch = MirPatch::new(body);
139
140             // create temp to store second discriminant in, `_s` in example above
141             let second_discriminant_temp =
142                 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
143
144             patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
145
146             // create assignment of discriminant
147             patch.add_assign(
148                 parent_end,
149                 Place::from(second_discriminant_temp),
150                 Rvalue::Discriminant(opt_data.child_place),
151             );
152
153             // create temp to store inequality comparison between the two discriminants, `_t` in
154             // example above
155             let nequal = BinOp::Ne;
156             let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
157             let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
158             patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
159
160             // create inequality comparison between the two discriminants
161             let comp_rvalue = Rvalue::BinaryOp(
162                 nequal,
163                 Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
164             );
165             patch.add_statement(
166                 parent_end,
167                 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
168             );
169
170             let eq_new_targets = parent_targets.iter().map(|(value, child)| {
171                 let TerminatorKind::SwitchInt{ targets, .. } = &bbs[child].terminator().kind else {
172                     unreachable!()
173                 };
174                 (value, targets.target_for_value(value))
175             });
176             let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
177
178             // Create `bbEq` in example above
179             let eq_switch = BasicBlockData::new(Some(Terminator {
180                 source_info: bbs[parent].terminator().source_info,
181                 kind: TerminatorKind::SwitchInt {
182                     // switch on the first discriminant, so we can mark the second one as dead
183                     discr: parent_op,
184                     targets: eq_targets,
185                 },
186             }));
187
188             let eq_bb = patch.new_block(eq_switch);
189
190             // Jump to it on the basis of the inequality comparison
191             let true_case = opt_data.destination;
192             let false_case = eq_bb;
193             patch.patch_terminator(
194                 parent,
195                 TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
196             );
197
198             // generate StorageDead for the second_discriminant_temp not in use anymore
199             patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
200
201             // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
202             // the switch
203             for bb in [false_case, true_case].iter() {
204                 patch.add_statement(
205                     Location { block: *bb, statement_index: 0 },
206                     StatementKind::StorageDead(comp_temp),
207                 );
208             }
209
210             patch.apply(body);
211         }
212
213         // Since this optimization adds new basic blocks and invalidates others,
214         // clean up the cfg to make it nicer for other passes
215         if should_cleanup {
216             simplify_cfg(tcx, body);
217         }
218     }
219 }
220
221 /// Returns true if computing the discriminant of `place` may be hoisted out of the branch
222 fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
223     // FIXME(JakobDegen): This is unsound. Someone could write code like this:
224     // ```rust
225     // let Q = val;
226     // if discriminant(P) == otherwise {
227     //     let ptr = &mut Q as *mut _ as *mut u8;
228     //     unsafe { *ptr = 10; } // Any invalid value for the type
229     // }
230     //
231     // match P {
232     //    A => match Q {
233     //        A => {
234     //            // code
235     //        }
236     //        _ => {
237     //            // don't use Q
238     //        }
239     //    }
240     //    _ => {
241     //        // don't use Q
242     //    }
243     // };
244     // ```
245     //
246     // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
247     // invalid value, which is UB.
248     //
249     // In order to fix this, we would either need to show that the discriminant computation of
250     // `place` is computed in all branches, including the `otherwise` branch, or we would need
251     // another analysis pass to determine that the place is fully initialized. It might even be best
252     // to have the hoisting be performed in a different pass and just do the CFG changing in this
253     // pass.
254     for (place, proj) in place.iter_projections() {
255         match proj {
256             // Dereferencing in the computation of `place` might cause issues from one of two
257             // categories. First, the referent might be invalid. We protect against this by
258             // dereferencing references only (not pointers). Second, the use of a reference may
259             // invalidate other references that are used later (for aliasing reasons). Consider
260             // where such an invalidated reference may appear:
261             //  - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
262             //    cannot contain referenced data.
263             //  - In `BBU`: Not possible since that block contains only the `unreachable` terminator
264             //  - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
265             //    reaching that block in the input to our transformation, and so any data
266             //    invalidated by that computation could not have been used there.
267             //  - In `BB9`: Not possible since control flow might have reached `BB9` via the
268             //    `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
269             //    have invalidated the data when computing `discriminant(P)`
270             // So dereferencing here is correct.
271             ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
272                 ty::Ref(..) => {}
273                 _ => return false,
274             },
275             // Field projections are always valid
276             ProjectionElem::Field(..) => {}
277             // We cannot allow
278             // downcasts either, since the correctness of the downcast may depend on the parent
279             // branch being taken. An easy example of this is
280             // ```
281             // Q = discriminant(_3)
282             // P = (_3 as Variant)
283             // ```
284             // However, checking if the child and parent place are the same and only erroring then
285             // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
286             // be replaced by another optimization pass with any other condition that can be proven
287             // equivalent.
288             ProjectionElem::Downcast(..) => {
289                 return false;
290             }
291             // We cannot allow indexing since the index may be out of bounds.
292             _ => {
293                 return false;
294             }
295         }
296     }
297     true
298 }
299
300 #[derive(Debug)]
301 struct OptimizationData<'tcx> {
302     destination: BasicBlock,
303     child_place: Place<'tcx>,
304     child_ty: Ty<'tcx>,
305     child_source: SourceInfo,
306 }
307
308 fn evaluate_candidate<'tcx>(
309     tcx: TyCtxt<'tcx>,
310     body: &Body<'tcx>,
311     parent: BasicBlock,
312 ) -> Option<OptimizationData<'tcx>> {
313     let bbs = &body.basic_blocks;
314     let TerminatorKind::SwitchInt {
315         targets,
316         discr: parent_discr,
317     } = &bbs[parent].terminator().kind else {
318         return None
319     };
320     let parent_ty = parent_discr.ty(body.local_decls(), tcx);
321     let parent_dest = {
322         let poss = targets.otherwise();
323         // If the fallthrough on the parent is trivially unreachable, we can let the
324         // children choose the destination
325         if bbs[poss].statements.len() == 0
326             && bbs[poss].terminator().kind == TerminatorKind::Unreachable
327         {
328             None
329         } else {
330             Some(poss)
331         }
332     };
333     let (_, child) = targets.iter().next()?;
334     let child_terminator = &bbs[child].terminator();
335     let TerminatorKind::SwitchInt {
336         targets: child_targets,
337         discr: child_discr,
338     } = &child_terminator.kind else {
339         return None
340     };
341     let child_ty = child_discr.ty(body.local_decls(), tcx);
342     if child_ty != parent_ty {
343         return None;
344     }
345     let Some(StatementKind::Assign(boxed))
346         = &bbs[child].statements.first().map(|x| &x.kind) else {
347         return None;
348     };
349     let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
350         return None;
351     };
352     let destination = parent_dest.unwrap_or(child_targets.otherwise());
353
354     // Verify that the optimization is legal in general
355     // We can hoist evaluating the child discriminant out of the branch
356     if !may_hoist(tcx, body, *child_place) {
357         return None;
358     }
359
360     // Verify that the optimization is legal for each branch
361     for (value, child) in targets.iter() {
362         if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
363             return None;
364         }
365     }
366     Some(OptimizationData {
367         destination,
368         child_place: *child_place,
369         child_ty,
370         child_source: child_terminator.source_info,
371     })
372 }
373
374 fn verify_candidate_branch<'tcx>(
375     branch: &BasicBlockData<'tcx>,
376     value: u128,
377     place: Place<'tcx>,
378     destination: BasicBlock,
379 ) -> bool {
380     // In order for the optimization to be correct, the branch must...
381     // ...have exactly one statement
382     if branch.statements.len() != 1 {
383         return false;
384     }
385     // ...assign the discriminant of `place` in that statement
386     let StatementKind::Assign(boxed) = &branch.statements[0].kind else {
387         return false
388     };
389     let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else {
390         return false
391     };
392     if *from_place != place {
393         return false;
394     }
395     // ...make that assignment to a local
396     if discr_place.projection.len() != 0 {
397         return false;
398     }
399     // ...terminate on a `SwitchInt` that invalidates that local
400     let TerminatorKind::SwitchInt{ discr: switch_op, targets, .. } = &branch.terminator().kind else {
401         return false
402     };
403     if *switch_op != Operand::Move(*discr_place) {
404         return false;
405     }
406     // ...fall through to `destination` if the switch misses
407     if destination != targets.otherwise() {
408         return false;
409     }
410     // ...have a branch for value `value`
411     let mut iter = targets.iter();
412     let Some((target_value, _)) = iter.next() else {
413         return false;
414     };
415     if target_value != value {
416         return false;
417     }
418     // ...and have no more branches
419     if let Some(_) = iter.next() {
420         return false;
421     }
422     return true;
423 }