1 //! The general point of the optimizations provided here is to simplify something like:
12 use crate::transform::{simplify, MirPass, MirSource};
13 use itertools::Itertools as _;
14 use rustc_index::{bit_set::BitSet, vec::IndexVec};
15 use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
16 use rustc_middle::mir::*;
17 use rustc_middle::ty::{List, Ty, TyCtxt};
18 use rustc_target::abi::VariantIdx;
19 use std::iter::{Enumerate, Peekable};
22 /// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
24 /// This is done by transforming basic blocks where the statements match:
27 /// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );
28 /// _TMP_2 = _LOCAL_TMP;
29 /// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2;
30 /// discriminant(_LOCAL_0) = VAR_IDX;
36 /// _LOCAL_0 = move _LOCAL_1
38 pub struct SimplifyArmIdentity;
41 struct ArmIdentityInfo<'tcx> {
42 /// Storage location for the variant's field
44 /// Storage location holding the variant being read from
46 /// The variant field being read from
47 vf_s0: VarField<'tcx>,
48 /// Index of the statement which loads the variant being read
49 get_variant_field_stmt: usize,
51 /// Tracks each assignment to a temporary of the variant's field
52 field_tmp_assignments: Vec<(Local, Local)>,
54 /// Storage location holding the variant's field that was read from
56 /// Storage location holding the enum that we are writing to
58 /// The variant field being written to
59 vf_s1: VarField<'tcx>,
61 /// Storage location that the discriminant is being written to
62 set_discr_local: Local,
63 /// The variant being written
64 set_discr_var_idx: VariantIdx,
66 /// Index of the statement that should be overwritten as a move
67 stmt_to_overwrite: usize,
68 /// SourceInfo for the new move
69 source_info: SourceInfo,
71 /// Indices of matching Storage{Live,Dead} statements encountered.
72 /// (StorageLive index,, StorageDead index, Local)
73 storage_stmts: Vec<(usize, usize, Local)>,
75 /// The statements that should be removed (turned into nops)
76 stmts_to_remove: Vec<usize>,
78 /// Indices of debug variables that need to be adjusted to point to
79 // `{local_0}.{dbg_projection}`.
80 dbg_info_to_adjust: Vec<usize>,
82 /// The projection used to rewrite debug info.
83 dbg_projection: &'tcx List<PlaceElem<'tcx>>,
86 fn get_arm_identity_info<'a, 'tcx>(
87 stmts: &'a [Statement<'tcx>],
89 debug_info: &'a [VarDebugInfo<'tcx>],
90 ) -> Option<ArmIdentityInfo<'tcx>> {
91 // This can't possibly match unless there are at least 3 statements in the block
92 // so fail fast on tiny blocks.
97 let mut tmp_assigns = Vec::new();
98 let mut nop_stmts = Vec::new();
99 let mut storage_stmts = Vec::new();
100 let mut storage_live_stmts = Vec::new();
101 let mut storage_dead_stmts = Vec::new();
103 type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>;
105 fn is_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool {
106 matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_))
109 /// Eats consecutive Statements which match `test`, performing the specified `action` for each.
110 /// The iterator `stmt_iter` is not advanced if none were matched.
111 fn try_eat<'a, 'tcx>(
112 stmt_iter: &mut StmtIter<'a, 'tcx>,
113 test: impl Fn(&'a Statement<'tcx>) -> bool,
114 mut action: impl FnMut(usize, &'a Statement<'tcx>),
116 while stmt_iter.peek().map(|(_, stmt)| test(stmt)).unwrap_or(false) {
117 let (idx, stmt) = stmt_iter.next().unwrap();
123 /// Eats consecutive `StorageLive` and `StorageDead` Statements.
124 /// The iterator `stmt_iter` is not advanced if none were found.
125 fn try_eat_storage_stmts<'a, 'tcx>(
126 stmt_iter: &mut StmtIter<'a, 'tcx>,
127 storage_live_stmts: &mut Vec<(usize, Local)>,
128 storage_dead_stmts: &mut Vec<(usize, Local)>,
130 try_eat(stmt_iter, is_storage_stmt, |idx, stmt| {
131 if let StatementKind::StorageLive(l) = stmt.kind {
132 storage_live_stmts.push((idx, l));
133 } else if let StatementKind::StorageDead(l) = stmt.kind {
134 storage_dead_stmts.push((idx, l));
139 fn is_tmp_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool {
140 use rustc_middle::mir::StatementKind::Assign;
141 if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind {
142 place.as_local().is_some() && p.as_local().is_some()
148 /// Eats consecutive `Assign` Statements.
149 // The iterator `stmt_iter` is not advanced if none were found.
150 fn try_eat_assign_tmp_stmts<'a, 'tcx>(
151 stmt_iter: &mut StmtIter<'a, 'tcx>,
152 tmp_assigns: &mut Vec<(Local, Local)>,
153 nop_stmts: &mut Vec<usize>,
155 try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| {
156 use rustc_middle::mir::StatementKind::Assign;
157 if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) =
160 tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap()));
166 fn find_storage_live_dead_stmts_for_local<'tcx>(
168 stmts: &[Statement<'tcx>],
169 ) -> Option<(usize, usize)> {
170 trace!("looking for {:?}", local);
171 let mut storage_live_stmt = None;
172 let mut storage_dead_stmt = None;
173 for (idx, stmt) in stmts.iter().enumerate() {
174 if stmt.kind == StatementKind::StorageLive(local) {
175 storage_live_stmt = Some(idx);
176 } else if stmt.kind == StatementKind::StorageDead(local) {
177 storage_dead_stmt = Some(idx);
181 Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX)))
184 // Try to match the expected MIR structure with the basic block we're processing.
185 // We want to see something that looks like:
187 // (StorageLive(_) | StorageDead(_));*
188 // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
189 // (StorageLive(_) | StorageDead(_));*
190 // (tmp_n+1 = tmp_n);*
191 // (StorageLive(_) | StorageDead(_));*
192 // (tmp_n+1 = tmp_n);*
193 // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp;
194 // discriminant(LOCAL_FROM) = VariantIdx;
195 // (StorageLive(_) | StorageDead(_));*
197 let mut stmt_iter = stmts.iter().enumerate().peekable();
199 try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
201 let (get_variant_field_stmt, stmt) = stmt_iter.next()?;
202 let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?;
204 try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
206 try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
208 try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
210 try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
212 let (idx, stmt) = stmt_iter.next()?;
213 let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?;
216 let (idx, stmt) = stmt_iter.next()?;
217 let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?;
218 let discr_stmt_source_info = stmt.source_info;
221 try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
223 for (live_idx, live_local) in storage_live_stmts {
224 if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) {
225 let (dead_idx, _) = storage_dead_stmts.swap_remove(i);
226 storage_stmts.push((live_idx, dead_idx, live_local));
228 if live_local == local_tmp_s0 {
229 nop_stmts.push(get_variant_field_stmt);
236 // Use one of the statements we're going to discard between the point
237 // where the storage location for the variant field becomes live and
239 let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?;
240 let stmt_to_overwrite =
241 nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx);
243 let mut tmp_assigned_vars = BitSet::new_empty(locals_count);
244 for (l, r) in &tmp_assigns {
245 tmp_assigned_vars.insert(*l);
246 tmp_assigned_vars.insert(*r);
249 let mut dbg_info_to_adjust = Vec::new();
250 for (i, var_info) in debug_info.iter().enumerate() {
251 if tmp_assigned_vars.contains(var_info.place.local) {
252 dbg_info_to_adjust.push(i);
256 Some(ArmIdentityInfo {
257 local_temp_0: local_tmp_s0,
260 get_variant_field_stmt,
261 field_tmp_assignments: tmp_assigns,
267 stmt_to_overwrite: *stmt_to_overwrite?,
268 source_info: discr_stmt_source_info,
270 stmts_to_remove: nop_stmts,
276 fn optimization_applies<'tcx>(
277 opt_info: &ArmIdentityInfo<'tcx>,
278 local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
279 local_uses: &IndexVec<Local, usize>,
280 var_debug_info: &[VarDebugInfo<'tcx>],
282 trace!("testing if optimization applies...");
284 // FIXME(wesleywiser): possibly relax this restriction?
285 if opt_info.local_0 == opt_info.local_1 {
286 trace!("NO: moving into ourselves");
288 } else if opt_info.vf_s0 != opt_info.vf_s1 {
289 trace!("NO: the field-and-variant information do not match");
291 } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty {
292 // FIXME(Centril,oli-obk): possibly relax to same layout?
293 trace!("NO: source and target locals have different types");
295 } else if (opt_info.local_0, opt_info.vf_s0.var_idx)
296 != (opt_info.set_discr_local, opt_info.set_discr_var_idx)
298 trace!("NO: the discriminants do not match");
302 // Verify the assigment chain consists of the form b = a; c = b; d = c; etc...
303 if opt_info.field_tmp_assignments.is_empty() {
304 trace!("NO: no assignments found");
307 let mut last_assigned_to = opt_info.field_tmp_assignments[0].1;
308 let source_local = last_assigned_to;
309 for (l, r) in &opt_info.field_tmp_assignments {
310 if *r != last_assigned_to {
311 trace!("NO: found unexpected assignment {:?} = {:?}", l, r);
315 last_assigned_to = *l;
318 // Check that the first and last used locals are only used twice
319 // since they are of the form:
322 // _first = ((_x as Variant).n: ty);
325 // ((_y as Variant).n: ty) = _n;
326 // discriminant(_y) = z;
328 for (l, r) in &opt_info.field_tmp_assignments {
329 if local_uses[*l] != 2 {
330 warn!("NO: FAILED assignment chain local {:?} was used more than twice", l);
332 } else if local_uses[*r] != 2 {
333 warn!("NO: FAILED assignment chain local {:?} was used more than twice", r);
338 // Check that debug info only points to full Locals and not projections.
339 for dbg_idx in &opt_info.dbg_info_to_adjust {
340 let dbg_info = &var_debug_info[*dbg_idx];
341 if !dbg_info.place.projection.is_empty() {
342 trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, dbg_info.place);
347 if source_local != opt_info.local_temp_0 {
349 "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
351 opt_info.local_temp_0
354 } else if last_assigned_to != opt_info.local_tmp_s1 {
356 "NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}",
358 opt_info.local_tmp_s1
363 trace!("SUCCESS: optimization applies!");
367 impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
368 fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
369 if tcx.sess.opts.debugging_opts.mir_opt_level < 2 {
373 trace!("running SimplifyArmIdentity on {:?}", source);
374 let local_uses = LocalUseCounter::get_local_uses(body);
375 let (basic_blocks, local_decls, debug_info) =
376 body.basic_blocks_local_decls_mut_and_var_debug_info();
377 for bb in basic_blocks {
378 if let Some(opt_info) =
379 get_arm_identity_info(&bb.statements, local_decls.len(), debug_info)
381 trace!("got opt_info = {:#?}", opt_info);
382 if !optimization_applies(&opt_info, local_decls, &local_uses, &debug_info) {
383 debug!("optimization skipped for {:?}", source);
387 // Also remove unused Storage{Live,Dead} statements which correspond
388 // to temps used previously.
389 for (live_idx, dead_idx, local) in &opt_info.storage_stmts {
390 // The temporary that we've read the variant field into is scoped to this block,
391 // so we can remove the assignment.
392 if *local == opt_info.local_temp_0 {
393 bb.statements[opt_info.get_variant_field_stmt].make_nop();
396 for (left, right) in &opt_info.field_tmp_assignments {
397 if local == left || local == right {
398 bb.statements[*live_idx].make_nop();
399 bb.statements[*dead_idx].make_nop();
404 // Right shape; transform
405 for stmt_idx in opt_info.stmts_to_remove {
406 bb.statements[stmt_idx].make_nop();
409 let stmt = &mut bb.statements[opt_info.stmt_to_overwrite];
410 stmt.source_info = opt_info.source_info;
411 stmt.kind = StatementKind::Assign(box (
412 opt_info.local_0.into(),
413 Rvalue::Use(Operand::Move(opt_info.local_1.into())),
416 bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop);
418 // Fix the debug info to point to the right local
419 for dbg_index in opt_info.dbg_info_to_adjust {
420 let dbg_info = &mut debug_info[dbg_index];
421 assert!(dbg_info.place.projection.is_empty());
422 dbg_info.place.local = opt_info.local_0;
423 dbg_info.place.projection = opt_info.dbg_projection;
426 trace!("block is now {:?}", bb.statements);
432 struct LocalUseCounter {
433 local_uses: IndexVec<Local, usize>,
436 impl LocalUseCounter {
437 fn get_local_uses<'tcx>(body: &Body<'tcx>) -> IndexVec<Local, usize> {
438 let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) };
439 counter.visit_body(body);
444 impl<'tcx> Visitor<'tcx> for LocalUseCounter {
445 fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) {
446 if context.is_storage_marker()
447 || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo)
452 self.local_uses[*local] += 1;
458 /// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
460 fn match_get_variant_field<'tcx>(
461 stmt: &Statement<'tcx>,
462 ) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> {
464 StatementKind::Assign(box (place_into, rvalue_from)) => match rvalue_from {
465 Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)) => {
466 let local_into = place_into.as_local()?;
467 let (local_from, vf) = match_variant_field_place(*pf)?;
468 Some((local_into, local_from, vf, pf.projection))
478 /// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO;
480 fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {
482 StatementKind::Assign(box (place_from, rvalue_into)) => match rvalue_into {
483 Rvalue::Use(Operand::Move(place_into)) => {
484 let local_into = place_into.as_local()?;
485 let (local_from, vf) = match_variant_field_place(*place_from)?;
486 Some((local_into, local_from, vf))
496 /// discriminant(_LOCAL_TO_SET) = VAR_IDX;
498 fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)> {
500 StatementKind::SetDiscriminant { place, variant_index } => {
501 Some((place.as_local()?, *variant_index))
507 #[derive(PartialEq, Debug)]
508 struct VarField<'tcx> {
514 /// Match on `((_LOCAL as Variant).FIELD: TY)`.
515 fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarField<'tcx>)> {
516 match place.as_ref() {
519 projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)],
520 } => Some((local, VarField { field, field_ty: ty, var_idx })),
525 /// Simplifies `SwitchInt(_) -> [targets]`,
526 /// where all the `targets` have the same form,
527 /// into `goto -> target_first`.
528 pub struct SimplifyBranchSame;
530 impl<'tcx> MirPass<'tcx> for SimplifyBranchSame {
531 fn run_pass(&self, _: TyCtxt<'tcx>, _: MirSource<'tcx>, body: &mut Body<'tcx>) {
532 let mut did_remove_blocks = false;
533 let bbs = body.basic_blocks_mut();
534 for bb_idx in bbs.indices() {
535 let targets = match &bbs[bb_idx].terminator().kind {
536 TerminatorKind::SwitchInt { targets, .. } => targets,
540 let mut iter_bbs_reachable = targets
542 .map(|idx| (*idx, &bbs[*idx]))
544 // Reaching `unreachable` is UB so assume it doesn't happen.
545 bb.terminator().kind != TerminatorKind::Unreachable
546 // But `asm!(...)` could abort the program,
547 // so we cannot assume that the `unreachable` terminator itself is reachable.
548 // FIXME(Centril): use a normalization pass instead of a check.
549 || bb.statements.iter().any(|stmt| match stmt.kind {
550 StatementKind::LlvmInlineAsm(..) => true,
556 // We want to `goto -> bb_first`.
557 let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(targets[0]);
559 // All successor basic blocks should have the exact same form.
560 let all_successors_equivalent =
561 iter_bbs_reachable.map(|(_, bb)| bb).tuple_windows().all(|(bb_l, bb_r)| {
562 bb_l.is_cleanup == bb_r.is_cleanup
563 && bb_l.terminator().kind == bb_r.terminator().kind
564 && bb_l.statements.iter().eq_by(&bb_r.statements, |x, y| x.kind == y.kind)
567 if all_successors_equivalent {
568 // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
569 bbs[bb_idx].terminator_mut().kind = TerminatorKind::Goto { target: bb_first };
570 did_remove_blocks = true;
574 if did_remove_blocks {
575 // We have dead blocks now, so remove those.
576 simplify::remove_dead_blocks(body);