]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir/src/transform/simplify_try.rs
Fix miscompile in SimplifyBranchSame
[rust.git] / compiler / rustc_mir / src / transform / simplify_try.rs
1 //! The general point of the optimizations provided here is to simplify something like:
2 //!
3 //! ```rust
4 //! match x {
5 //!     Ok(x) => Ok(x),
6 //!     Err(x) => Err(x)
7 //! }
8 //! ```
9 //!
10 //! into just `x`.
11
12 use crate::transform::{simplify, MirPass};
13 use itertools::Itertools as _;
14 use rustc_index::{bit_set::BitSet, vec::IndexVec};
15 use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
16 use rustc_middle::mir::*;
17 use rustc_middle::ty::{self, List, Ty, TyCtxt};
18 use rustc_target::abi::VariantIdx;
19 use std::iter::{once, Enumerate, Peekable};
20 use std::slice::Iter;
21
22 /// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
23 ///
24 /// This is done by transforming basic blocks where the statements match:
25 ///
26 /// ```rust
27 /// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );
28 /// _TMP_2 = _LOCAL_TMP;
29 /// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2;
30 /// discriminant(_LOCAL_0) = VAR_IDX;
31 /// ```
32 ///
33 /// into:
34 ///
35 /// ```rust
36 /// _LOCAL_0 = move _LOCAL_1
37 /// ```
38 pub struct SimplifyArmIdentity;
39
40 #[derive(Debug)]
41 struct ArmIdentityInfo<'tcx> {
42     /// Storage location for the variant's field
43     local_temp_0: Local,
44     /// Storage location holding the variant being read from
45     local_1: Local,
46     /// The variant field being read from
47     vf_s0: VarField<'tcx>,
48     /// Index of the statement which loads the variant being read
49     get_variant_field_stmt: usize,
50
51     /// Tracks each assignment to a temporary of the variant's field
52     field_tmp_assignments: Vec<(Local, Local)>,
53
54     /// Storage location holding the variant's field that was read from
55     local_tmp_s1: Local,
56     /// Storage location holding the enum that we are writing to
57     local_0: Local,
58     /// The variant field being written to
59     vf_s1: VarField<'tcx>,
60
61     /// Storage location that the discriminant is being written to
62     set_discr_local: Local,
63     /// The variant being written
64     set_discr_var_idx: VariantIdx,
65
66     /// Index of the statement that should be overwritten as a move
67     stmt_to_overwrite: usize,
68     /// SourceInfo for the new move
69     source_info: SourceInfo,
70
71     /// Indices of matching Storage{Live,Dead} statements encountered.
72     /// (StorageLive index,, StorageDead index, Local)
73     storage_stmts: Vec<(usize, usize, Local)>,
74
75     /// The statements that should be removed (turned into nops)
76     stmts_to_remove: Vec<usize>,
77
78     /// Indices of debug variables that need to be adjusted to point to
79     // `{local_0}.{dbg_projection}`.
80     dbg_info_to_adjust: Vec<usize>,
81
82     /// The projection used to rewrite debug info.
83     dbg_projection: &'tcx List<PlaceElem<'tcx>>,
84 }
85
86 fn get_arm_identity_info<'a, 'tcx>(
87     stmts: &'a [Statement<'tcx>],
88     locals_count: usize,
89     debug_info: &'a [VarDebugInfo<'tcx>],
90 ) -> Option<ArmIdentityInfo<'tcx>> {
91     // This can't possibly match unless there are at least 3 statements in the block
92     // so fail fast on tiny blocks.
93     if stmts.len() < 3 {
94         return None;
95     }
96
97     let mut tmp_assigns = Vec::new();
98     let mut nop_stmts = Vec::new();
99     let mut storage_stmts = Vec::new();
100     let mut storage_live_stmts = Vec::new();
101     let mut storage_dead_stmts = Vec::new();
102
103     type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>;
104
105     fn is_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool {
106         matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_))
107     }
108
109     /// Eats consecutive Statements which match `test`, performing the specified `action` for each.
110     /// The iterator `stmt_iter` is not advanced if none were matched.
111     fn try_eat<'a, 'tcx>(
112         stmt_iter: &mut StmtIter<'a, 'tcx>,
113         test: impl Fn(&'a Statement<'tcx>) -> bool,
114         mut action: impl FnMut(usize, &'a Statement<'tcx>),
115     ) {
116         while stmt_iter.peek().map(|(_, stmt)| test(stmt)).unwrap_or(false) {
117             let (idx, stmt) = stmt_iter.next().unwrap();
118
119             action(idx, stmt);
120         }
121     }
122
123     /// Eats consecutive `StorageLive` and `StorageDead` Statements.
124     /// The iterator `stmt_iter` is not advanced if none were found.
125     fn try_eat_storage_stmts<'a, 'tcx>(
126         stmt_iter: &mut StmtIter<'a, 'tcx>,
127         storage_live_stmts: &mut Vec<(usize, Local)>,
128         storage_dead_stmts: &mut Vec<(usize, Local)>,
129     ) {
130         try_eat(stmt_iter, is_storage_stmt, |idx, stmt| {
131             if let StatementKind::StorageLive(l) = stmt.kind {
132                 storage_live_stmts.push((idx, l));
133             } else if let StatementKind::StorageDead(l) = stmt.kind {
134                 storage_dead_stmts.push((idx, l));
135             }
136         })
137     }
138
139     fn is_tmp_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool {
140         use rustc_middle::mir::StatementKind::Assign;
141         if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind {
142             place.as_local().is_some() && p.as_local().is_some()
143         } else {
144             false
145         }
146     }
147
148     /// Eats consecutive `Assign` Statements.
149     // The iterator `stmt_iter` is not advanced if none were found.
150     fn try_eat_assign_tmp_stmts<'a, 'tcx>(
151         stmt_iter: &mut StmtIter<'a, 'tcx>,
152         tmp_assigns: &mut Vec<(Local, Local)>,
153         nop_stmts: &mut Vec<usize>,
154     ) {
155         try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| {
156             use rustc_middle::mir::StatementKind::Assign;
157             if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) =
158                 &stmt.kind
159             {
160                 tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap()));
161                 nop_stmts.push(idx);
162             }
163         })
164     }
165
166     fn find_storage_live_dead_stmts_for_local<'tcx>(
167         local: Local,
168         stmts: &[Statement<'tcx>],
169     ) -> Option<(usize, usize)> {
170         trace!("looking for {:?}", local);
171         let mut storage_live_stmt = None;
172         let mut storage_dead_stmt = None;
173         for (idx, stmt) in stmts.iter().enumerate() {
174             if stmt.kind == StatementKind::StorageLive(local) {
175                 storage_live_stmt = Some(idx);
176             } else if stmt.kind == StatementKind::StorageDead(local) {
177                 storage_dead_stmt = Some(idx);
178             }
179         }
180
181         Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX)))
182     }
183
184     // Try to match the expected MIR structure with the basic block we're processing.
185     // We want to see something that looks like:
186     // ```
187     // (StorageLive(_) | StorageDead(_));*
188     // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
189     // (StorageLive(_) | StorageDead(_));*
190     // (tmp_n+1 = tmp_n);*
191     // (StorageLive(_) | StorageDead(_));*
192     // (tmp_n+1 = tmp_n);*
193     // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp;
194     // discriminant(LOCAL_FROM) = VariantIdx;
195     // (StorageLive(_) | StorageDead(_));*
196     // ```
197     let mut stmt_iter = stmts.iter().enumerate().peekable();
198
199     try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
200
201     let (get_variant_field_stmt, stmt) = stmt_iter.next()?;
202     let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?;
203
204     try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
205
206     try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
207
208     try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
209
210     try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
211
212     let (idx, stmt) = stmt_iter.next()?;
213     let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?;
214     nop_stmts.push(idx);
215
216     let (idx, stmt) = stmt_iter.next()?;
217     let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?;
218     let discr_stmt_source_info = stmt.source_info;
219     nop_stmts.push(idx);
220
221     try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
222
223     for (live_idx, live_local) in storage_live_stmts {
224         if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) {
225             let (dead_idx, _) = storage_dead_stmts.swap_remove(i);
226             storage_stmts.push((live_idx, dead_idx, live_local));
227
228             if live_local == local_tmp_s0 {
229                 nop_stmts.push(get_variant_field_stmt);
230             }
231         }
232     }
233     // We sort primitive usize here so we can use unstable sort
234     nop_stmts.sort_unstable();
235
236     // Use one of the statements we're going to discard between the point
237     // where the storage location for the variant field becomes live and
238     // is killed.
239     let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?;
240     let stmt_to_overwrite =
241         nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx);
242
243     let mut tmp_assigned_vars = BitSet::new_empty(locals_count);
244     for (l, r) in &tmp_assigns {
245         tmp_assigned_vars.insert(*l);
246         tmp_assigned_vars.insert(*r);
247     }
248
249     let dbg_info_to_adjust: Vec<_> =
250         debug_info
251             .iter()
252             .enumerate()
253             .filter_map(|(i, var_info)| {
254                 if tmp_assigned_vars.contains(var_info.place.local) { Some(i) } else { None }
255             })
256             .collect();
257
258     Some(ArmIdentityInfo {
259         local_temp_0: local_tmp_s0,
260         local_1,
261         vf_s0,
262         get_variant_field_stmt,
263         field_tmp_assignments: tmp_assigns,
264         local_tmp_s1,
265         local_0,
266         vf_s1,
267         set_discr_local,
268         set_discr_var_idx,
269         stmt_to_overwrite: *stmt_to_overwrite?,
270         source_info: discr_stmt_source_info,
271         storage_stmts,
272         stmts_to_remove: nop_stmts,
273         dbg_info_to_adjust,
274         dbg_projection,
275     })
276 }
277
278 fn optimization_applies<'tcx>(
279     opt_info: &ArmIdentityInfo<'tcx>,
280     local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
281     local_uses: &IndexVec<Local, usize>,
282     var_debug_info: &[VarDebugInfo<'tcx>],
283 ) -> bool {
284     trace!("testing if optimization applies...");
285
286     // FIXME(wesleywiser): possibly relax this restriction?
287     if opt_info.local_0 == opt_info.local_1 {
288         trace!("NO: moving into ourselves");
289         return false;
290     } else if opt_info.vf_s0 != opt_info.vf_s1 {
291         trace!("NO: the field-and-variant information do not match");
292         return false;
293     } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty {
294         // FIXME(Centril,oli-obk): possibly relax to same layout?
295         trace!("NO: source and target locals have different types");
296         return false;
297     } else if (opt_info.local_0, opt_info.vf_s0.var_idx)
298         != (opt_info.set_discr_local, opt_info.set_discr_var_idx)
299     {
300         trace!("NO: the discriminants do not match");
301         return false;
302     }
303
304     // Verify the assigment chain consists of the form b = a; c = b; d = c; etc...
305     if opt_info.field_tmp_assignments.is_empty() {
306         trace!("NO: no assignments found");
307         return false;
308     }
309     let mut last_assigned_to = opt_info.field_tmp_assignments[0].1;
310     let source_local = last_assigned_to;
311     for (l, r) in &opt_info.field_tmp_assignments {
312         if *r != last_assigned_to {
313             trace!("NO: found unexpected assignment {:?} = {:?}", l, r);
314             return false;
315         }
316
317         last_assigned_to = *l;
318     }
319
320     // Check that the first and last used locals are only used twice
321     // since they are of the form:
322     //
323     // ```
324     // _first = ((_x as Variant).n: ty);
325     // _n = _first;
326     // ...
327     // ((_y as Variant).n: ty) = _n;
328     // discriminant(_y) = z;
329     // ```
330     for (l, r) in &opt_info.field_tmp_assignments {
331         if local_uses[*l] != 2 {
332             warn!("NO: FAILED assignment chain local {:?} was used more than twice", l);
333             return false;
334         } else if local_uses[*r] != 2 {
335             warn!("NO: FAILED assignment chain local {:?} was used more than twice", r);
336             return false;
337         }
338     }
339
340     // Check that debug info only points to full Locals and not projections.
341     for dbg_idx in &opt_info.dbg_info_to_adjust {
342         let dbg_info = &var_debug_info[*dbg_idx];
343         if !dbg_info.place.projection.is_empty() {
344             trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, dbg_info.place);
345             return false;
346         }
347     }
348
349     if source_local != opt_info.local_temp_0 {
350         trace!(
351             "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
352             source_local,
353             opt_info.local_temp_0
354         );
355         return false;
356     } else if last_assigned_to != opt_info.local_tmp_s1 {
357         trace!(
358             "NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}",
359             last_assigned_to,
360             opt_info.local_tmp_s1
361         );
362         return false;
363     }
364
365     trace!("SUCCESS: optimization applies!");
366     true
367 }
368
369 impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
370     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
371         // FIXME(77359): This optimization can result in unsoundness.
372         if !tcx.sess.opts.debugging_opts.unsound_mir_opts {
373             return;
374         }
375
376         let source = body.source;
377         trace!("running SimplifyArmIdentity on {:?}", source);
378
379         let local_uses = LocalUseCounter::get_local_uses(body);
380         let (basic_blocks, local_decls, debug_info) =
381             body.basic_blocks_local_decls_mut_and_var_debug_info();
382         for bb in basic_blocks {
383             if let Some(opt_info) =
384                 get_arm_identity_info(&bb.statements, local_decls.len(), debug_info)
385             {
386                 trace!("got opt_info = {:#?}", opt_info);
387                 if !optimization_applies(&opt_info, local_decls, &local_uses, &debug_info) {
388                     debug!("optimization skipped for {:?}", source);
389                     continue;
390                 }
391
392                 // Also remove unused Storage{Live,Dead} statements which correspond
393                 // to temps used previously.
394                 for (live_idx, dead_idx, local) in &opt_info.storage_stmts {
395                     // The temporary that we've read the variant field into is scoped to this block,
396                     // so we can remove the assignment.
397                     if *local == opt_info.local_temp_0 {
398                         bb.statements[opt_info.get_variant_field_stmt].make_nop();
399                     }
400
401                     for (left, right) in &opt_info.field_tmp_assignments {
402                         if local == left || local == right {
403                             bb.statements[*live_idx].make_nop();
404                             bb.statements[*dead_idx].make_nop();
405                         }
406                     }
407                 }
408
409                 // Right shape; transform
410                 for stmt_idx in opt_info.stmts_to_remove {
411                     bb.statements[stmt_idx].make_nop();
412                 }
413
414                 let stmt = &mut bb.statements[opt_info.stmt_to_overwrite];
415                 stmt.source_info = opt_info.source_info;
416                 stmt.kind = StatementKind::Assign(box (
417                     opt_info.local_0.into(),
418                     Rvalue::Use(Operand::Move(opt_info.local_1.into())),
419                 ));
420
421                 bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop);
422
423                 // Fix the debug info to point to the right local
424                 for dbg_index in opt_info.dbg_info_to_adjust {
425                     let dbg_info = &mut debug_info[dbg_index];
426                     assert!(dbg_info.place.projection.is_empty());
427                     dbg_info.place.local = opt_info.local_0;
428                     dbg_info.place.projection = opt_info.dbg_projection;
429                 }
430
431                 trace!("block is now {:?}", bb.statements);
432             }
433         }
434     }
435 }
436
437 struct LocalUseCounter {
438     local_uses: IndexVec<Local, usize>,
439 }
440
441 impl LocalUseCounter {
442     fn get_local_uses<'tcx>(body: &Body<'tcx>) -> IndexVec<Local, usize> {
443         let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) };
444         counter.visit_body(body);
445         counter.local_uses
446     }
447 }
448
449 impl<'tcx> Visitor<'tcx> for LocalUseCounter {
450     fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) {
451         if context.is_storage_marker()
452             || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo)
453         {
454             return;
455         }
456
457         self.local_uses[*local] += 1;
458     }
459 }
460
461 /// Match on:
462 /// ```rust
463 /// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
464 /// ```
465 fn match_get_variant_field<'tcx>(
466     stmt: &Statement<'tcx>,
467 ) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> {
468     match &stmt.kind {
469         StatementKind::Assign(box (
470             place_into,
471             Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)),
472         )) => {
473             let local_into = place_into.as_local()?;
474             let (local_from, vf) = match_variant_field_place(*pf)?;
475             Some((local_into, local_from, vf, pf.projection))
476         }
477         _ => None,
478     }
479 }
480
481 /// Match on:
482 /// ```rust
483 /// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO;
484 /// ```
485 fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {
486     match &stmt.kind {
487         StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => {
488             let local_into = place_into.as_local()?;
489             let (local_from, vf) = match_variant_field_place(*place_from)?;
490             Some((local_into, local_from, vf))
491         }
492         _ => None,
493     }
494 }
495
496 /// Match on:
497 /// ```rust
498 /// discriminant(_LOCAL_TO_SET) = VAR_IDX;
499 /// ```
500 fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)> {
501     match &stmt.kind {
502         StatementKind::SetDiscriminant { place, variant_index } => {
503             Some((place.as_local()?, *variant_index))
504         }
505         _ => None,
506     }
507 }
508
509 #[derive(PartialEq, Debug)]
510 struct VarField<'tcx> {
511     field: Field,
512     field_ty: Ty<'tcx>,
513     var_idx: VariantIdx,
514 }
515
516 /// Match on `((_LOCAL as Variant).FIELD: TY)`.
517 fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarField<'tcx>)> {
518     match place.as_ref() {
519         PlaceRef {
520             local,
521             projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)],
522         } => Some((local, VarField { field, field_ty: ty, var_idx })),
523         _ => None,
524     }
525 }
526
527 /// Simplifies `SwitchInt(_) -> [targets]`,
528 /// where all the `targets` have the same form,
529 /// into `goto -> target_first`.
530 pub struct SimplifyBranchSame;
531
532 impl<'tcx> MirPass<'tcx> for SimplifyBranchSame {
533     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
534         trace!("Running SimplifyBranchSame on {:?}", body.source);
535         let finder = SimplifyBranchSameOptimizationFinder { body, tcx };
536         let opts = finder.find();
537
538         let did_remove_blocks = opts.len() > 0;
539         for opt in opts.iter() {
540             trace!("SUCCESS: Applying optimization {:?}", opt);
541             // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
542             body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind =
543                 TerminatorKind::Goto { target: opt.bb_to_goto };
544         }
545
546         if did_remove_blocks {
547             // We have dead blocks now, so remove those.
548             simplify::remove_dead_blocks(body);
549         }
550     }
551 }
552
553 #[derive(Debug)]
554 struct SimplifyBranchSameOptimization {
555     /// All basic blocks are equal so go to this one
556     bb_to_goto: BasicBlock,
557     /// Basic block where the terminator can be simplified to a goto
558     bb_to_opt_terminator: BasicBlock,
559 }
560
561 struct SwitchTargetAndValue {
562     target: BasicBlock,
563     // None in case of the `otherwise` case
564     value: Option<u128>,
565 }
566
567 struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
568     body: &'a Body<'tcx>,
569     tcx: TyCtxt<'tcx>,
570 }
571
572 impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
573     fn find(&self) -> Vec<SimplifyBranchSameOptimization> {
574         self.body
575             .basic_blocks()
576             .iter_enumerated()
577             .filter_map(|(bb_idx, bb)| {
578                 let (discr_switched_on, targets_and_values) = match &bb.terminator().kind {
579                     TerminatorKind::SwitchInt { targets, discr, values, .. } => {
580                         // if values.len() == targets.len() - 1, we need to include None where no value is present
581                         // such that the zip does not throw away targets. If no `otherwise` case is in targets, the zip will simply throw away the added None
582                         let values_extended = values.iter().map(|x|Some(*x)).chain(once(None));
583                         let targets_and_values:Vec<_> = targets.iter().zip(values_extended)
584                             .map(|(target, value)| SwitchTargetAndValue{target:*target, value})
585                             .collect();
586                         assert_eq!(targets.len(), targets_and_values.len());
587                         (discr, targets_and_values)},
588                     _ => return None,
589                 };
590
591                 // find the adt that has its discriminant read
592                 // assuming this must be the last statement of the block
593                 let adt_matched_on = match &bb.statements.last()?.kind {
594                     StatementKind::Assign(box (place, rhs))
595                         if Some(*place) == discr_switched_on.place() =>
596                     {
597                         match rhs {
598                             Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place,
599                             _ => {
600                                 trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs);
601                                 return None;
602                             }
603                         }
604                     }
605                     other => {
606                         trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other);
607                         return None
608                     },
609                 };
610
611                 let mut iter_bbs_reachable = targets_and_values
612                     .iter()
613                     .map(|target_and_value| (target_and_value, &self.body.basic_blocks()[target_and_value.target]))
614                     .filter(|(_, bb)| {
615                         // Reaching `unreachable` is UB so assume it doesn't happen.
616                         bb.terminator().kind != TerminatorKind::Unreachable
617                     // But `asm!(...)` could abort the program,
618                     // so we cannot assume that the `unreachable` terminator itself is reachable.
619                     // FIXME(Centril): use a normalization pass instead of a check.
620                     || bb.statements.iter().any(|stmt| match stmt.kind {
621                         StatementKind::LlvmInlineAsm(..) => true,
622                         _ => false,
623                     })
624                     })
625                     .peekable();
626
627                 let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(&targets_and_values[0]);
628                 let mut all_successors_equivalent = StatementEquality::TrivialEqual;
629
630                 // All successor basic blocks must be equal or contain statements that are pairwise considered equal.
631                 for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() {
632                     let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup
633                                             && bb_l.terminator().kind == bb_r.terminator().kind
634                                             && bb_l.statements.len() == bb_r.statements.len();
635                     let statement_check = || {
636                         bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| {
637                             let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r);
638                             if matches!(stmt_equality, StatementEquality::NotEqual) {
639                                 // short circuit
640                                 None
641                             } else {
642                                 Some(acc.combine(&stmt_equality))
643                             }
644                         })
645                         .unwrap_or(StatementEquality::NotEqual)
646                     };
647                     if !trivial_checks {
648                         all_successors_equivalent = StatementEquality::NotEqual;
649                         break;
650                     }
651                     all_successors_equivalent = all_successors_equivalent.combine(&statement_check());
652                 };
653
654                 match all_successors_equivalent{
655                     StatementEquality::TrivialEqual => {
656                         // statements are trivially equal, so just take first
657                         trace!("Statements are trivially equal");
658                         Some(SimplifyBranchSameOptimization {
659                             bb_to_goto: bb_first.target,
660                             bb_to_opt_terminator: bb_idx,
661                         })
662                     }
663                     StatementEquality::ConsideredEqual(bb_to_choose) => {
664                         trace!("Statements are considered equal");
665                         Some(SimplifyBranchSameOptimization {
666                             bb_to_goto: bb_to_choose,
667                             bb_to_opt_terminator: bb_idx,
668                         })
669                     }
670                     StatementEquality::NotEqual => {
671                         trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx);
672                         None
673                     }
674                 }
675             })
676             .collect()
677     }
678
679     /// Tests if two statements can be considered equal
680     ///
681     /// Statements can be trivially equal if the kinds match.
682     /// But they can also be considered equal in the following case A:
683     /// ```
684     /// discriminant(_0) = 0;   // bb1
685     /// _0 = move _1;           // bb2
686     /// ```
687     /// In this case the two statements are equal iff
688     /// 1: _0 is an enum where the variant index 0 is fieldless, and
689     /// 2:  bb1 was targeted by a switch where the discriminant of _1 was switched on
690     fn statement_equality(
691         &self,
692         adt_matched_on: Place<'tcx>,
693         x: &Statement<'tcx>,
694         x_target_and_value: &SwitchTargetAndValue,
695         y: &Statement<'tcx>,
696         y_target_and_value: &SwitchTargetAndValue,
697     ) -> StatementEquality {
698         let helper = |rhs: &Rvalue<'tcx>,
699                       place: &Place<'tcx>,
700                       variant_index: &VariantIdx,
701                       side_to_choose| {
702             let place_type = place.ty(self.body, self.tcx).ty;
703             let adt = match *place_type.kind() {
704                 ty::Adt(adt, _) if adt.is_enum() => adt,
705                 _ => return StatementEquality::NotEqual,
706             };
707             let variant_is_fieldless = adt.variants[*variant_index].fields.is_empty();
708             if !variant_is_fieldless {
709                 trace!("NO: variant {:?} was not fieldless", variant_index);
710                 return StatementEquality::NotEqual;
711             }
712
713             match rhs {
714                 Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => {
715                     StatementEquality::ConsideredEqual(side_to_choose)
716                 }
717                 _ => {
718                     trace!(
719                         "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}",
720                         rhs,
721                         adt_matched_on
722                     );
723                     StatementEquality::NotEqual
724                 }
725             }
726         };
727         match (&x.kind, &y.kind) {
728             // trivial case
729             (x, y) if x == y => StatementEquality::TrivialEqual,
730
731             // check for case A
732             (
733                 StatementKind::Assign(box (_, rhs)),
734                 StatementKind::SetDiscriminant { place, variant_index },
735             )
736             // we need to make sure that the switch value that targets the bb with SetDiscriminant (y), is the same as the variant index
737             if Some(variant_index.index() as u128) == y_target_and_value.value => {
738                 // choose basic block of x, as that has the assign
739                 helper(rhs, place, variant_index, x_target_and_value.target)
740             }
741             (
742                 StatementKind::SetDiscriminant { place, variant_index },
743                 StatementKind::Assign(box (_, rhs)),
744             )
745             // we need to make sure that the switch value that targets the bb with SetDiscriminant (x), is the same as the variant index
746             if Some(variant_index.index() as u128) == x_target_and_value.value  => {
747                 // choose basic block of y, as that has the assign
748                 helper(rhs, place, variant_index, y_target_and_value.target)
749             }
750             _ => {
751                 trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y);
752                 StatementEquality::NotEqual
753             }
754         }
755     }
756 }
757
758 #[derive(Copy, Clone, Eq, PartialEq)]
759 enum StatementEquality {
760     /// The two statements are trivially equal; same kind
761     TrivialEqual,
762     /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization.
763     /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index
764     ConsideredEqual(BasicBlock),
765     /// The two statements are not equal
766     NotEqual,
767 }
768
769 impl StatementEquality {
770     fn combine(&self, other: &StatementEquality) -> StatementEquality {
771         use StatementEquality::*;
772         match (self, other) {
773             (TrivialEqual, TrivialEqual) => TrivialEqual,
774             (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => {
775                 ConsideredEqual(*b)
776             }
777             (ConsideredEqual(b1), ConsideredEqual(b2)) => {
778                 if b1 == b2 {
779                     ConsideredEqual(*b1)
780                 } else {
781                     NotEqual
782                 }
783             }
784             (_, NotEqual) | (NotEqual, _) => NotEqual,
785         }
786     }
787 }