2 use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
3 use rustc_index::bit_set::BitSet;
4 use rustc_index::vec::IndexVec;
5 use rustc_middle::mir::visit::*;
6 use rustc_middle::mir::*;
7 use rustc_middle::ty::TyCtxt;
9 pub struct ScalarReplacementOfAggregates;
11 impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
12 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
13 sess.mir_opt_level() >= 3
16 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
17 let escaping = escaping_locals(&*body);
19 let replacements = compute_flattening(tcx, body, escaping);
20 debug!(?replacements);
21 replace_flattened_locals(tcx, body, replacements);
25 /// Identify all locals that are not eligible for SROA.
27 /// There are 3 cases:
28 /// - the aggegated local is used or passed to other code (function parameters and arguments);
29 /// - the locals is a union or an enum;
30 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
32 fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
33 let mut set = BitSet::new_empty(body.local_decls.len());
34 set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
35 for (local, decl) in body.local_decls().iter_enumerated() {
36 if decl.ty.is_union() || decl.ty.is_enum() {
40 let mut visitor = EscapeVisitor { set };
41 visitor.visit_body(body);
44 struct EscapeVisitor {
48 impl<'tcx> Visitor<'tcx> for EscapeVisitor {
49 fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
50 self.set.insert(local);
53 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
54 // Mirror the implementation in PreFlattenVisitor.
55 if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
58 self.super_place(place, context, location);
61 fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
62 if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
63 if !place.is_indirect() {
64 // Raw pointers may be used to access anything inside the enclosing place.
65 self.set.insert(place.local);
69 self.super_rvalue(rvalue, location)
72 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
73 if let StatementKind::StorageLive(..)
74 | StatementKind::StorageDead(..)
75 | StatementKind::Deinit(..) = statement.kind
77 // Storage statements are expanded in run_pass.
80 self.super_statement(statement, location)
83 fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
84 // Drop implicitly calls `drop_in_place`, which takes a `&mut`.
85 // This implies that `Drop` implicitly takes the address of the place.
86 if let TerminatorKind::Drop { place, .. }
87 | TerminatorKind::DropAndReplace { place, .. } = terminator.kind
89 if !place.is_indirect() {
90 // Raw pointers may be used to access anything inside the enclosing place.
91 self.set.insert(place.local);
95 self.super_terminator(terminator, location);
98 // We ignore anything that happens in debuginfo, since we expand it using
99 // `VarDebugInfoContents::Composite`.
100 fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
104 #[derive(Default, Debug)]
105 struct ReplacementMap<'tcx> {
106 fields: FxIndexMap<PlaceRef<'tcx>, Local>,
109 /// Compute the replacement of flattened places into locals.
111 /// For each eligible place, we assign a new local to each accessed field.
112 /// The replacement will be done later in `ReplacementVisitor`.
113 fn compute_flattening<'tcx>(
115 body: &mut Body<'tcx>,
116 escaping: BitSet<Local>,
117 ) -> ReplacementMap<'tcx> {
118 let mut visitor = PreFlattenVisitor {
121 local_decls: &mut body.local_decls,
122 map: Default::default(),
124 for (block, bbdata) in body.basic_blocks.iter_enumerated() {
125 visitor.visit_basic_block_data(block, bbdata);
129 struct PreFlattenVisitor<'tcx, 'll> {
131 local_decls: &'ll mut LocalDecls<'tcx>,
132 escaping: BitSet<Local>,
133 map: ReplacementMap<'tcx>,
136 impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> {
137 fn create_place(&mut self, place: PlaceRef<'tcx>) {
138 if self.escaping.contains(place.local) {
142 match self.map.fields.entry(place) {
143 IndexEntry::Occupied(_) => {}
144 IndexEntry::Vacant(v) => {
145 let ty = place.ty(&*self.local_decls, self.tcx).ty;
146 let local = self.local_decls.push(LocalDecl {
149 ..self.local_decls[place.local].clone()
157 impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> {
158 fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) {
159 if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
160 let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
161 self.create_place(pr)
167 /// Perform the replacement computed by `compute_flattening`.
168 fn replace_flattened_locals<'tcx>(
170 body: &mut Body<'tcx>,
171 replacements: ReplacementMap<'tcx>,
173 let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
174 for p in replacements.fields.keys() {
175 all_dead_locals.insert(p.local);
177 debug!(?all_dead_locals);
178 if all_dead_locals.is_empty() {
182 let mut fragments = IndexVec::new();
183 for (k, v) in &replacements.fields {
184 fragments.ensure_contains_elem(k.local, || Vec::new());
185 fragments[k.local].push((&k.projection[..], *v));
189 let mut visitor = ReplacementVisitor {
191 local_decls: &body.local_decls,
196 for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
197 visitor.visit_basic_block_data(bb, data);
199 for scope in &mut body.source_scopes {
200 visitor.visit_source_scope_data(scope);
202 for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
203 visitor.visit_user_type_annotation(index, annotation);
205 for var_debug_info in &mut body.var_debug_info {
206 visitor.visit_var_debug_info(var_debug_info);
210 struct ReplacementVisitor<'tcx, 'll> {
212 /// This is only used to compute the type for `VarDebugInfoContents::Composite`.
213 local_decls: &'ll LocalDecls<'tcx>,
215 replacements: ReplacementMap<'tcx>,
216 /// This is used to check that we are not leaving references to replaced locals behind.
217 all_dead_locals: BitSet<Local>,
218 /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
219 /// and deinit statement and debuginfo.
220 fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
223 impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
224 fn gather_debug_info_fragments(
226 place: PlaceRef<'tcx>,
227 ) -> Vec<VarDebugInfoFragment<'tcx>> {
228 let mut fragments = Vec::new();
229 let parts = &self.fragments[place.local];
230 for (proj, replacement_local) in parts {
231 if proj.starts_with(place.projection) {
232 fragments.push(VarDebugInfoFragment {
233 projection: proj[place.projection.len()..].to_vec(),
234 contents: Place::from(*replacement_local),
241 fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
242 if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
243 let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
244 let local = self.replacements.fields.get(&pr)?;
245 Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
252 impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
253 fn tcx(&self) -> TyCtxt<'tcx> {
257 fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
258 if let StatementKind::StorageLive(..)
259 | StatementKind::StorageDead(..)
260 | StatementKind::Deinit(..) = statement.kind
262 // Storage statements are expanded in run_pass.
265 self.super_statement(statement, location)
268 fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
269 if let Some(repl) = self.replace_place(place.as_ref()) {
272 self.super_place(place, context, location)
276 fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
277 match &mut var_debug_info.value {
278 VarDebugInfoContents::Place(ref mut place) => {
279 if let Some(repl) = self.replace_place(place.as_ref()) {
281 } else if self.all_dead_locals.contains(place.local) {
282 let ty = place.ty(self.local_decls, self.tcx).ty;
283 let fragments = self.gather_debug_info_fragments(place.as_ref());
284 var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
287 VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
288 let mut new_fragments = Vec::new();
290 .drain_filter(|fragment| {
291 if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
292 fragment.contents = repl;
294 } else if self.all_dead_locals.contains(fragment.contents.local) {
295 let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
296 new_fragments.extend(frg.into_iter().map(|mut f| {
297 f.projection.splice(0..0, fragment.projection.iter().copied());
306 fragments.extend(new_fragments);
308 VarDebugInfoContents::Const(_) => {}
312 fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
313 self.super_basic_block_data(bb, bbdata);
322 bbdata.expand_statements(|stmt| {
323 let source_info = stmt.source_info;
324 let (stmt, origin_local) = match &stmt.kind {
325 StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
326 StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
327 StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
330 if !self.all_dead_locals.contains(origin_local) {
333 let final_locals = self.fragments.get(origin_local)?;
334 Some(final_locals.iter().map(move |&(_, l)| {
335 let kind = match stmt {
336 Stmt::StorageLive => StatementKind::StorageLive(l),
337 Stmt::StorageDead => StatementKind::StorageDead(l),
338 Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
340 Statement { source_info, kind }
345 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
346 assert!(!self.all_dead_locals.contains(*local));