]> git.lizzy.rs Git - rust.git/commitdiff
Don't inline virtual calls (take 2)
authorWesley Wiser <wwiser@gmail.com>
Wed, 7 Nov 2018 03:31:09 +0000 (22:31 -0500)
committerWesley Wiser <wwiser@gmail.com>
Sat, 10 Nov 2018 03:11:40 +0000 (22:11 -0500)
When I fixed the previous mis-optimizations, I didn't realize there were
actually two different places where we mutate `callsites` and both of
them should have the same behavior.

As a result, if a function was inlined and that function contained
virtual function calls, they were incorrectly being inlined. I also
added a test case which covers this.

src/librustc_mir/transform/inline.rs
src/test/mir-opt/inline-trait-method_2.rs [new file with mode: 0644]

index 2db3bbda3233bf8c895e0352a587511b4b92ceb2..1cce0de5152fda7bd649b1b92c80235f9a2958d6 100644 (file)
@@ -19,7 +19,7 @@
 
 use rustc::mir::*;
 use rustc::mir::visit::*;
-use rustc::ty::{self, Instance, InstanceDef, Ty, TyCtxt};
+use rustc::ty::{self, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
 use rustc::ty::subst::{Subst,Substs};
 
 use std::collections::VecDeque;
@@ -85,39 +85,16 @@ fn run_pass(&self, caller_mir: &mut Mir<'tcx>) {
         // Only do inlining into fn bodies.
         let id = self.tcx.hir.as_local_node_id(self.source.def_id).unwrap();
         let body_owner_kind = self.tcx.hir.body_owner_kind(id);
+
         if let (hir::BodyOwnerKind::Fn, None) = (body_owner_kind, self.source.promoted) {
 
             for (bb, bb_data) in caller_mir.basic_blocks().iter_enumerated() {
-                // Don't inline calls that are in cleanup blocks.
-                if bb_data.is_cleanup { continue; }
-
-                // Only consider direct calls to functions
-                let terminator = bb_data.terminator();
-                if let TerminatorKind::Call {
-                    func: ref op, .. } = terminator.kind {
-                        if let ty::FnDef(callee_def_id, substs) = op.ty(caller_mir, self.tcx).sty {
-                            if let Some(instance) = Instance::resolve(self.tcx,
-                                                                      param_env,
-                                                                      callee_def_id,
-                                                                      substs) {
-                                let is_virtual =
-                                    if let InstanceDef::Virtual(..) = instance.def {
-                                        true
-                                    } else {
-                                        false
-                                    };
-
-                                if !is_virtual {
-                                    callsites.push_back(CallSite {
-                                        callee: instance.def_id(),
-                                        substs: instance.substs,
-                                        bb,
-                                        location: terminator.source_info
-                                    });
-                                }
-                            }
-                        }
-                    }
+                if let Some(callsite) = self.get_valid_function_call(bb,
+                                                                     bb_data,
+                                                                     caller_mir,
+                                                                     param_env) {
+                    callsites.push_back(callsite);
+                }
             }
         } else {
             return;
@@ -163,20 +140,13 @@ fn run_pass(&self, caller_mir: &mut Mir<'tcx>) {
 
                 // Add callsites from inlined function
                 for (bb, bb_data) in caller_mir.basic_blocks().iter_enumerated().skip(start) {
-                    // Only consider direct calls to functions
-                    let terminator = bb_data.terminator();
-                    if let TerminatorKind::Call {
-                        func: Operand::Constant(ref f), .. } = terminator.kind {
-                        if let ty::FnDef(callee_def_id, substs) = f.ty.sty {
-                            // Don't inline the same function multiple times.
-                            if callsite.callee != callee_def_id {
-                                callsites.push_back(CallSite {
-                                    callee: callee_def_id,
-                                    substs,
-                                    bb,
-                                    location: terminator.source_info
-                                });
-                            }
+                    if let Some(new_callsite) = self.get_valid_function_call(bb,
+                                                                             bb_data,
+                                                                             caller_mir,
+                                                                             param_env) {
+                        // Don't inline the same function multiple times.
+                        if callsite.callee != new_callsite.callee {
+                            callsites.push_back(new_callsite);
                         }
                     }
                 }
@@ -198,6 +168,40 @@ fn run_pass(&self, caller_mir: &mut Mir<'tcx>) {
         }
     }
 
+    fn get_valid_function_call(&self,
+                               bb: BasicBlock,
+                               bb_data: &BasicBlockData<'tcx>,
+                               caller_mir: &Mir<'tcx>,
+                               param_env: ParamEnv<'tcx>,
+    ) -> Option<CallSite<'tcx>> {
+        // Don't inline calls that are in cleanup blocks.
+        if bb_data.is_cleanup { return None; }
+
+        // Only consider direct calls to functions
+        let terminator = bb_data.terminator();
+        if let TerminatorKind::Call { func: ref op, .. } = terminator.kind {
+            if let ty::FnDef(callee_def_id, substs) = op.ty(caller_mir, self.tcx).sty {
+                let instance = Instance::resolve(self.tcx,
+                                                 param_env,
+                                                 callee_def_id,
+                                                 substs)?;
+
+                if let InstanceDef::Virtual(..) = instance.def {
+                    return None;
+                }
+
+                return Some(CallSite {
+                    callee: instance.def_id(),
+                    substs: instance.substs,
+                    bb,
+                    location: terminator.source_info
+                });
+            }
+        }
+
+        None
+    }
+
     fn consider_optimizing(&self,
                            callsite: CallSite<'tcx>,
                            callee_mir: &Mir<'tcx>)
diff --git a/src/test/mir-opt/inline-trait-method_2.rs b/src/test/mir-opt/inline-trait-method_2.rs
new file mode 100644 (file)
index 0000000..aa756f4
--- /dev/null
@@ -0,0 +1,36 @@
+// compile-flags: -Z span_free_formats -Z mir-opt-level=3
+
+#[inline]
+fn test(x: &dyn X) -> bool {
+    x.y()
+}
+
+fn test2(x: &dyn X) -> bool {
+    test(x)
+}
+
+trait X {
+    fn y(&self) -> bool {
+        false
+    }
+}
+
+impl X for () {
+    fn y(&self) -> bool {
+        true
+    }
+}
+
+fn main() {
+    println!("Should be true: {}", test2(&()));
+}
+
+// END RUST SOURCE
+// START rustc.test2.Inline.after.mir
+// ...
+// bb0: {
+// ...
+//     _0 = const X::y(move _2) -> bb1;
+// }
+// ...
+// END rustc.test2.Inline.after.mir