]> git.lizzy.rs Git - rust.git/commitdiff
when suspending, we need to specify for which lifetime to recover
authorRalf Jung <post@ralfj.de>
Fri, 21 Jul 2017 19:43:09 +0000 (12:43 -0700)
committerRalf Jung <post@ralfj.de>
Sun, 30 Jul 2017 08:11:59 +0000 (01:11 -0700)
This matters if the lvalues that is suspended involves Deref'ing a reference --
that reference's lifetime will then not be in the type any more

src/librustc/hir/mod.rs
src/librustc/ich/impls_mir.rs
src/librustc/mir/mod.rs
src/librustc/mir/visit.rs
src/librustc_mir/transform/add_validation.rs

index a3a133daa09c460bb7a3beec60612ce702a6aeb5..cc0d49c1a3630fe5b498234edb10ab3be4d1aeda 100644 (file)
@@ -684,6 +684,16 @@ pub enum Mutability {
     MutImmutable,
 }
 
+impl Mutability {
+    /// Return MutMutable only if both arguments are mutable.
+    pub fn and(self, other: Self) -> Self {
+        match self {
+            MutMutable => other,
+            MutImmutable => MutImmutable,
+        }
+    }
+}
+
 #[derive(Clone, PartialEq, Eq, RustcEncodable, RustcDecodable, Hash, Debug, Copy)]
 pub enum BinOp_ {
     /// The `+` operator (addition)
index eb0c62a1161806b0343a208b0929faa8a2cbd809..dc41f981ed57b3dbe2a5d1e41d20789087ee4872 100644 (file)
@@ -243,6 +243,8 @@ fn hash_stable<W: StableHasherResult>(&self,
     }
 }
 
+impl_stable_hash_for!(struct mir::ValidationOperand<'tcx> { lval, ty, re, mutbl });
+
 impl_stable_hash_for!(enum mir::ValidationOp { Acquire, Release, Suspend(extent) });
 
 impl<'a, 'gcx, 'tcx> HashStable<StableHashingContext<'a, 'gcx, 'tcx>> for mir::Lvalue<'tcx> {
index dcab476ec23d556cc5febc71d882ba27112d9557..4655f8a9c15ec2edc479df4b50e36eecbc4743ea 100644 (file)
@@ -25,7 +25,7 @@
 use ty::fold::{TypeFoldable, TypeFolder, TypeVisitor};
 use util::ppaux;
 use rustc_back::slice;
-use hir::InlineAsm;
+use hir::{self, InlineAsm};
 use std::ascii;
 use std::borrow::{Cow};
 use std::cell::Ref;
@@ -826,7 +826,7 @@ pub enum StatementKind<'tcx> {
     },
 
     /// Assert the given lvalues to be valid inhabitants of their type.
-    Validate(ValidationOp, Vec<(Ty<'tcx>, Lvalue<'tcx>)>),
+    Validate(ValidationOp, Vec<ValidationOperand<'tcx>>),
 
     /// Mark one terminating point of an extent (i.e. static region).
     /// (The starting point(s) arise implicitly from borrows.)
@@ -855,6 +855,28 @@ fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
     }
 }
 
+#[derive(Clone, RustcEncodable, RustcDecodable)]
+pub struct ValidationOperand<'tcx> {
+    pub lval: Lvalue<'tcx>,
+    pub ty: Ty<'tcx>,
+    pub re: Option<CodeExtent>,
+    pub mutbl: hir::Mutability,
+}
+
+impl<'tcx> Debug for ValidationOperand<'tcx> {
+    fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
+        write!(fmt, "{:?}@{:?}", self.lval, self.ty)?;
+        if let Some(ce) = self.re {
+            // (reuse lifetime rendering policy from ppaux.)
+            write!(fmt, "/{}", ty::ReScope(ce))?;
+        }
+        if let hir::MutImmutable = self.mutbl {
+            write!(fmt, " (imm)")?;
+        }
+        Ok(())
+    }
+}
+
 impl<'tcx> Debug for Statement<'tcx> {
     fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
         use self::StatementKind::*;
@@ -1505,6 +1527,21 @@ fn super_visit_with<V: TypeVisitor<'tcx>>(&self, visitor: &mut V) -> bool {
     }
 }
 
+impl<'tcx> TypeFoldable<'tcx> for ValidationOperand<'tcx> {
+    fn super_fold_with<'gcx: 'tcx, F: TypeFolder<'gcx, 'tcx>>(&self, folder: &mut F) -> Self {
+        ValidationOperand {
+            lval: self.lval.fold_with(folder),
+            ty: self.ty.fold_with(folder),
+            re: self.re,
+            mutbl: self.mutbl,
+        }
+    }
+
+    fn super_visit_with<V: TypeVisitor<'tcx>>(&self, visitor: &mut V) -> bool {
+        self.lval.visit_with(visitor) || self.ty.visit_with(visitor)
+    }
+}
+
 impl<'tcx> TypeFoldable<'tcx> for Statement<'tcx> {
     fn super_fold_with<'gcx: 'tcx, F: TypeFolder<'gcx, 'tcx>>(&self, folder: &mut F) -> Self {
         use mir::StatementKind::*;
@@ -1531,7 +1568,7 @@ fn super_fold_with<'gcx: 'tcx, F: TypeFolder<'gcx, 'tcx>>(&self, folder: &mut F)
 
             Validate(ref op, ref lvals) =>
                 Validate(op.clone(),
-                         lvals.iter().map(|ty_and_lval| ty_and_lval.fold_with(folder)).collect()),
+                         lvals.iter().map(|operand| operand.fold_with(folder)).collect()),
 
             Nop => Nop,
         };
index 5284a6132396be545ff1f101f934cfd38fc5f167..a05007503cefb48008b58cfa5fd10d2913e161c3 100644 (file)
@@ -334,9 +334,10 @@ fn super_statement(&mut self,
                     }
                     StatementKind::EndRegion(_) => {}
                     StatementKind::Validate(_, ref $($mutability)* lvalues) => {
-                        for & $($mutability)* (ref $($mutability)* ty, ref $($mutability)* lvalue) in lvalues {
-                            self.visit_ty(ty, Lookup::Loc(location));
-                            self.visit_lvalue(lvalue, LvalueContext::Validate, location);
+                        for operand in lvalues {
+                            self.visit_lvalue(& $($mutability)* operand.lval,
+                                              LvalueContext::Validate, location);
+                            self.visit_ty(& $($mutability)* operand.ty, Lookup::Loc(location));
                         }
                     }
                     StatementKind::SetDiscriminant{ ref $($mutability)* lvalue, .. } => {
index b79c1a2d6fdb2674340a918e918bce75214a057a..1fe16fb98f22591cc2889898367a4628b8cad124 100644 (file)
 //! of MIR building, and only after this pass we think of the program has having the
 //! normal MIR semantics.
 
-use rustc::ty::{TyCtxt, RegionKind};
+use rustc::ty::{self, TyCtxt, RegionKind};
+use rustc::hir;
 use rustc::mir::*;
 use rustc::mir::transform::{MirPass, MirSource};
+use rustc::middle::region::CodeExtent;
 
 pub struct AddValidation;
 
-
-fn is_lvalue_shared<'a, 'tcx, D>(lval: &Lvalue<'tcx>, local_decls: &D, tcx: TyCtxt<'a, 'tcx, 'tcx>) -> bool
+/// Determine the "context" of the lval: Mutability and region.
+fn lval_context<'a, 'tcx, D>(
+    lval: &Lvalue<'tcx>,
+    local_decls: &D,
+    tcx: TyCtxt<'a, 'tcx, 'tcx>
+) -> (Option<CodeExtent>, hir::Mutability)
     where D: HasLocalDecls<'tcx>
 {
     use rustc::mir::Lvalue::*;
 
     match *lval {
-        Local { .. } => false,
-        Static(_) => true,
+        Local { .. } => (None, hir::MutMutable),
+        Static(_) => (None, hir::MutImmutable),
         Projection(ref proj) => {
-            // If the base is shared, things stay shared
-            if is_lvalue_shared(&proj.base, local_decls, tcx) {
-                return true;
-            }
-            // A Deref projection may make things shared
             match proj.elem {
                 ProjectionElem::Deref => {
-                    // Computing the inside the recursion makes this quadratic.  We don't expect deep paths though.
+                    // Computing the inside the recursion makes this quadratic.
+                    // We don't expect deep paths though.
                     let ty = proj.base.ty(local_decls, tcx).to_ty(tcx);
-                    !ty.is_mutable_pointer()
+                    // A Deref projection may restrict the context, this depends on the type
+                    // being deref'd.
+                    let context = match ty.sty {
+                        ty::TyRef(re, tam) => {
+                            let re = match re {
+                                &RegionKind::ReScope(ce) => Some(ce),
+                                &RegionKind::ReErased =>
+                                    bug!("AddValidation pass must be run before erasing lifetimes"),
+                                _ => None
+                            };
+                            (re, tam.mutbl)
+                        }
+                        ty::TyRawPtr(_) =>
+                            // There is no guarantee behind even a mutable raw pointer,
+                            // no write locks are acquired there, so we also don't want to
+                            // release any.
+                            (None, hir::MutImmutable),
+                        ty::TyAdt(adt, _) if adt.is_box() => (None, hir::MutMutable),
+                        _ => bug!("Deref on a non-pointer type {:?}", ty),
+                    };
+                    // "Intersect" this restriction with proj.base.
+                    if let (Some(_), hir::MutImmutable) = context {
+                        // This is already as restricted as it gets, no need to even recurse
+                        context
+                    } else {
+                        let base_context = lval_context(&proj.base, local_decls, tcx);
+                        // The region of the outermost Deref is always most restrictive.
+                        let re = context.0.or(base_context.0);
+                        let mutbl = context.1.and(base_context.1);
+                        (re, mutbl)
+                    }
+
                 }
-                _ => false,
+                _ => lval_context(&proj.base, local_decls, tcx),
             }
         }
     }
@@ -52,41 +85,49 @@ fn run_pass<'a, 'tcx>(&self,
                           tcx: TyCtxt<'a, 'tcx, 'tcx>,
                           _: MirSource,
                           mir: &mut Mir<'tcx>) {
+        let local_decls = mir.local_decls.clone(); // TODO: Find a way to get rid of this clone.
+
+        /// Convert an lvalue to a validation operand.
+        let lval_to_operand = |lval: Lvalue<'tcx>| -> ValidationOperand<'tcx> {
+            let (re, mutbl) = lval_context(&lval, &local_decls, tcx);
+            let ty = lval.ty(&local_decls, tcx).to_ty(tcx);
+            ValidationOperand { lval, ty, re, mutbl }
+        };
+
         // PART 1
         // Add an AcquireValid at the beginning of the start block.
         if mir.arg_count > 0 {
             let acquire_stmt = Statement {
                 source_info: SourceInfo {
                     scope: ARGUMENT_VISIBILITY_SCOPE,
-                    span: mir.span, // TODO: Consider using just the span covering the function argument declaration
+                    span: mir.span, // TODO: Consider using just the span covering the function
+                                    // argument declaration.
                 },
                 kind: StatementKind::Validate(ValidationOp::Acquire,
                     // Skip return value, go over all the arguments
                     mir.local_decls.iter_enumerated().skip(1).take(mir.arg_count)
-                    .map(|(local, local_decl)| (local_decl.ty, Lvalue::Local(local))).collect()
+                    .map(|(local, _)| lval_to_operand(Lvalue::Local(local))).collect()
                 )
             };
             mir.basic_blocks_mut()[START_BLOCK].statements.insert(0, acquire_stmt);
         }
 
         // PART 2
-        // Add ReleaseValid/AcquireValid around function call terminators.  We don't use a visitor because
-        // we need to access the block that a Call jumps to.
-        let mut returns : Vec<(SourceInfo, Lvalue<'tcx>, BasicBlock)> = Vec::new(); // Here we collect the destinations.
-        let local_decls = mir.local_decls.clone(); // TODO: Find a way to get rid of this clone.
+        // Add ReleaseValid/AcquireValid around function call terminators.  We don't use a visitor
+        // because we need to access the block that a Call jumps to.
+        let mut returns : Vec<(SourceInfo, Lvalue<'tcx>, BasicBlock)> = Vec::new();
         for block_data in mir.basic_blocks_mut() {
             match block_data.terminator {
-                Some(Terminator { kind: TerminatorKind::Call { ref args, ref destination, .. }, source_info }) => {
+                Some(Terminator { kind: TerminatorKind::Call { ref args, ref destination, .. },
+                                  source_info }) => {
                     // Before the call: Release all arguments
                     let release_stmt = Statement {
                         source_info,
                         kind: StatementKind::Validate(ValidationOp::Release,
                             args.iter().filter_map(|op| {
                                 match op {
-                                    &Operand::Consume(ref lval) => {
-                                        let ty = lval.ty(&local_decls, tcx).to_ty(tcx);
-                                        Some((ty, lval.clone()))
-                                    },
+                                    &Operand::Consume(ref lval) =>
+                                        Some(lval_to_operand(lval.clone())),
                                     &Operand::Constant(..) => { None },
                                 }
                             }).collect())
@@ -97,13 +138,15 @@ fn run_pass<'a, 'tcx>(&self,
                         returns.push((source_info, destination.0.clone(), destination.1));
                     }
                 }
-                Some(Terminator { kind: TerminatorKind::Drop { location: ref lval, .. }, source_info }) |
-                Some(Terminator { kind: TerminatorKind::DropAndReplace { location: ref lval, .. }, source_info }) => {
+                Some(Terminator { kind: TerminatorKind::Drop { location: ref lval, .. },
+                                  source_info }) |
+                Some(Terminator { kind: TerminatorKind::DropAndReplace { location: ref lval, .. },
+                                  source_info }) => {
                     // Before the call: Release all arguments
-                    let ty = lval.ty(&local_decls, tcx).to_ty(tcx);
                     let release_stmt = Statement {
                         source_info,
-                        kind: StatementKind::Validate(ValidationOp::Release, vec![(ty, lval.clone())])
+                        kind: StatementKind::Validate(ValidationOp::Release,
+                                vec![lval_to_operand(lval.clone())]),
                     };
                     block_data.statements.push(release_stmt);
                     // drop doesn't return anything, so we need no acquire.
@@ -115,20 +158,20 @@ fn run_pass<'a, 'tcx>(&self,
         }
         // Now we go over the returns we collected to acquire the return values.
         for (source_info, dest_lval, dest_block) in returns {
-            let ty = dest_lval.ty(&local_decls, tcx).to_ty(tcx);
             let acquire_stmt = Statement {
                 source_info,
-                kind: StatementKind::Validate(ValidationOp::Acquire, vec![(ty, dest_lval)])
+                kind: StatementKind::Validate(ValidationOp::Acquire,
+                        vec![lval_to_operand(dest_lval)]),
             };
             mir.basic_blocks_mut()[dest_block].statements.insert(0, acquire_stmt);
         }
 
         // PART 3
-        // Add ReleaseValid/AcquireValid around Ref.  Again an iterator does not seem very suited as
-        // we need to add new statements before and after each Ref.
+        // Add ReleaseValid/AcquireValid around Ref.  Again an iterator does not seem very suited
+        // as we need to add new statements before and after each Ref.
         for block_data in mir.basic_blocks_mut() {
-            // We want to insert statements around Ref commands as we iterate.  To this end, we iterate backwards
-            // using indices.
+            // We want to insert statements around Ref commands as we iterate.  To this end, we
+            // iterate backwards using indices.
             for i in (0..block_data.statements.len()).rev() {
                 let (dest_lval, re, src_lval) = match block_data.statements[i].kind {
                     StatementKind::Assign(ref dest_lval, Rvalue::Ref(re, _, ref src_lval)) => {
@@ -137,27 +180,25 @@ fn run_pass<'a, 'tcx>(&self,
                     _ => continue,
                 };
                 // So this is a ref, and we got all the data we wanted.
-                let dest_ty = dest_lval.ty(&local_decls, tcx).to_ty(tcx);
                 let acquire_stmt = Statement {
                     source_info: block_data.statements[i].source_info,
-                    kind: StatementKind::Validate(ValidationOp::Acquire, vec![(dest_ty, dest_lval)]),
+                    kind: StatementKind::Validate(ValidationOp::Acquire,
+                            vec![lval_to_operand(dest_lval)]),
                 };
                 block_data.statements.insert(i+1, acquire_stmt);
 
-                // The source is released until the region of the borrow ends -- but not if it is shared.
-                if !is_lvalue_shared(&src_lval, &local_decls, tcx) {
-                    let src_ty = src_lval.ty(&local_decls, tcx).to_ty(tcx);
-                    let op = match re {
-                        &RegionKind::ReScope(ce) => ValidationOp::Suspend(ce),
-                        &RegionKind::ReErased => bug!("AddValidation pass must be run before erasing lifetimes"),
-                        _ => ValidationOp::Release,
-                    };
-                    let release_stmt = Statement {
-                        source_info: block_data.statements[i].source_info,
-                        kind: StatementKind::Validate(op, vec![(src_ty, src_lval)]),
-                    };
-                    block_data.statements.insert(i, release_stmt);
-                }
+                // The source is released until the region of the borrow ends.
+                let op = match re {
+                    &RegionKind::ReScope(ce) => ValidationOp::Suspend(ce),
+                    &RegionKind::ReErased =>
+                        bug!("AddValidation pass must be run before erasing lifetimes"),
+                    _ => ValidationOp::Release,
+                };
+                let release_stmt = Statement {
+                    source_info: block_data.statements[i].source_info,
+                    kind: StatementKind::Validate(op, vec![lval_to_operand(src_lval)]),
+                };
+                block_data.statements.insert(i, release_stmt);
             }
         }
     }