]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/simplify_try.rs
Make verbose query description more useful.
[rust.git] / compiler / rustc_mir_transform / src / simplify_try.rs
1 //! The general point of the optimizations provided here is to simplify something like:
2 //!
3 //! ```rust
4 //! # fn foo<T, E>(x: Result<T, E>) -> Result<T, E> {
5 //! match x {
6 //!     Ok(x) => Ok(x),
7 //!     Err(x) => Err(x)
8 //! }
9 //! # }
10 //! ```
11 //!
12 //! into just `x`.
13
14 use crate::{simplify, MirPass};
15 use itertools::Itertools as _;
16 use rustc_index::{bit_set::BitSet, vec::IndexVec};
17 use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
18 use rustc_middle::mir::*;
19 use rustc_middle::ty::{self, List, Ty, TyCtxt};
20 use rustc_target::abi::VariantIdx;
21 use std::iter::{once, Enumerate, Peekable};
22 use std::slice::Iter;
23
24 /// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
25 ///
26 /// This is done by transforming basic blocks where the statements match:
27 ///
28 /// ```ignore (MIR)
29 /// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );
30 /// _TMP_2 = _LOCAL_TMP;
31 /// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2;
32 /// discriminant(_LOCAL_0) = VAR_IDX;
33 /// ```
34 ///
35 /// into:
36 ///
37 /// ```ignore (MIR)
38 /// _LOCAL_0 = move _LOCAL_1
39 /// ```
40 pub struct SimplifyArmIdentity;
41
42 #[derive(Debug)]
43 struct ArmIdentityInfo<'tcx> {
44     /// Storage location for the variant's field
45     local_temp_0: Local,
46     /// Storage location holding the variant being read from
47     local_1: Local,
48     /// The variant field being read from
49     vf_s0: VarField<'tcx>,
50     /// Index of the statement which loads the variant being read
51     get_variant_field_stmt: usize,
52
53     /// Tracks each assignment to a temporary of the variant's field
54     field_tmp_assignments: Vec<(Local, Local)>,
55
56     /// Storage location holding the variant's field that was read from
57     local_tmp_s1: Local,
58     /// Storage location holding the enum that we are writing to
59     local_0: Local,
60     /// The variant field being written to
61     vf_s1: VarField<'tcx>,
62
63     /// Storage location that the discriminant is being written to
64     set_discr_local: Local,
65     /// The variant being written
66     set_discr_var_idx: VariantIdx,
67
68     /// Index of the statement that should be overwritten as a move
69     stmt_to_overwrite: usize,
70     /// SourceInfo for the new move
71     source_info: SourceInfo,
72
73     /// Indices of matching Storage{Live,Dead} statements encountered.
74     /// (StorageLive index,, StorageDead index, Local)
75     storage_stmts: Vec<(usize, usize, Local)>,
76
77     /// The statements that should be removed (turned into nops)
78     stmts_to_remove: Vec<usize>,
79
80     /// Indices of debug variables that need to be adjusted to point to
81     // `{local_0}.{dbg_projection}`.
82     dbg_info_to_adjust: Vec<usize>,
83
84     /// The projection used to rewrite debug info.
85     dbg_projection: &'tcx List<PlaceElem<'tcx>>,
86 }
87
88 fn get_arm_identity_info<'a, 'tcx>(
89     stmts: &'a [Statement<'tcx>],
90     locals_count: usize,
91     debug_info: &'a [VarDebugInfo<'tcx>],
92 ) -> Option<ArmIdentityInfo<'tcx>> {
93     // This can't possibly match unless there are at least 3 statements in the block
94     // so fail fast on tiny blocks.
95     if stmts.len() < 3 {
96         return None;
97     }
98
99     let mut tmp_assigns = Vec::new();
100     let mut nop_stmts = Vec::new();
101     let mut storage_stmts = Vec::new();
102     let mut storage_live_stmts = Vec::new();
103     let mut storage_dead_stmts = Vec::new();
104
105     type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>;
106
107     fn is_storage_stmt(stmt: &Statement<'_>) -> bool {
108         matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_))
109     }
110
111     /// Eats consecutive Statements which match `test`, performing the specified `action` for each.
112     /// The iterator `stmt_iter` is not advanced if none were matched.
113     fn try_eat<'a, 'tcx>(
114         stmt_iter: &mut StmtIter<'a, 'tcx>,
115         test: impl Fn(&'a Statement<'tcx>) -> bool,
116         mut action: impl FnMut(usize, &'a Statement<'tcx>),
117     ) {
118         while stmt_iter.peek().map_or(false, |(_, stmt)| test(stmt)) {
119             let (idx, stmt) = stmt_iter.next().unwrap();
120
121             action(idx, stmt);
122         }
123     }
124
125     /// Eats consecutive `StorageLive` and `StorageDead` Statements.
126     /// The iterator `stmt_iter` is not advanced if none were found.
127     fn try_eat_storage_stmts(
128         stmt_iter: &mut StmtIter<'_, '_>,
129         storage_live_stmts: &mut Vec<(usize, Local)>,
130         storage_dead_stmts: &mut Vec<(usize, Local)>,
131     ) {
132         try_eat(stmt_iter, is_storage_stmt, |idx, stmt| {
133             if let StatementKind::StorageLive(l) = stmt.kind {
134                 storage_live_stmts.push((idx, l));
135             } else if let StatementKind::StorageDead(l) = stmt.kind {
136                 storage_dead_stmts.push((idx, l));
137             }
138         })
139     }
140
141     fn is_tmp_storage_stmt(stmt: &Statement<'_>) -> bool {
142         use rustc_middle::mir::StatementKind::Assign;
143         if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind {
144             place.as_local().is_some() && p.as_local().is_some()
145         } else {
146             false
147         }
148     }
149
150     /// Eats consecutive `Assign` Statements.
151     // The iterator `stmt_iter` is not advanced if none were found.
152     fn try_eat_assign_tmp_stmts(
153         stmt_iter: &mut StmtIter<'_, '_>,
154         tmp_assigns: &mut Vec<(Local, Local)>,
155         nop_stmts: &mut Vec<usize>,
156     ) {
157         try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| {
158             use rustc_middle::mir::StatementKind::Assign;
159             if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) =
160                 &stmt.kind
161             {
162                 tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap()));
163                 nop_stmts.push(idx);
164             }
165         })
166     }
167
168     fn find_storage_live_dead_stmts_for_local(
169         local: Local,
170         stmts: &[Statement<'_>],
171     ) -> Option<(usize, usize)> {
172         trace!("looking for {:?}", local);
173         let mut storage_live_stmt = None;
174         let mut storage_dead_stmt = None;
175         for (idx, stmt) in stmts.iter().enumerate() {
176             if stmt.kind == StatementKind::StorageLive(local) {
177                 storage_live_stmt = Some(idx);
178             } else if stmt.kind == StatementKind::StorageDead(local) {
179                 storage_dead_stmt = Some(idx);
180             }
181         }
182
183         Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX)))
184     }
185
186     // Try to match the expected MIR structure with the basic block we're processing.
187     // We want to see something that looks like:
188     // ```
189     // (StorageLive(_) | StorageDead(_));*
190     // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
191     // (StorageLive(_) | StorageDead(_));*
192     // (tmp_n+1 = tmp_n);*
193     // (StorageLive(_) | StorageDead(_));*
194     // (tmp_n+1 = tmp_n);*
195     // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp;
196     // discriminant(LOCAL_FROM) = VariantIdx;
197     // (StorageLive(_) | StorageDead(_));*
198     // ```
199     let mut stmt_iter = stmts.iter().enumerate().peekable();
200
201     try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
202
203     let (get_variant_field_stmt, stmt) = stmt_iter.next()?;
204     let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?;
205
206     try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
207
208     try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
209
210     try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
211
212     try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
213
214     let (idx, stmt) = stmt_iter.next()?;
215     let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?;
216     nop_stmts.push(idx);
217
218     let (idx, stmt) = stmt_iter.next()?;
219     let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?;
220     let discr_stmt_source_info = stmt.source_info;
221     nop_stmts.push(idx);
222
223     try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
224
225     for (live_idx, live_local) in storage_live_stmts {
226         if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) {
227             let (dead_idx, _) = storage_dead_stmts.swap_remove(i);
228             storage_stmts.push((live_idx, dead_idx, live_local));
229
230             if live_local == local_tmp_s0 {
231                 nop_stmts.push(get_variant_field_stmt);
232             }
233         }
234     }
235     // We sort primitive usize here so we can use unstable sort
236     nop_stmts.sort_unstable();
237
238     // Use one of the statements we're going to discard between the point
239     // where the storage location for the variant field becomes live and
240     // is killed.
241     let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?;
242     let stmt_to_overwrite =
243         nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx);
244
245     let mut tmp_assigned_vars = BitSet::new_empty(locals_count);
246     for (l, r) in &tmp_assigns {
247         tmp_assigned_vars.insert(*l);
248         tmp_assigned_vars.insert(*r);
249     }
250
251     let dbg_info_to_adjust: Vec<_> = debug_info
252         .iter()
253         .enumerate()
254         .filter_map(|(i, var_info)| {
255             if let VarDebugInfoContents::Place(p) = var_info.value {
256                 if tmp_assigned_vars.contains(p.local) {
257                     return Some(i);
258                 }
259             }
260
261             None
262         })
263         .collect();
264
265     Some(ArmIdentityInfo {
266         local_temp_0: local_tmp_s0,
267         local_1,
268         vf_s0,
269         get_variant_field_stmt,
270         field_tmp_assignments: tmp_assigns,
271         local_tmp_s1,
272         local_0,
273         vf_s1,
274         set_discr_local,
275         set_discr_var_idx,
276         stmt_to_overwrite: *stmt_to_overwrite?,
277         source_info: discr_stmt_source_info,
278         storage_stmts,
279         stmts_to_remove: nop_stmts,
280         dbg_info_to_adjust,
281         dbg_projection,
282     })
283 }
284
285 fn optimization_applies<'tcx>(
286     opt_info: &ArmIdentityInfo<'tcx>,
287     local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
288     local_uses: &IndexVec<Local, usize>,
289     var_debug_info: &[VarDebugInfo<'tcx>],
290 ) -> bool {
291     trace!("testing if optimization applies...");
292
293     // FIXME(wesleywiser): possibly relax this restriction?
294     if opt_info.local_0 == opt_info.local_1 {
295         trace!("NO: moving into ourselves");
296         return false;
297     } else if opt_info.vf_s0 != opt_info.vf_s1 {
298         trace!("NO: the field-and-variant information do not match");
299         return false;
300     } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty {
301         // FIXME(Centril,oli-obk): possibly relax to same layout?
302         trace!("NO: source and target locals have different types");
303         return false;
304     } else if (opt_info.local_0, opt_info.vf_s0.var_idx)
305         != (opt_info.set_discr_local, opt_info.set_discr_var_idx)
306     {
307         trace!("NO: the discriminants do not match");
308         return false;
309     }
310
311     // Verify the assignment chain consists of the form b = a; c = b; d = c; etc...
312     if opt_info.field_tmp_assignments.is_empty() {
313         trace!("NO: no assignments found");
314         return false;
315     }
316     let mut last_assigned_to = opt_info.field_tmp_assignments[0].1;
317     let source_local = last_assigned_to;
318     for (l, r) in &opt_info.field_tmp_assignments {
319         if *r != last_assigned_to {
320             trace!("NO: found unexpected assignment {:?} = {:?}", l, r);
321             return false;
322         }
323
324         last_assigned_to = *l;
325     }
326
327     // Check that the first and last used locals are only used twice
328     // since they are of the form:
329     //
330     // ```
331     // _first = ((_x as Variant).n: ty);
332     // _n = _first;
333     // ...
334     // ((_y as Variant).n: ty) = _n;
335     // discriminant(_y) = z;
336     // ```
337     for (l, r) in &opt_info.field_tmp_assignments {
338         if local_uses[*l] != 2 {
339             warn!("NO: FAILED assignment chain local {:?} was used more than twice", l);
340             return false;
341         } else if local_uses[*r] != 2 {
342             warn!("NO: FAILED assignment chain local {:?} was used more than twice", r);
343             return false;
344         }
345     }
346
347     // Check that debug info only points to full Locals and not projections.
348     for dbg_idx in &opt_info.dbg_info_to_adjust {
349         let dbg_info = &var_debug_info[*dbg_idx];
350         if let VarDebugInfoContents::Place(p) = dbg_info.value {
351             if !p.projection.is_empty() {
352                 trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, p);
353                 return false;
354             }
355         }
356     }
357
358     if source_local != opt_info.local_temp_0 {
359         trace!(
360             "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
361             source_local,
362             opt_info.local_temp_0
363         );
364         return false;
365     } else if last_assigned_to != opt_info.local_tmp_s1 {
366         trace!(
367             "NO: end of assignment chain does not match written enum temp: {:?} != {:?}",
368             last_assigned_to,
369             opt_info.local_tmp_s1
370         );
371         return false;
372     }
373
374     trace!("SUCCESS: optimization applies!");
375     true
376 }
377
378 impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
379     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
380         // FIXME(77359): This optimization can result in unsoundness.
381         if !tcx.sess.opts.unstable_opts.unsound_mir_opts {
382             return;
383         }
384
385         let source = body.source;
386         trace!("running SimplifyArmIdentity on {:?}", source);
387
388         let local_uses = LocalUseCounter::get_local_uses(body);
389         for bb in body.basic_blocks.as_mut() {
390             if let Some(opt_info) =
391                 get_arm_identity_info(&bb.statements, body.local_decls.len(), &body.var_debug_info)
392             {
393                 trace!("got opt_info = {:#?}", opt_info);
394                 if !optimization_applies(
395                     &opt_info,
396                     &body.local_decls,
397                     &local_uses,
398                     &body.var_debug_info,
399                 ) {
400                     debug!("optimization skipped for {:?}", source);
401                     continue;
402                 }
403
404                 // Also remove unused Storage{Live,Dead} statements which correspond
405                 // to temps used previously.
406                 for (live_idx, dead_idx, local) in &opt_info.storage_stmts {
407                     // The temporary that we've read the variant field into is scoped to this block,
408                     // so we can remove the assignment.
409                     if *local == opt_info.local_temp_0 {
410                         bb.statements[opt_info.get_variant_field_stmt].make_nop();
411                     }
412
413                     for (left, right) in &opt_info.field_tmp_assignments {
414                         if local == left || local == right {
415                             bb.statements[*live_idx].make_nop();
416                             bb.statements[*dead_idx].make_nop();
417                         }
418                     }
419                 }
420
421                 // Right shape; transform
422                 for stmt_idx in opt_info.stmts_to_remove {
423                     bb.statements[stmt_idx].make_nop();
424                 }
425
426                 let stmt = &mut bb.statements[opt_info.stmt_to_overwrite];
427                 stmt.source_info = opt_info.source_info;
428                 stmt.kind = StatementKind::Assign(Box::new((
429                     opt_info.local_0.into(),
430                     Rvalue::Use(Operand::Move(opt_info.local_1.into())),
431                 )));
432
433                 bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop);
434
435                 // Fix the debug info to point to the right local
436                 for dbg_index in opt_info.dbg_info_to_adjust {
437                     let dbg_info = &mut body.var_debug_info[dbg_index];
438                     assert!(
439                         matches!(dbg_info.value, VarDebugInfoContents::Place(_)),
440                         "value was not a Place"
441                     );
442                     if let VarDebugInfoContents::Place(p) = &mut dbg_info.value {
443                         assert!(p.projection.is_empty());
444                         p.local = opt_info.local_0;
445                         p.projection = opt_info.dbg_projection;
446                     }
447                 }
448
449                 trace!("block is now {:?}", bb.statements);
450             }
451         }
452     }
453 }
454
455 struct LocalUseCounter {
456     local_uses: IndexVec<Local, usize>,
457 }
458
459 impl LocalUseCounter {
460     fn get_local_uses(body: &Body<'_>) -> IndexVec<Local, usize> {
461         let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) };
462         counter.visit_body(body);
463         counter.local_uses
464     }
465 }
466
467 impl Visitor<'_> for LocalUseCounter {
468     fn visit_local(&mut self, local: Local, context: PlaceContext, _location: Location) {
469         if context.is_storage_marker()
470             || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo)
471         {
472             return;
473         }
474
475         self.local_uses[local] += 1;
476     }
477 }
478
479 /// Match on:
480 /// ```ignore (MIR)
481 /// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
482 /// ```
483 fn match_get_variant_field<'tcx>(
484     stmt: &Statement<'tcx>,
485 ) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> {
486     match &stmt.kind {
487         StatementKind::Assign(box (
488             place_into,
489             Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)),
490         )) => {
491             let local_into = place_into.as_local()?;
492             let (local_from, vf) = match_variant_field_place(*pf)?;
493             Some((local_into, local_from, vf, pf.projection))
494         }
495         _ => None,
496     }
497 }
498
499 /// Match on:
500 /// ```ignore (MIR)
501 /// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO;
502 /// ```
503 fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {
504     match &stmt.kind {
505         StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => {
506             let local_into = place_into.as_local()?;
507             let (local_from, vf) = match_variant_field_place(*place_from)?;
508             Some((local_into, local_from, vf))
509         }
510         _ => None,
511     }
512 }
513
514 /// Match on:
515 /// ```ignore (MIR)
516 /// discriminant(_LOCAL_TO_SET) = VAR_IDX;
517 /// ```
518 fn match_set_discr(stmt: &Statement<'_>) -> Option<(Local, VariantIdx)> {
519     match &stmt.kind {
520         StatementKind::SetDiscriminant { place, variant_index } => {
521             Some((place.as_local()?, *variant_index))
522         }
523         _ => None,
524     }
525 }
526
527 #[derive(PartialEq, Debug)]
528 struct VarField<'tcx> {
529     field: Field,
530     field_ty: Ty<'tcx>,
531     var_idx: VariantIdx,
532 }
533
534 /// Match on `((_LOCAL as Variant).FIELD: TY)`.
535 fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarField<'tcx>)> {
536     match place.as_ref() {
537         PlaceRef {
538             local,
539             projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)],
540         } => Some((local, VarField { field, field_ty: ty, var_idx })),
541         _ => None,
542     }
543 }
544
545 /// Simplifies `SwitchInt(_) -> [targets]`,
546 /// where all the `targets` have the same form,
547 /// into `goto -> target_first`.
548 pub struct SimplifyBranchSame;
549
550 impl<'tcx> MirPass<'tcx> for SimplifyBranchSame {
551     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
552         // This optimization is disabled by default for now due to
553         // soundness concerns; see issue #89485 and PR #89489.
554         if !tcx.sess.opts.unstable_opts.unsound_mir_opts {
555             return;
556         }
557
558         trace!("Running SimplifyBranchSame on {:?}", body.source);
559         let finder = SimplifyBranchSameOptimizationFinder { body, tcx };
560         let opts = finder.find();
561
562         let did_remove_blocks = opts.len() > 0;
563         for opt in opts.iter() {
564             trace!("SUCCESS: Applying optimization {:?}", opt);
565             // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
566             body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind =
567                 TerminatorKind::Goto { target: opt.bb_to_goto };
568         }
569
570         if did_remove_blocks {
571             // We have dead blocks now, so remove those.
572             simplify::remove_dead_blocks(tcx, body);
573         }
574     }
575 }
576
577 #[derive(Debug)]
578 struct SimplifyBranchSameOptimization {
579     /// All basic blocks are equal so go to this one
580     bb_to_goto: BasicBlock,
581     /// Basic block where the terminator can be simplified to a goto
582     bb_to_opt_terminator: BasicBlock,
583 }
584
585 struct SwitchTargetAndValue {
586     target: BasicBlock,
587     // None in case of the `otherwise` case
588     value: Option<u128>,
589 }
590
591 struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
592     body: &'a Body<'tcx>,
593     tcx: TyCtxt<'tcx>,
594 }
595
596 impl<'tcx> SimplifyBranchSameOptimizationFinder<'_, 'tcx> {
597     fn find(&self) -> Vec<SimplifyBranchSameOptimization> {
598         self.body
599             .basic_blocks
600             .iter_enumerated()
601             .filter_map(|(bb_idx, bb)| {
602                 let (discr_switched_on, targets_and_values) = match &bb.terminator().kind {
603                     TerminatorKind::SwitchInt { targets, discr, .. } => {
604                         let targets_and_values: Vec<_> = targets.iter()
605                             .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) })
606                             .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None }))
607                             .collect();
608                         (discr, targets_and_values)
609                     },
610                     _ => return None,
611                 };
612
613                 // find the adt that has its discriminant read
614                 // assuming this must be the last statement of the block
615                 let adt_matched_on = match &bb.statements.last()?.kind {
616                     StatementKind::Assign(box (place, rhs))
617                         if Some(*place) == discr_switched_on.place() =>
618                     {
619                         match rhs {
620                             Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place,
621                             _ => {
622                                 trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs);
623                                 return None;
624                             }
625                         }
626                     }
627                     other => {
628                         trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other);
629                         return None
630                     },
631                 };
632
633                 let mut iter_bbs_reachable = targets_and_values
634                     .iter()
635                     .map(|target_and_value| (target_and_value, &self.body.basic_blocks[target_and_value.target]))
636                     .filter(|(_, bb)| {
637                         // Reaching `unreachable` is UB so assume it doesn't happen.
638                         bb.terminator().kind != TerminatorKind::Unreachable
639                     })
640                     .peekable();
641
642                 let bb_first = iter_bbs_reachable.peek().map_or(&targets_and_values[0], |(idx, _)| *idx);
643                 let mut all_successors_equivalent = StatementEquality::TrivialEqual;
644
645                 // All successor basic blocks must be equal or contain statements that are pairwise considered equal.
646                 for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() {
647                     let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup
648                                             && bb_l.terminator().kind == bb_r.terminator().kind
649                                             && bb_l.statements.len() == bb_r.statements.len();
650                     let statement_check = || {
651                         bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| {
652                             let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r);
653                             if matches!(stmt_equality, StatementEquality::NotEqual) {
654                                 // short circuit
655                                 None
656                             } else {
657                                 Some(acc.combine(&stmt_equality))
658                             }
659                         })
660                         .unwrap_or(StatementEquality::NotEqual)
661                     };
662                     if !trivial_checks {
663                         all_successors_equivalent = StatementEquality::NotEqual;
664                         break;
665                     }
666                     all_successors_equivalent = all_successors_equivalent.combine(&statement_check());
667                 };
668
669                 match all_successors_equivalent{
670                     StatementEquality::TrivialEqual => {
671                         // statements are trivially equal, so just take first
672                         trace!("Statements are trivially equal");
673                         Some(SimplifyBranchSameOptimization {
674                             bb_to_goto: bb_first.target,
675                             bb_to_opt_terminator: bb_idx,
676                         })
677                     }
678                     StatementEquality::ConsideredEqual(bb_to_choose) => {
679                         trace!("Statements are considered equal");
680                         Some(SimplifyBranchSameOptimization {
681                             bb_to_goto: bb_to_choose,
682                             bb_to_opt_terminator: bb_idx,
683                         })
684                     }
685                     StatementEquality::NotEqual => {
686                         trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx);
687                         None
688                     }
689                 }
690             })
691             .collect()
692     }
693
694     /// Tests if two statements can be considered equal
695     ///
696     /// Statements can be trivially equal if the kinds match.
697     /// But they can also be considered equal in the following case A:
698     /// ```ignore (MIR)
699     /// discriminant(_0) = 0;   // bb1
700     /// _0 = move _1;           // bb2
701     /// ```
702     /// In this case the two statements are equal iff
703     /// - `_0` is an enum where the variant index 0 is fieldless, and
704     /// -  bb1 was targeted by a switch where the discriminant of `_1` was switched on
705     fn statement_equality(
706         &self,
707         adt_matched_on: Place<'tcx>,
708         x: &Statement<'tcx>,
709         x_target_and_value: &SwitchTargetAndValue,
710         y: &Statement<'tcx>,
711         y_target_and_value: &SwitchTargetAndValue,
712     ) -> StatementEquality {
713         let helper = |rhs: &Rvalue<'tcx>,
714                       place: &Place<'tcx>,
715                       variant_index: VariantIdx,
716                       switch_value: u128,
717                       side_to_choose| {
718             let place_type = place.ty(self.body, self.tcx).ty;
719             let adt = match *place_type.kind() {
720                 ty::Adt(adt, _) if adt.is_enum() => adt,
721                 _ => return StatementEquality::NotEqual,
722             };
723             // We need to make sure that the switch value that targets the bb with
724             // SetDiscriminant is the same as the variant discriminant.
725             let variant_discr = adt.discriminant_for_variant(self.tcx, variant_index).val;
726             if variant_discr != switch_value {
727                 trace!(
728                     "NO: variant discriminant {} does not equal switch value {}",
729                     variant_discr,
730                     switch_value
731                 );
732                 return StatementEquality::NotEqual;
733             }
734             let variant_is_fieldless = adt.variant(variant_index).fields.is_empty();
735             if !variant_is_fieldless {
736                 trace!("NO: variant {:?} was not fieldless", variant_index);
737                 return StatementEquality::NotEqual;
738             }
739
740             match rhs {
741                 Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => {
742                     StatementEquality::ConsideredEqual(side_to_choose)
743                 }
744                 _ => {
745                     trace!(
746                         "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}",
747                         rhs,
748                         adt_matched_on
749                     );
750                     StatementEquality::NotEqual
751                 }
752             }
753         };
754         match (&x.kind, &y.kind) {
755             // trivial case
756             (x, y) if x == y => StatementEquality::TrivialEqual,
757
758             // check for case A
759             (
760                 StatementKind::Assign(box (_, rhs)),
761                 &StatementKind::SetDiscriminant { ref place, variant_index },
762             ) if y_target_and_value.value.is_some() => {
763                 // choose basic block of x, as that has the assign
764                 helper(
765                     rhs,
766                     place,
767                     variant_index,
768                     y_target_and_value.value.unwrap(),
769                     x_target_and_value.target,
770                 )
771             }
772             (
773                 &StatementKind::SetDiscriminant { ref place, variant_index },
774                 &StatementKind::Assign(box (_, ref rhs)),
775             ) if x_target_and_value.value.is_some() => {
776                 // choose basic block of y, as that has the assign
777                 helper(
778                     rhs,
779                     place,
780                     variant_index,
781                     x_target_and_value.value.unwrap(),
782                     y_target_and_value.target,
783                 )
784             }
785             _ => {
786                 trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y);
787                 StatementEquality::NotEqual
788             }
789         }
790     }
791 }
792
793 #[derive(Copy, Clone, Eq, PartialEq)]
794 enum StatementEquality {
795     /// The two statements are trivially equal; same kind
796     TrivialEqual,
797     /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization.
798     /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index
799     ConsideredEqual(BasicBlock),
800     /// The two statements are not equal
801     NotEqual,
802 }
803
804 impl StatementEquality {
805     fn combine(&self, other: &StatementEquality) -> StatementEquality {
806         use StatementEquality::*;
807         match (self, other) {
808             (TrivialEqual, TrivialEqual) => TrivialEqual,
809             (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => {
810                 ConsideredEqual(*b)
811             }
812             (ConsideredEqual(b1), ConsideredEqual(b2)) => {
813                 if b1 == b2 {
814                     ConsideredEqual(*b1)
815                 } else {
816                     NotEqual
817                 }
818             }
819             (_, NotEqual) | (NotEqual, _) => NotEqual,
820         }
821     }
822 }