]> git.lizzy.rs Git - rust.git/commitdiff
Make SROA expand assignments.
authorCamille GILLOT <gillot.camille@gmail.com>
Sun, 5 Feb 2023 09:31:27 +0000 (09:31 +0000)
committerCamille GILLOT <gillot.camille@gmail.com>
Sun, 5 Feb 2023 11:42:11 +0000 (11:42 +0000)
compiler/rustc_mir_transform/src/sroa.rs
tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff [new file with mode: 0644]
tests/mir-opt/sroa.escaping.ScalarReplacementOfAggregates.diff
tests/mir-opt/sroa.flat.ScalarReplacementOfAggregates.diff
tests/mir-opt/sroa.ref_copies.ScalarReplacementOfAggregates.diff [new file with mode: 0644]
tests/mir-opt/sroa.rs

index 2118e3c55222420e64db80b1cb50a51449325da6..f6609704d25d4b2c2caac1ff21a511147326341b 100644 (file)
@@ -78,10 +78,15 @@ fn visit_assign(
             rvalue: &Rvalue<'tcx>,
             location: Location,
         ) {
-            if lvalue.as_local().is_some() && let Rvalue::Aggregate(..) = rvalue {
-                // Aggregate assignments are expanded in run_pass.
-                self.visit_rvalue(rvalue, location);
-                return;
+            if lvalue.as_local().is_some() {
+                match rvalue {
+                    // Aggregate assignments are expanded in run_pass.
+                    Rvalue::Aggregate(..) | Rvalue::Use(..) => {
+                        self.visit_rvalue(rvalue, location);
+                        return;
+                    }
+                    _ => {}
+                }
             }
             self.super_assign(lvalue, rvalue, location)
         }
@@ -195,10 +200,9 @@ fn replace_flattened_locals<'tcx>(
         return;
     }
 
-    let mut fragments = IndexVec::new();
+    let mut fragments = IndexVec::<_, Option<Vec<_>>>::from_elem(None, &body.local_decls);
     for (k, v) in &replacements.fields {
-        fragments.ensure_contains_elem(k.local, || Vec::new());
-        fragments[k.local].push((k.projection, *v));
+        fragments[k.local].get_or_insert_default().push((k.projection, *v));
     }
     debug!(?fragments);
 
@@ -235,7 +239,7 @@ struct ReplacementVisitor<'tcx, 'll> {
     all_dead_locals: BitSet<Local>,
     /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
     /// and deinit statement and debuginfo.
-    fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
+    fragments: IndexVec<Local, Option<Vec<(&'tcx [PlaceElem<'tcx>], Local)>>>,
     patch: MirPatch<'tcx>,
 }
 
@@ -243,9 +247,9 @@ impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
     fn gather_debug_info_fragments(
         &self,
         place: PlaceRef<'tcx>,
-    ) -> Vec<VarDebugInfoFragment<'tcx>> {
+    ) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
         let mut fragments = Vec::new();
-        let parts = &self.fragments[place.local];
+        let Some(parts) = &self.fragments[place.local] else { return None };
         for (proj, replacement_local) in parts {
             if proj.starts_with(place.projection) {
                 fragments.push(VarDebugInfoFragment {
@@ -254,7 +258,7 @@ fn gather_debug_info_fragments(
                 });
             }
         }
-        fragments
+        Some(fragments)
     }
 
     fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
@@ -276,8 +280,7 @@ fn tcx(&self) -> TyCtxt<'tcx> {
     fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
         match statement.kind {
             StatementKind::StorageLive(l) => {
-                if self.all_dead_locals.contains(l) {
-                    let final_locals = &self.fragments[l];
+                if let Some(final_locals) = &self.fragments[l] {
                     for &(_, fl) in final_locals {
                         self.patch.add_statement(location, StatementKind::StorageLive(fl));
                     }
@@ -286,8 +289,7 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
                 return;
             }
             StatementKind::StorageDead(l) => {
-                if self.all_dead_locals.contains(l) {
-                    let final_locals = &self.fragments[l];
+                if let Some(final_locals) = &self.fragments[l] {
                     for &(_, fl) in final_locals {
                         self.patch.add_statement(location, StatementKind::StorageDead(fl));
                     }
@@ -297,9 +299,8 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
             }
             StatementKind::Deinit(box ref place) => {
                 if let Some(local) = place.as_local()
-                    && self.all_dead_locals.contains(local)
+                    && let Some(final_locals) = &self.fragments[local]
                 {
-                    let final_locals = &self.fragments[local];
                     for &(_, fl) in final_locals {
                         self.patch.add_statement(
                             location,
@@ -313,9 +314,8 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
 
             StatementKind::Assign(box (ref place, Rvalue::Aggregate(_, ref operands))) => {
                 if let Some(local) = place.as_local()
-                    && self.all_dead_locals.contains(local)
+                    && let Some(final_locals) = &self.fragments[local]
                 {
-                    let final_locals = &self.fragments[local];
                     for &(projection, fl) in final_locals {
                         let &[PlaceElem::Field(index, _)] = projection else { bug!() };
                         let index = index.as_usize();
@@ -330,6 +330,48 @@ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Locatio
                 }
             }
 
+            StatementKind::Assign(box (ref place, Rvalue::Use(Operand::Constant(_)))) => {
+                if let Some(local) = place.as_local()
+                    && let Some(final_locals) = &self.fragments[local]
+                {
+                    for &(projection, fl) in final_locals {
+                        let rvalue = Rvalue::Use(Operand::Move(place.project_deeper(projection, self.tcx)));
+                        self.patch.add_statement(
+                            location,
+                            StatementKind::Assign(Box::new((fl.into(), rvalue))),
+                        );
+                    }
+                    self.all_dead_locals.remove(local);
+                    return;
+                }
+            }
+
+            StatementKind::Assign(box (ref lhs, Rvalue::Use(ref op))) => {
+                let (rplace, copy) = match op {
+                    Operand::Copy(rplace) => (rplace, true),
+                    Operand::Move(rplace) => (rplace, false),
+                    Operand::Constant(_) => bug!(),
+                };
+                if let Some(local) = lhs.as_local()
+                    && let Some(final_locals) = &self.fragments[local]
+                {
+                    for &(projection, fl) in final_locals {
+                        let rplace = rplace.project_deeper(projection, self.tcx);
+                        let rvalue = if copy {
+                            Rvalue::Use(Operand::Copy(rplace))
+                        } else {
+                            Rvalue::Use(Operand::Move(rplace))
+                        };
+                        self.patch.add_statement(
+                            location,
+                            StatementKind::Assign(Box::new((fl.into(), rvalue))),
+                        );
+                    }
+                    statement.make_nop();
+                    return;
+                }
+            }
+
             _ => {}
         }
         self.super_statement(statement, location)
@@ -348,9 +390,8 @@ fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
             VarDebugInfoContents::Place(ref mut place) => {
                 if let Some(repl) = self.replace_place(place.as_ref()) {
                     *place = repl;
-                } else if self.all_dead_locals.contains(place.local) {
+                } else if let Some(fragments) = self.gather_debug_info_fragments(place.as_ref()) {
                     let ty = place.ty(self.local_decls, self.tcx).ty;
-                    let fragments = self.gather_debug_info_fragments(place.as_ref());
                     var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
                 }
             }
@@ -361,8 +402,9 @@ fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
                         if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
                             fragment.contents = repl;
                             true
-                        } else if self.all_dead_locals.contains(fragment.contents.local) {
-                            let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
+                        } else if let Some(frg) =
+                            self.gather_debug_info_fragments(fragment.contents.as_ref())
+                        {
                             new_fragments.extend(frg.into_iter().map(|mut f| {
                                 f.projection.splice(0..0, fragment.projection.iter().copied());
                                 f
diff --git a/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff b/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff
new file mode 100644 (file)
index 0000000..72610de
--- /dev/null
@@ -0,0 +1,48 @@
+- // MIR for `copies` before ScalarReplacementOfAggregates
++ // MIR for `copies` after ScalarReplacementOfAggregates
+  
+  fn copies(_1: Foo) -> () {
+      debug x => _1;                       // in scope 0 at $DIR/sroa.rs:+0:11: +0:12
+      let mut _0: ();                      // return place in scope 0 at $DIR/sroa.rs:+0:19: +0:19
+      let _2: Foo;                         // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
++     let _5: u8;                          // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
++     let _6: &str;                        // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
+      scope 1 {
+-         debug y => _2;                   // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
++         debug y => Foo{ .0 => _5, .2 => _6, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
+          let _3: u8;                      // in scope 1 at $DIR/sroa.rs:+2:9: +2:10
+          scope 2 {
+              debug t => _3;               // in scope 2 at $DIR/sroa.rs:+2:9: +2:10
+              let _4: &str;                // in scope 2 at $DIR/sroa.rs:+3:9: +3:10
+              scope 3 {
+                  debug u => _4;           // in scope 3 at $DIR/sroa.rs:+3:9: +3:10
+              }
+          }
+      }
+  
+      bb0: {
+-         StorageLive(_2);                 // scope 0 at $DIR/sroa.rs:+1:9: +1:10
+-         _2 = _1;                         // scope 0 at $DIR/sroa.rs:+1:13: +1:14
++         StorageLive(_5);                 // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         StorageLive(_6);                 // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         nop;                             // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         _5 = (_1.0: u8);                 // scope 0 at $DIR/sroa.rs:+1:13: +1:14
++         _6 = (_1.2: &str);               // scope 0 at $DIR/sroa.rs:+1:13: +1:14
++         nop;                             // scope 0 at $DIR/sroa.rs:+1:13: +1:14
+          StorageLive(_3);                 // scope 1 at $DIR/sroa.rs:+2:9: +2:10
+-         _3 = (_2.0: u8);                 // scope 1 at $DIR/sroa.rs:+2:13: +2:16
++         _3 = _5;                         // scope 1 at $DIR/sroa.rs:+2:13: +2:16
+          StorageLive(_4);                 // scope 2 at $DIR/sroa.rs:+3:9: +3:10
+-         _4 = (_2.2: &str);               // scope 2 at $DIR/sroa.rs:+3:13: +3:16
++         _4 = _6;                         // scope 2 at $DIR/sroa.rs:+3:13: +3:16
+          _0 = const ();                   // scope 0 at $DIR/sroa.rs:+0:19: +4:2
+          StorageDead(_4);                 // scope 2 at $DIR/sroa.rs:+4:1: +4:2
+          StorageDead(_3);                 // scope 1 at $DIR/sroa.rs:+4:1: +4:2
+-         StorageDead(_2);                 // scope 0 at $DIR/sroa.rs:+4:1: +4:2
++         StorageDead(_5);                 // scope 0 at $DIR/sroa.rs:+4:1: +4:2
++         StorageDead(_6);                 // scope 0 at $DIR/sroa.rs:+4:1: +4:2
++         nop;                             // scope 0 at $DIR/sroa.rs:+4:1: +4:2
+          return;                          // scope 0 at $DIR/sroa.rs:+4:2: +4:2
+      }
+  }
+  
index b01fb6fc91538ccd3a35ccd5d50cc17d68629091..ea7f5007224519e16759e5054db50242eb790155 100644 (file)
@@ -17,7 +17,7 @@
           StorageLive(_5);                 // scope 0 at $DIR/sroa.rs:+2:34: +2:37
           _5 = g() -> bb1;                 // scope 0 at $DIR/sroa.rs:+2:34: +2:37
                                            // mir::Constant
-                                           // + span: $DIR/sroa.rs:78:34: 78:35
+                                           // + span: $DIR/sroa.rs:73:34: 73:35
                                            // + literal: Const { ty: fn() -> u32 {g}, val: Value(<ZST>) }
       }
   
@@ -28,7 +28,7 @@
           _2 = &raw const (*_3);           // scope 0 at $DIR/sroa.rs:+2:7: +2:41
           _1 = f(move _2) -> bb2;          // scope 0 at $DIR/sroa.rs:+2:5: +2:42
                                            // mir::Constant
-                                           // + span: $DIR/sroa.rs:78:5: 78:6
+                                           // + span: $DIR/sroa.rs:73:5: 73:6
                                            // + literal: Const { ty: fn(*const u32) {f}, val: Value(<ZST>) }
       }
   
index 338ce262f1ec9aeec16a2e10ca0882fe94684565..69631fc0213f8f5d31ab407509bdc3beb14adf2d 100644 (file)
@@ -45,7 +45,7 @@
 +         _9 = move _6;                    // scope 0 at $DIR/sroa.rs:+1:30: +1:70
 +         _10 = const "a";                 // scope 0 at $DIR/sroa.rs:+1:30: +1:70
                                            // mir::Constant
-                                           // + span: $DIR/sroa.rs:57:52: 57:55
+                                           // + span: $DIR/sroa.rs:53:52: 53:55
                                            // + literal: Const { ty: &str, val: Value(Slice(..)) }
 +         _11 = move _7;                   // scope 0 at $DIR/sroa.rs:+1:30: +1:70
 +         nop;                             // scope 0 at $DIR/sroa.rs:+1:30: +1:70
diff --git a/tests/mir-opt/sroa.ref_copies.ScalarReplacementOfAggregates.diff b/tests/mir-opt/sroa.ref_copies.ScalarReplacementOfAggregates.diff
new file mode 100644 (file)
index 0000000..1a561a9
--- /dev/null
@@ -0,0 +1,48 @@
+- // MIR for `ref_copies` before ScalarReplacementOfAggregates
++ // MIR for `ref_copies` after ScalarReplacementOfAggregates
+  
+  fn ref_copies(_1: &Foo) -> () {
+      debug x => _1;                       // in scope 0 at $DIR/sroa.rs:+0:15: +0:16
+      let mut _0: ();                      // return place in scope 0 at $DIR/sroa.rs:+0:24: +0:24
+      let _2: Foo;                         // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
++     let _5: u8;                          // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
++     let _6: &str;                        // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
+      scope 1 {
+-         debug y => _2;                   // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
++         debug y => Foo{ .0 => _5, .2 => _6, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
+          let _3: u8;                      // in scope 1 at $DIR/sroa.rs:+2:9: +2:10
+          scope 2 {
+              debug t => _3;               // in scope 2 at $DIR/sroa.rs:+2:9: +2:10
+              let _4: &str;                // in scope 2 at $DIR/sroa.rs:+3:9: +3:10
+              scope 3 {
+                  debug u => _4;           // in scope 3 at $DIR/sroa.rs:+3:9: +3:10
+              }
+          }
+      }
+  
+      bb0: {
+-         StorageLive(_2);                 // scope 0 at $DIR/sroa.rs:+1:9: +1:10
+-         _2 = (*_1);                      // scope 0 at $DIR/sroa.rs:+1:13: +1:15
++         StorageLive(_5);                 // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         StorageLive(_6);                 // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         nop;                             // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         _5 = ((*_1).0: u8);              // scope 0 at $DIR/sroa.rs:+1:13: +1:15
++         _6 = ((*_1).2: &str);            // scope 0 at $DIR/sroa.rs:+1:13: +1:15
++         nop;                             // scope 0 at $DIR/sroa.rs:+1:13: +1:15
+          StorageLive(_3);                 // scope 1 at $DIR/sroa.rs:+2:9: +2:10
+-         _3 = (_2.0: u8);                 // scope 1 at $DIR/sroa.rs:+2:13: +2:16
++         _3 = _5;                         // scope 1 at $DIR/sroa.rs:+2:13: +2:16
+          StorageLive(_4);                 // scope 2 at $DIR/sroa.rs:+3:9: +3:10
+-         _4 = (_2.2: &str);               // scope 2 at $DIR/sroa.rs:+3:13: +3:16
++         _4 = _6;                         // scope 2 at $DIR/sroa.rs:+3:13: +3:16
+          _0 = const ();                   // scope 0 at $DIR/sroa.rs:+0:24: +4:2
+          StorageDead(_4);                 // scope 2 at $DIR/sroa.rs:+4:1: +4:2
+          StorageDead(_3);                 // scope 1 at $DIR/sroa.rs:+4:1: +4:2
+-         StorageDead(_2);                 // scope 0 at $DIR/sroa.rs:+4:1: +4:2
++         StorageDead(_5);                 // scope 0 at $DIR/sroa.rs:+4:1: +4:2
++         StorageDead(_6);                 // scope 0 at $DIR/sroa.rs:+4:1: +4:2
++         nop;                             // scope 0 at $DIR/sroa.rs:+4:1: +4:2
+          return;                          // scope 0 at $DIR/sroa.rs:+4:2: +4:2
+      }
+  }
+  
index ff8deb40d7d5a388c74f71102ae489834ad3a2e3..b80f61600c2660012ca055db55e528c6d81ce843 100644 (file)
@@ -12,17 +12,14 @@ impl Drop for Tag {
     fn drop(&mut self) {}
 }
 
-// EMIT_MIR sroa.dropping.ScalarReplacementOfAggregates.diff
 pub fn dropping() {
     S(Tag(0), Tag(1), Tag(2)).1;
 }
 
-// EMIT_MIR sroa.enums.ScalarReplacementOfAggregates.diff
 pub fn enums(a: usize) -> usize {
     if let Some(a) = Some(a) { a } else { 0 }
 }
 
-// EMIT_MIR sroa.structs.ScalarReplacementOfAggregates.diff
 pub fn structs(a: f32) -> f32 {
     struct U {
         _foo: usize,
@@ -32,7 +29,6 @@ struct U {
     U { _foo: 0, a }.a
 }
 
-// EMIT_MIR sroa.unions.ScalarReplacementOfAggregates.diff
 pub fn unions(a: f32) -> u32 {
     union Repr {
         f: f32,
@@ -41,6 +37,7 @@ union Repr {
     unsafe { Repr { f: a }.u }
 }
 
+#[derive(Copy, Clone)]
 struct Foo {
     a: u8,
     b: (),
@@ -52,7 +49,6 @@ fn g() -> u32 {
     3
 }
 
-// EMIT_MIR sroa.flat.ScalarReplacementOfAggregates.diff
 pub fn flat() {
     let Foo { a, b, c, d } = Foo { a: 5, b: (), c: "a", d: Some(-4) };
     let _ = a;
@@ -72,12 +68,23 @@ fn f(a: *const u32) {
     println!("{}", unsafe { *a.add(2) });
 }
 
-// EMIT_MIR sroa.escaping.ScalarReplacementOfAggregates.diff
 pub fn escaping() {
     // Verify this struct is not flattened.
     f(&Escaping { a: 1, b: 2, c: g() }.a);
 }
 
+fn copies(x: Foo) {
+    let y = x;
+    let t = y.a;
+    let u = y.c;
+}
+
+fn ref_copies(x: &Foo) {
+    let y = *x;
+    let t = y.a;
+    let u = y.c;
+}
+
 fn main() {
     dropping();
     enums(5);
@@ -85,4 +92,15 @@ fn main() {
     unions(5.);
     flat();
     escaping();
+    copies(Foo { a: 5, b: (), c: "a", d: Some(-4) });
+    ref_copies(&Foo { a: 5, b: (), c: "a", d: Some(-4) });
 }
+
+// EMIT_MIR sroa.dropping.ScalarReplacementOfAggregates.diff
+// EMIT_MIR sroa.enums.ScalarReplacementOfAggregates.diff
+// EMIT_MIR sroa.structs.ScalarReplacementOfAggregates.diff
+// EMIT_MIR sroa.unions.ScalarReplacementOfAggregates.diff
+// EMIT_MIR sroa.flat.ScalarReplacementOfAggregates.diff
+// EMIT_MIR sroa.escaping.ScalarReplacementOfAggregates.diff
+// EMIT_MIR sroa.copies.ScalarReplacementOfAggregates.diff
+// EMIT_MIR sroa.ref_copies.ScalarReplacementOfAggregates.diff