1 use rustc_middle::mir::patch::MirPatch;
2 use rustc_middle::mir::*;
3 use rustc_middle::ty::{self, Ty, TyCtxt};
6 use super::simplify::simplify_cfg;
8 /// This pass optimizes something like
9 /// ```ignore (syntax-highlighting-only)
10 /// let x: Option<()>;
11 /// let y: Option<()>;
13 /// (Some(_), Some(_)) => {0},
17 /// into something like
18 /// ```ignore (syntax-highlighting-only)
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = std::mem::discriminant(x);
22 /// let discriminant_y = std::mem::discriminant(y);
23 /// if discriminant_x == discriminant_y {
27 /// } // | Actually the same bb
29 /// 1 // <--------------
33 /// Specifically, it looks for instances of control flow like this:
38 /// |---------------| ============================
39 /// | ... | /------> | BBC |
40 /// |---------------| | |--------------------------|
41 /// | switchInt(Q) | | | _cl = discriminant(P) |
42 /// | c | --------/ |--------------------------|
43 /// | d | -------\ | switchInt(_cl) |
44 /// | ... | | | c | ---> BBC.2
45 /// | otherwise | --\ | /--- | otherwise |
46 /// ================= | | | ============================
48 /// ================= | | |
49 /// | BBU | <-| | | ============================
50 /// |---------------| | \-------> | BBD |
51 /// |---------------| | | |--------------------------|
52 /// | unreachable | | | | _dl = discriminant(P) |
53 /// ================= | | |--------------------------|
54 /// | | | switchInt(_dl) |
55 /// ================= | | | d | ---> BBD.2
56 /// | BB9 | <--------------- | otherwise |
57 /// |---------------| ============================
61 /// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
63 /// - `BB1` is `parent` and `BBC, BBD` are children
64 /// - `P` is `child_place`
65 /// - `child_ty` is the type of `_cl`.
66 /// - `Q` is `parent_op`.
67 /// - `parent_ty` is the type of `Q`.
68 /// - `BB9` is `destination`
69 /// All this is then transformed into:
72 /// =======================
74 /// |---------------------| ============================
75 /// | ... | /------> | BBEq |
76 /// | _s = discriminant(P)| | |--------------------------|
77 /// | _t = Ne(Q, _s) | | |--------------------------|
78 /// |---------------------| | | switchInt(Q) |
79 /// | switchInt(_t) | | | c | ---> BBC.2
80 /// | false | --------/ | d | ---> BBD.2
81 /// | otherwise | ---------------- | otherwise |
82 /// ======================= | ============================
84 /// ================= |
85 /// | BB9 | <-----------/
91 /// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
92 /// The filter on which `P` are allowed (together with discussion of its correctness) is found in
94 pub struct EarlyOtherwiseBranch;
96 impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
97 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
98 sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts
101 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
102 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
104 let mut should_cleanup = false;
106 // Also consider newly generated bbs in the same pass
107 for i in 0..body.basic_blocks.len() {
108 let bbs = &*body.basic_blocks;
109 let parent = BasicBlock::from_usize(i);
110 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else {
114 if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_data)) {
118 trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data);
120 should_cleanup = true;
122 let TerminatorKind::SwitchInt {
124 switch_ty: parent_ty,
125 targets: parent_targets
126 } = &bbs[parent].terminator().kind else {
129 // Always correct since we can only switch on `Copy` types
130 let parent_op = match parent_op {
131 Operand::Move(x) => Operand::Copy(*x),
132 Operand::Copy(x) => Operand::Copy(*x),
133 Operand::Constant(x) => Operand::Constant(x.clone()),
135 let statements_before = bbs[parent].statements.len();
136 let parent_end = Location { block: parent, statement_index: statements_before };
138 let mut patch = MirPatch::new(body);
140 // create temp to store second discriminant in, `_s` in example above
141 let second_discriminant_temp =
142 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
144 patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
146 // create assignment of discriminant
149 Place::from(second_discriminant_temp),
150 Rvalue::Discriminant(opt_data.child_place),
153 // create temp to store inequality comparison between the two discriminants, `_t` in
155 let nequal = BinOp::Ne;
156 let comp_res_type = nequal.ty(tcx, *parent_ty, opt_data.child_ty);
157 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
158 patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
160 // create inequality comparison between the two discriminants
161 let comp_rvalue = Rvalue::BinaryOp(
163 Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
167 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
170 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
171 let TerminatorKind::SwitchInt{ targets, .. } = &bbs[child].terminator().kind else {
174 (value, targets.target_for_value(value))
176 let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
178 // Create `bbEq` in example above
179 let eq_switch = BasicBlockData::new(Some(Terminator {
180 source_info: bbs[parent].terminator().source_info,
181 kind: TerminatorKind::SwitchInt {
182 // switch on the first discriminant, so we can mark the second one as dead
184 switch_ty: opt_data.child_ty,
189 let eq_bb = patch.new_block(eq_switch);
191 // Jump to it on the basis of the inequality comparison
192 let true_case = opt_data.destination;
193 let false_case = eq_bb;
194 patch.patch_terminator(
198 Operand::Move(Place::from(comp_temp)),
204 // generate StorageDead for the second_discriminant_temp not in use anymore
205 patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
207 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
209 for bb in [false_case, true_case].iter() {
211 Location { block: *bb, statement_index: 0 },
212 StatementKind::StorageDead(comp_temp),
219 // Since this optimization adds new basic blocks and invalidates others,
220 // clean up the cfg to make it nicer for other passes
222 simplify_cfg(tcx, body);
227 /// Returns true if computing the discriminant of `place` may be hoisted out of the branch
228 fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
229 // FIXME(JakobDegen): This is unsound. Someone could write code like this:
232 // if discriminant(P) == otherwise {
233 // let ptr = &mut Q as *mut _ as *mut u8;
234 // unsafe { *ptr = 10; } // Any invalid value for the type
252 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
253 // invalid value, which is UB.
255 // In order to fix this, we would either need to show that the discriminant computation of
256 // `place` is computed in all branches, including the `otherwise` branch, or we would need
257 // another analysis pass to determine that the place is fully initialized. It might even be best
258 // to have the hoisting be performed in a different pass and just do the CFG changing in this
260 for (place, proj) in place.iter_projections() {
262 // Dereferencing in the computation of `place` might cause issues from one of two
263 // categories. First, the referent might be invalid. We protect against this by
264 // dereferencing references only (not pointers). Second, the use of a reference may
265 // invalidate other references that are used later (for aliasing reasons). Consider
266 // where such an invalidated reference may appear:
267 // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
268 // cannot contain referenced data.
269 // - In `BBU`: Not possible since that block contains only the `unreachable` terminator
270 // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
271 // reaching that block in the input to our transformation, and so any data
272 // invalidated by that computation could not have been used there.
273 // - In `BB9`: Not possible since control flow might have reached `BB9` via the
274 // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
275 // have invalidated the data when computing `discriminant(P)`
276 // So dereferencing here is correct.
277 ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
281 // Field projections are always valid
282 ProjectionElem::Field(..) => {}
284 // downcasts either, since the correctness of the downcast may depend on the parent
285 // branch being taken. An easy example of this is
287 // Q = discriminant(_3)
288 // P = (_3 as Variant)
290 // However, checking if the child and parent place are the same and only erroring then
291 // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
292 // be replaced by another optimization pass with any other condition that can be proven
294 ProjectionElem::Downcast(..) => {
297 // We cannot allow indexing since the index may be out of bounds.
307 struct OptimizationData<'tcx> {
308 destination: BasicBlock,
309 child_place: Place<'tcx>,
311 child_source: SourceInfo,
314 fn evaluate_candidate<'tcx>(
318 ) -> Option<OptimizationData<'tcx>> {
319 let bbs = &body.basic_blocks;
320 let TerminatorKind::SwitchInt {
322 switch_ty: parent_ty,
324 } = &bbs[parent].terminator().kind else {
328 let poss = targets.otherwise();
329 // If the fallthrough on the parent is trivially unreachable, we can let the
330 // children choose the destination
331 if bbs[poss].statements.len() == 0
332 && bbs[poss].terminator().kind == TerminatorKind::Unreachable
339 let (_, child) = targets.iter().next()?;
340 let child_terminator = &bbs[child].terminator();
341 let TerminatorKind::SwitchInt {
343 targets: child_targets,
345 } = &child_terminator.kind else {
348 if child_ty != parent_ty {
351 let Some(StatementKind::Assign(boxed))
352 = &bbs[child].statements.first().map(|x| &x.kind) else {
355 let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
358 let destination = parent_dest.unwrap_or(child_targets.otherwise());
360 // Verify that the optimization is legal in general
361 // We can hoist evaluating the child discriminant out of the branch
362 if !may_hoist(tcx, body, *child_place) {
366 // Verify that the optimization is legal for each branch
367 for (value, child) in targets.iter() {
368 if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
372 Some(OptimizationData {
374 child_place: *child_place,
376 child_source: child_terminator.source_info,
380 fn verify_candidate_branch<'tcx>(
381 branch: &BasicBlockData<'tcx>,
384 destination: BasicBlock,
386 // In order for the optimization to be correct, the branch must...
387 // ...have exactly one statement
388 if branch.statements.len() != 1 {
391 // ...assign the discriminant of `place` in that statement
392 let StatementKind::Assign(boxed) = &branch.statements[0].kind else {
395 let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else {
398 if *from_place != place {
401 // ...make that assignment to a local
402 if discr_place.projection.len() != 0 {
405 // ...terminate on a `SwitchInt` that invalidates that local
406 let TerminatorKind::SwitchInt{ discr: switch_op, targets, .. } = &branch.terminator().kind else {
409 if *switch_op != Operand::Move(*discr_place) {
412 // ...fall through to `destination` if the switch misses
413 if destination != targets.otherwise() {
416 // ...have a branch for value `value`
417 let mut iter = targets.iter();
418 let Some((target_value, _)) = iter.next() else {
421 if target_value != value {
424 // ...and have no more branches
425 if let Some(_) = iter.next() {