]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/sroa.rs
Mark 'atomic_mut_ptr' methods const
[rust.git] / compiler / rustc_mir_transform / src / sroa.rs
1 use crate::MirPass;
2 use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
3 use rustc_index::bit_set::BitSet;
4 use rustc_index::vec::IndexVec;
5 use rustc_middle::mir::visit::*;
6 use rustc_middle::mir::*;
7 use rustc_middle::ty::TyCtxt;
8
9 pub struct ScalarReplacementOfAggregates;
10
11 impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
12     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
13         sess.mir_opt_level() >= 3
14     }
15
16     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
17         let escaping = escaping_locals(&*body);
18         debug!(?escaping);
19         let replacements = compute_flattening(tcx, body, escaping);
20         debug!(?replacements);
21         replace_flattened_locals(tcx, body, replacements);
22     }
23 }
24
25 /// Identify all locals that are not eligible for SROA.
26 ///
27 /// There are 3 cases:
28 /// - the aggegated local is used or passed to other code (function parameters and arguments);
29 /// - the locals is a union or an enum;
30 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
31 ///   client code.
32 fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
33     let mut set = BitSet::new_empty(body.local_decls.len());
34     set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
35     for (local, decl) in body.local_decls().iter_enumerated() {
36         if decl.ty.is_union() || decl.ty.is_enum() {
37             set.insert(local);
38         }
39     }
40     let mut visitor = EscapeVisitor { set };
41     visitor.visit_body(body);
42     return visitor.set;
43
44     struct EscapeVisitor {
45         set: BitSet<Local>,
46     }
47
48     impl<'tcx> Visitor<'tcx> for EscapeVisitor {
49         fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
50             self.set.insert(local);
51         }
52
53         fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
54             // Mirror the implementation in PreFlattenVisitor.
55             if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
56                 return;
57             }
58             self.super_place(place, context, location);
59         }
60
61         fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
62             if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
63                 if !place.is_indirect() {
64                     // Raw pointers may be used to access anything inside the enclosing place.
65                     self.set.insert(place.local);
66                     return;
67                 }
68             }
69             self.super_rvalue(rvalue, location)
70         }
71
72         fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
73             if let StatementKind::StorageLive(..)
74             | StatementKind::StorageDead(..)
75             | StatementKind::Deinit(..) = statement.kind
76             {
77                 // Storage statements are expanded in run_pass.
78                 return;
79             }
80             self.super_statement(statement, location)
81         }
82
83         fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
84             // Drop implicitly calls `drop_in_place`, which takes a `&mut`.
85             // This implies that `Drop` implicitly takes the address of the place.
86             if let TerminatorKind::Drop { place, .. }
87             | TerminatorKind::DropAndReplace { place, .. } = terminator.kind
88             {
89                 if !place.is_indirect() {
90                     // Raw pointers may be used to access anything inside the enclosing place.
91                     self.set.insert(place.local);
92                     return;
93                 }
94             }
95             self.super_terminator(terminator, location);
96         }
97
98         // We ignore anything that happens in debuginfo, since we expand it using
99         // `VarDebugInfoContents::Composite`.
100         fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
101     }
102 }
103
104 #[derive(Default, Debug)]
105 struct ReplacementMap<'tcx> {
106     fields: FxIndexMap<PlaceRef<'tcx>, Local>,
107 }
108
109 /// Compute the replacement of flattened places into locals.
110 ///
111 /// For each eligible place, we assign a new local to each accessed field.
112 /// The replacement will be done later in `ReplacementVisitor`.
113 fn compute_flattening<'tcx>(
114     tcx: TyCtxt<'tcx>,
115     body: &mut Body<'tcx>,
116     escaping: BitSet<Local>,
117 ) -> ReplacementMap<'tcx> {
118     let mut visitor = PreFlattenVisitor {
119         tcx,
120         escaping,
121         local_decls: &mut body.local_decls,
122         map: Default::default(),
123     };
124     for (block, bbdata) in body.basic_blocks.iter_enumerated() {
125         visitor.visit_basic_block_data(block, bbdata);
126     }
127     return visitor.map;
128
129     struct PreFlattenVisitor<'tcx, 'll> {
130         tcx: TyCtxt<'tcx>,
131         local_decls: &'ll mut LocalDecls<'tcx>,
132         escaping: BitSet<Local>,
133         map: ReplacementMap<'tcx>,
134     }
135
136     impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> {
137         fn create_place(&mut self, place: PlaceRef<'tcx>) {
138             if self.escaping.contains(place.local) {
139                 return;
140             }
141
142             match self.map.fields.entry(place) {
143                 IndexEntry::Occupied(_) => {}
144                 IndexEntry::Vacant(v) => {
145                     let ty = place.ty(&*self.local_decls, self.tcx).ty;
146                     let local = self.local_decls.push(LocalDecl {
147                         ty,
148                         user_ty: None,
149                         ..self.local_decls[place.local].clone()
150                     });
151                     v.insert(local);
152                 }
153             }
154         }
155     }
156
157     impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> {
158         fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) {
159             if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
160                 let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
161                 self.create_place(pr)
162             }
163         }
164     }
165 }
166
167 /// Perform the replacement computed by `compute_flattening`.
168 fn replace_flattened_locals<'tcx>(
169     tcx: TyCtxt<'tcx>,
170     body: &mut Body<'tcx>,
171     replacements: ReplacementMap<'tcx>,
172 ) {
173     let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
174     for p in replacements.fields.keys() {
175         all_dead_locals.insert(p.local);
176     }
177     debug!(?all_dead_locals);
178     if all_dead_locals.is_empty() {
179         return;
180     }
181
182     let mut fragments = IndexVec::new();
183     for (k, v) in &replacements.fields {
184         fragments.ensure_contains_elem(k.local, || Vec::new());
185         fragments[k.local].push((k.projection, *v));
186     }
187     debug!(?fragments);
188
189     let mut visitor = ReplacementVisitor {
190         tcx,
191         local_decls: &body.local_decls,
192         replacements,
193         all_dead_locals,
194         fragments,
195     };
196     for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
197         visitor.visit_basic_block_data(bb, data);
198     }
199     for scope in &mut body.source_scopes {
200         visitor.visit_source_scope_data(scope);
201     }
202     for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
203         visitor.visit_user_type_annotation(index, annotation);
204     }
205     for var_debug_info in &mut body.var_debug_info {
206         visitor.visit_var_debug_info(var_debug_info);
207     }
208 }
209
210 struct ReplacementVisitor<'tcx, 'll> {
211     tcx: TyCtxt<'tcx>,
212     /// This is only used to compute the type for `VarDebugInfoContents::Composite`.
213     local_decls: &'ll LocalDecls<'tcx>,
214     /// Work to do.
215     replacements: ReplacementMap<'tcx>,
216     /// This is used to check that we are not leaving references to replaced locals behind.
217     all_dead_locals: BitSet<Local>,
218     /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
219     /// and deinit statement and debuginfo.
220     fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
221 }
222
223 impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
224     fn gather_debug_info_fragments(
225         &self,
226         place: PlaceRef<'tcx>,
227     ) -> Vec<VarDebugInfoFragment<'tcx>> {
228         let mut fragments = Vec::new();
229         let parts = &self.fragments[place.local];
230         for (proj, replacement_local) in parts {
231             if proj.starts_with(place.projection) {
232                 fragments.push(VarDebugInfoFragment {
233                     projection: proj[place.projection.len()..].to_vec(),
234                     contents: Place::from(*replacement_local),
235                 });
236             }
237         }
238         fragments
239     }
240
241     fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
242         if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
243             let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
244             let local = self.replacements.fields.get(&pr)?;
245             Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
246         } else {
247             None
248         }
249     }
250 }
251
252 impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
253     fn tcx(&self) -> TyCtxt<'tcx> {
254         self.tcx
255     }
256
257     fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
258         if let StatementKind::StorageLive(..)
259         | StatementKind::StorageDead(..)
260         | StatementKind::Deinit(..) = statement.kind
261         {
262             // Storage statements are expanded in run_pass.
263             return;
264         }
265         self.super_statement(statement, location)
266     }
267
268     fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
269         if let Some(repl) = self.replace_place(place.as_ref()) {
270             *place = repl
271         } else {
272             self.super_place(place, context, location)
273         }
274     }
275
276     fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
277         match &mut var_debug_info.value {
278             VarDebugInfoContents::Place(ref mut place) => {
279                 if let Some(repl) = self.replace_place(place.as_ref()) {
280                     *place = repl;
281                 } else if self.all_dead_locals.contains(place.local) {
282                     let ty = place.ty(self.local_decls, self.tcx).ty;
283                     let fragments = self.gather_debug_info_fragments(place.as_ref());
284                     var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
285                 }
286             }
287             VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
288                 let mut new_fragments = Vec::new();
289                 fragments
290                     .drain_filter(|fragment| {
291                         if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
292                             fragment.contents = repl;
293                             true
294                         } else if self.all_dead_locals.contains(fragment.contents.local) {
295                             let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
296                             new_fragments.extend(frg.into_iter().map(|mut f| {
297                                 f.projection.splice(0..0, fragment.projection.iter().copied());
298                                 f
299                             }));
300                             false
301                         } else {
302                             true
303                         }
304                     })
305                     .for_each(drop);
306                 fragments.extend(new_fragments);
307             }
308             VarDebugInfoContents::Const(_) => {}
309         }
310     }
311
312     fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
313         self.super_basic_block_data(bb, bbdata);
314
315         #[derive(Debug)]
316         enum Stmt {
317             StorageLive,
318             StorageDead,
319             Deinit,
320         }
321
322         bbdata.expand_statements(|stmt| {
323             let source_info = stmt.source_info;
324             let (stmt, origin_local) = match &stmt.kind {
325                 StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
326                 StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
327                 StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
328                 _ => return None,
329             };
330             if !self.all_dead_locals.contains(origin_local) {
331                 return None;
332             }
333             let final_locals = self.fragments.get(origin_local)?;
334             Some(final_locals.iter().map(move |&(_, l)| {
335                 let kind = match stmt {
336                     Stmt::StorageLive => StatementKind::StorageLive(l),
337                     Stmt::StorageDead => StatementKind::StorageDead(l),
338                     Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
339                 };
340                 Statement { source_info, kind }
341             }))
342         });
343     }
344
345     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
346         assert!(!self.all_dead_locals.contains(*local));
347     }
348 }