]> git.lizzy.rs Git - rust.git/commitdiff
Simplify ReplacementMap.
authorCamille GILLOT <gillot.camille@gmail.com>
Sun, 5 Feb 2023 13:35:33 +0000 (13:35 +0000)
committerCamille GILLOT <gillot.camille@gmail.com>
Sun, 5 Feb 2023 13:41:24 +0000 (13:41 +0000)
compiler/rustc_mir_transform/src/sroa.rs

index 28963c77aa55f0a8460bde9093a2fbe067de79c4..26acd406ed8a9f81616240e55d6b533d680ed79c 100644 (file)
@@ -1,11 +1,10 @@
 use crate::MirPass;
-use rustc_data_structures::fx::FxIndexMap;
 use rustc_index::bit_set::BitSet;
 use rustc_index::vec::IndexVec;
 use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
-use rustc_middle::ty::TyCtxt;
+use rustc_middle::ty::{Ty, TyCtxt};
 use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
 
 pub struct ScalarReplacementOfAggregates;
@@ -26,13 +25,13 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
             let replacements = compute_flattening(tcx, body, escaping);
             debug!(?replacements);
             let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
-            if !all_dead_locals.is_empty() && tcx.sess.mir_opt_level() >= 4 {
+            if !all_dead_locals.is_empty() {
                 for local in excluded.indices() {
-                    excluded[local] |= all_dead_locals.contains(local) ;
+                    excluded[local] |= all_dead_locals.contains(local);
                 }
                 excluded.raw.resize(body.local_decls.len(), false);
             } else {
-                break
+                break;
             }
         }
     }
@@ -111,36 +110,29 @@ fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
 
 #[derive(Default, Debug)]
 struct ReplacementMap<'tcx> {
-    fields: FxIndexMap<PlaceRef<'tcx>, Local>,
     /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
     /// and deinit statement and debuginfo.
-    fragments: IndexVec<Local, Option<Vec<(&'tcx [PlaceElem<'tcx>], Local)>>>,
+    fragments: IndexVec<Local, Option<IndexVec<Field, Option<(Ty<'tcx>, Local)>>>>,
 }
 
 impl<'tcx> ReplacementMap<'tcx> {
-    fn gather_debug_info_fragments(
-        &self,
-        place: PlaceRef<'tcx>,
-    ) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
-        let mut fragments = Vec::new();
-        let Some(parts) = &self.fragments[place.local] else { return None };
-        for (proj, replacement_local) in parts {
-            if proj.starts_with(place.projection) {
-                fragments.push(VarDebugInfoFragment {
-                    projection: proj[place.projection.len()..].to_vec(),
-                    contents: Place::from(*replacement_local),
-                });
-            }
-        }
-        Some(fragments)
+    fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
+        let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; };
+        let fields = self.fragments[place.local].as_ref()?;
+        let (_, new_local) = fields[f]?;
+        Some(Place { local: new_local, projection: tcx.intern_place_elems(&rest) })
     }
 
     fn place_fragments(
         &self,
         place: Place<'tcx>,
-    ) -> Option<&Vec<(&'tcx [PlaceElem<'tcx>], Local)>> {
+    ) -> Option<impl Iterator<Item = (Field, Ty<'tcx>, Local)> + '_> {
         let local = place.as_local()?;
-        self.fragments[local].as_ref()
+        let fields = self.fragments[local].as_ref()?;
+        Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
+            let (ty, local) = opt_ty_local?;
+            Some((field, ty, local))
+        }))
     }
 }
 
@@ -153,8 +145,7 @@ fn compute_flattening<'tcx>(
     body: &mut Body<'tcx>,
     escaping: BitSet<Local>,
 ) -> ReplacementMap<'tcx> {
-    let mut fields = FxIndexMap::default();
-    let mut fragments = IndexVec::from_elem(None::<Vec<_>>, &body.local_decls);
+    let mut fragments = IndexVec::from_elem(None, &body.local_decls);
 
     for local in body.local_decls.indices() {
         if escaping.contains(local) {
@@ -169,14 +160,10 @@ fn compute_flattening<'tcx>(
             };
             let new_local =
                 body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
-            let place = Place::from(local)
-                .project_deeper(&[PlaceElem::Field(field, field_ty)], tcx)
-                .as_ref();
-            fields.insert(place, new_local);
-            fragments[local].get_or_insert_default().push((place.projection, new_local));
+            fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
         });
     }
-    ReplacementMap { fields, fragments }
+    ReplacementMap { fragments }
 }
 
 /// Perform the replacement computed by `compute_flattening`.
@@ -186,8 +173,10 @@ fn replace_flattened_locals<'tcx>(
     replacements: ReplacementMap<'tcx>,
 ) -> BitSet<Local> {
     let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
-    for p in replacements.fields.keys() {
-        all_dead_locals.insert(p.local);
+    for (local, replacements) in replacements.fragments.iter_enumerated() {
+        if replacements.is_some() {
+            all_dead_locals.insert(local);
+        }
     }
     debug!(?all_dead_locals);
     if all_dead_locals.is_empty() {
@@ -197,7 +186,7 @@ fn replace_flattened_locals<'tcx>(
     let mut visitor = ReplacementVisitor {
         tcx,
         local_decls: &body.local_decls,
-        replacements,
+        replacements: &replacements,
         all_dead_locals,
         patch: MirPatch::new(body),
     };
@@ -223,21 +212,23 @@ struct ReplacementVisitor<'tcx, 'll> {
     /// This is only used to compute the type for `VarDebugInfoContents::Composite`.
     local_decls: &'ll LocalDecls<'tcx>,
     /// Work to do.
-    replacements: ReplacementMap<'tcx>,
+    replacements: &'ll ReplacementMap<'tcx>,
     /// This is used to check that we are not leaving references to replaced locals behind.
     all_dead_locals: BitSet<Local>,
     patch: MirPatch<'tcx>,
 }
 
-impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
-    fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
-        if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
-            let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
-            let local = self.replacements.fields.get(&pr)?;
-            Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
-        } else {
-            None
+impl<'tcx> ReplacementVisitor<'tcx, '_> {
+    fn gather_debug_info_fragments(&self, local: Local) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
+        let mut fragments = Vec::new();
+        let parts = self.replacements.place_fragments(local.into())?;
+        for (field, ty, replacement_local) in parts {
+            fragments.push(VarDebugInfoFragment {
+                projection: vec![PlaceElem::Field(field, ty)],
+                contents: Place::from(replacement_local),
+            });
         }
+        Some(fragments)
     }
 }
 
@@ -246,12 +237,21 @@ fn tcx(&self) -> TyCtxt<'tcx> {
         self.tcx
     }
 
+    fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
+        if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
+            *place = repl
+        } else {
+            self.super_place(place, context, location)
+        }
+    }
+
     #[instrument(level = "trace", skip(self))]
     fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
         match statement.kind {
+            // Duplicate storage and deinit statements, as they pretty much apply to all fields.
             StatementKind::StorageLive(l) => {
-                if let Some(final_locals) = &self.replacements.fragments[l] {
-                    for &(_, fl) in final_locals {
+                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+                    for (_, _, fl) in final_locals {
                         self.patch.add_statement(location, StatementKind::StorageLive(fl));
                     }
                     statement.make_nop();
@@ -259,8 +259,8 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
                 return;
             }
             StatementKind::StorageDead(l) => {
-                if let Some(final_locals) = &self.replacements.fragments[l] {
-                    for &(_, fl) in final_locals {
+                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+                    for (_, _, fl) in final_locals {
                         self.patch.add_statement(location, StatementKind::StorageDead(fl));
                     }
                     statement.make_nop();
@@ -269,7 +269,7 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
             }
             StatementKind::Deinit(box place) => {
                 if let Some(final_locals) = self.replacements.place_fragments(place) {
-                    for &(_, fl) in final_locals {
+                    for (_, _, fl) in final_locals {
                         self.patch
                             .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
                     }
@@ -278,48 +278,80 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
                 }
             }
 
-            StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref operands))) => {
-                if let Some(final_locals) = self.replacements.place_fragments(place) {
-                    for &(projection, fl) in final_locals {
-                        let &[PlaceElem::Field(index, _)] = projection else { bug!() };
-                        let index = index.as_usize();
-                        let rvalue = Rvalue::Use(operands[index].clone());
-                        self.patch.add_statement(
-                            location,
-                            StatementKind::Assign(Box::new((fl.into(), rvalue))),
-                        );
+            // We have `a = Struct { 0: x, 1: y, .. }`.
+            // We replace it by
+            // ```
+            // a_0 = x
+            // a_1 = y
+            // ...
+            // ```
+            StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
+                if let Some(local) = place.as_local()
+                    && let Some(final_locals) = &self.replacements.fragments[local]
+                {
+                    // This is ok as we delete the statement later.
+                    let operands = std::mem::take(operands);
+                    for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
+                        if let Some((_, new_local)) = opt_ty_local {
+                            // Replace mentions of SROA'd locals that appear in the operand.
+                            self.visit_operand(&mut operand, location);
+
+                            let rvalue = Rvalue::Use(operand);
+                            self.patch.add_statement(
+                                location,
+                                StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+                            );
+                        }
                     }
                     statement.make_nop();
                     return;
                 }
             }
 
+            // We have `a = some constant`
+            // We add the projections.
+            // ```
+            // a_0 = a.0
+            // a_1 = a.1
+            // ...
+            // ```
+            // ConstProp will pick up the pieces and replace them by actual constants.
             StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
                 if let Some(final_locals) = self.replacements.place_fragments(place) {
-                    for &(projection, fl) in final_locals {
-                        let rvalue =
-                            Rvalue::Use(Operand::Move(place.project_deeper(projection, self.tcx)));
+                    for (field, ty, new_local) in final_locals {
+                        let rplace = self.tcx.mk_place_field(place, field, ty);
+                        let rvalue = Rvalue::Use(Operand::Move(rplace));
                         self.patch.add_statement(
                             location,
-                            StatementKind::Assign(Box::new((fl.into(), rvalue))),
+                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
                         );
                     }
-                    self.all_dead_locals.remove(place.local);
+                    // We still need `place.local` to exist, so don't make it nop.
                     return;
                 }
             }
 
+            // We have `a = move? place`
+            // We replace it by
+            // ```
+            // a_0 = move? place.0
+            // a_1 = move? place.1
+            // ...
+            // ```
             StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
-                let (rplace, copy) = match op {
+                let (rplace, copy) = match *op {
                     Operand::Copy(rplace) => (rplace, true),
                     Operand::Move(rplace) => (rplace, false),
                     Operand::Constant(_) => bug!(),
                 };
                 if let Some(final_locals) = self.replacements.place_fragments(lhs) {
-                    for &(projection, fl) in final_locals {
-                        let rplace = rplace.project_deeper(projection, self.tcx);
+                    for (field, ty, new_local) in final_locals {
+                        let rplace = self.tcx.mk_place_field(rplace, field, ty);
                         debug!(?rplace);
-                        let rplace = self.replace_place(rplace.as_ref()).unwrap_or(rplace);
+                        let rplace = self
+                            .replacements
+                            .replace_place(self.tcx, rplace.as_ref())
+                            .unwrap_or(rplace);
                         debug!(?rplace);
                         let rvalue = if copy {
                             Rvalue::Use(Operand::Copy(rplace))
@@ -328,7 +360,7 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
                         };
                         self.patch.add_statement(
                             location,
-                            StatementKind::Assign(Box::new((fl.into(), rvalue))),
+                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
                         );
                     }
                     statement.make_nop();
@@ -341,22 +373,14 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
         self.super_statement(statement, location)
     }
 
-    fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
-        if let Some(repl) = self.replace_place(place.as_ref()) {
-            *place = repl
-        } else {
-            self.super_place(place, context, location)
-        }
-    }
-
     #[instrument(level = "trace", skip(self))]
     fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
         match &mut var_debug_info.value {
             VarDebugInfoContents::Place(ref mut place) => {
-                if let Some(repl) = self.replace_place(place.as_ref()) {
+                if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
                     *place = repl;
-                } else if let Some(fragments) =
-                    self.replacements.gather_debug_info_fragments(place.as_ref())
+                } else if let Some(local) = place.as_local()
+                    && let Some(fragments) = self.gather_debug_info_fragments(local)
                 {
                     let ty = place.ty(self.local_decls, self.tcx).ty;
                     var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
@@ -367,12 +391,13 @@ fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
                 debug!(?fragments);
                 fragments
                     .drain_filter(|fragment| {
-                        if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
+                        if let Some(repl) =
+                            self.replacements.replace_place(self.tcx, fragment.contents.as_ref())
+                        {
                             fragment.contents = repl;
                             false
-                        } else if let Some(frg) = self
-                            .replacements
-                            .gather_debug_info_fragments(fragment.contents.as_ref())
+                        } else if let Some(local) = fragment.contents.as_local()
+                            && let Some(frg) = self.gather_debug_info_fragments(local)
                         {
                             new_fragments.extend(frg.into_iter().map(|mut f| {
                                 f.projection.splice(0..0, fragment.projection.iter().copied());