From d327fa112b8ca56e8c310a8ec9bf458909beacfe Mon Sep 17 00:00:00 2001 From: Bastian Kauschke Date: Thu, 10 Sep 2020 09:06:30 +0200 Subject: [PATCH] initial working state --- .../rustc_middle/src/mir/abstract_const.rs | 15 + compiler/rustc_middle/src/mir/mod.rs | 1 + compiler/rustc_middle/src/query/mod.rs | 19 ++ compiler/rustc_mir/src/transform/mod.rs | 6 +- compiler/rustc_trait_selection/src/lib.rs | 1 + .../src/traits/const_evaluatable.rs | 256 +++++++++++++++++- .../rustc_trait_selection/src/traits/mod.rs | 14 + .../const_evaluatable_checked/less_than.rs | 14 + 8 files changed, 313 insertions(+), 13 deletions(-) create mode 100644 compiler/rustc_middle/src/mir/abstract_const.rs create mode 100644 src/test/ui/const-generics/const_evaluatable_checked/less_than.rs diff --git a/compiler/rustc_middle/src/mir/abstract_const.rs b/compiler/rustc_middle/src/mir/abstract_const.rs new file mode 100644 index 00000000000..48fe8dafd53 --- /dev/null +++ b/compiler/rustc_middle/src/mir/abstract_const.rs @@ -0,0 +1,15 @@ +//! A subset of a mir body used for const evaluatability checking. +use crate::mir; +use crate::ty; + +/// An index into an `AbstractConst`. +pub type NodeId = usize; + +/// A node of an `AbstractConst`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, HashStable)] +pub enum Node<'tcx> { + Leaf(&'tcx ty::Const<'tcx>), + Binop(mir::BinOp, NodeId, NodeId), + UnaryOp(mir::UnOp, NodeId), + FunctionCall(NodeId, &'tcx [NodeId]), +} diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index 29daf7e9309..be61b676807 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -40,6 +40,7 @@ use self::predecessors::{PredecessorCache, Predecessors}; pub use self::query::*; +pub mod abstract_const; pub mod coverage; pub mod interpret; pub mod mono; diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 15e3110bc85..41b8bb60ef5 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -244,6 +244,25 @@ fn describe_as_module(def_id: LocalDefId, tcx: TyCtxt<'_>) -> String { no_hash } + /// Try to build an abstract representation of the given constant. + query mir_abstract_const( + key: DefId + ) -> Option<&'tcx [mir::abstract_const::Node<'tcx>]> { + desc { + |tcx| "building an abstract representation for {}", tcx.def_path_str(key), + } + } + /// Try to build an abstract representation of the given constant. + query mir_abstract_const_of_const_arg( + key: (LocalDefId, DefId) + ) -> Option<&'tcx [mir::abstract_const::Node<'tcx>]> { + desc { + |tcx| + "building an abstract representation for the const argument {}", + tcx.def_path_str(key.0.to_def_id()), + } + } + query mir_drops_elaborated_and_const_checked( key: ty::WithOptConstParam ) -> &'tcx Steal> { diff --git a/compiler/rustc_mir/src/transform/mod.rs b/compiler/rustc_mir/src/transform/mod.rs index 8025b7c0204..226282fe426 100644 --- a/compiler/rustc_mir/src/transform/mod.rs +++ b/compiler/rustc_mir/src/transform/mod.rs @@ -329,7 +329,11 @@ fn mir_promoted( // this point, before we steal the mir-const result. // Also this means promotion can rely on all const checks having been done. let _ = tcx.mir_const_qualif_opt_const_arg(def); - + let _ = if let Some(param_did) = def.const_param_did { + tcx.mir_abstract_const_of_const_arg((def.did, param_did)) + } else { + tcx.mir_abstract_const(def.did.to_def_id()) + }; let mut body = tcx.mir_const(def).steal(); let mut required_consts = Vec::new(); diff --git a/compiler/rustc_trait_selection/src/lib.rs b/compiler/rustc_trait_selection/src/lib.rs index b5882df4729..da1996b92a6 100644 --- a/compiler/rustc_trait_selection/src/lib.rs +++ b/compiler/rustc_trait_selection/src/lib.rs @@ -12,6 +12,7 @@ #![doc(html_root_url = "https://doc.rust-lang.org/nightly/")] #![feature(bool_to_option)] +#![feature(box_patterns)] #![feature(drain_filter)] #![feature(in_band_lifetimes)] #![feature(crate_visibility_modifier)] diff --git a/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs b/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs index fdb87c085b5..9d74de44d17 100644 --- a/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs +++ b/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs @@ -1,10 +1,17 @@ +#![allow(warnings)] use rustc_hir::def::DefKind; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; use rustc_infer::infer::InferCtxt; +use rustc_middle::mir::abstract_const::{Node, NodeId}; use rustc_middle::mir::interpret::ErrorHandled; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::{self, Rvalue, StatementKind, TerminatorKind}; +use rustc_middle::ty::subst::Subst; use rustc_middle::ty::subst::SubstsRef; -use rustc_middle::ty::{self, TypeFoldable}; +use rustc_middle::ty::{self, TyCtxt, TypeFoldable}; use rustc_session::lint; -use rustc_span::def_id::DefId; +use rustc_span::def_id::{DefId, LocalDefId}; use rustc_span::Span; pub fn is_const_evaluatable<'cx, 'tcx>( @@ -16,18 +23,23 @@ pub fn is_const_evaluatable<'cx, 'tcx>( ) -> Result<(), ErrorHandled> { debug!("is_const_evaluatable({:?}, {:?})", def, substs); if infcx.tcx.features().const_evaluatable_checked { - // FIXME(const_evaluatable_checked): Actually look into generic constants to - // implement const equality. - for pred in param_env.caller_bounds() { - match pred.skip_binders() { - ty::PredicateAtom::ConstEvaluatable(b_def, b_substs) => { - debug!("is_const_evaluatable: caller_bound={:?}, {:?}", b_def, b_substs); - if b_def == def && b_substs == substs { - debug!("is_const_evaluatable: caller_bound ~~> ok"); - return Ok(()); + if let Some(ct) = AbstractConst::new(infcx.tcx, def, substs) { + for pred in param_env.caller_bounds() { + match pred.skip_binders() { + ty::PredicateAtom::ConstEvaluatable(b_def, b_substs) => { + debug!("is_const_evaluatable: caller_bound={:?}, {:?}", b_def, b_substs); + if b_def == def && b_substs == substs { + debug!("is_const_evaluatable: caller_bound ~~> ok"); + return Ok(()); + } else if AbstractConst::new(infcx.tcx, b_def, b_substs) + .map_or(false, |b_ct| try_unify(infcx.tcx, ct, b_ct)) + { + debug!("is_const_evaluatable: abstract_const ~~> ok"); + return Ok(()); + } } + _ => {} // don't care } - _ => {} // don't care } } } @@ -76,3 +88,223 @@ pub fn is_const_evaluatable<'cx, 'tcx>( debug!(?concrete, "is_const_evaluatable"); concrete.map(drop) } + +/// A tree representing an anonymous constant. +/// +/// This is only able to represent a subset of `MIR`, +/// and should not leak any information about desugarings. +#[derive(Clone, Copy)] +pub struct AbstractConst<'tcx> { + pub inner: &'tcx [Node<'tcx>], + pub substs: SubstsRef<'tcx>, +} + +impl AbstractConst<'tcx> { + pub fn new( + tcx: TyCtxt<'tcx>, + def: ty::WithOptConstParam, + substs: SubstsRef<'tcx>, + ) -> Option> { + let inner = match (def.did.as_local(), def.const_param_did) { + (Some(did), Some(param_did)) => { + tcx.mir_abstract_const_of_const_arg((did, param_did))? + } + _ => tcx.mir_abstract_const(def.did)?, + }; + + Some(AbstractConst { inner, substs }) + } + + #[inline] + pub fn subtree(self, node: NodeId) -> AbstractConst<'tcx> { + AbstractConst { inner: &self.inner[..=node], substs: self.substs } + } + + #[inline] + pub fn root(self) -> Node<'tcx> { + self.inner.last().copied().unwrap() + } +} + +struct AbstractConstBuilder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a mir::Body<'tcx>, + nodes: Vec>, + locals: IndexVec, + checked_op_locals: BitSet, +} + +impl<'a, 'tcx> AbstractConstBuilder<'a, 'tcx> { + fn new(tcx: TyCtxt<'tcx>, body: &'a mir::Body<'tcx>) -> Option> { + if body.is_cfg_cyclic() { + return None; + } + + Some(AbstractConstBuilder { + tcx, + body, + nodes: vec![], + locals: IndexVec::from_elem(NodeId::MAX, &body.local_decls), + checked_op_locals: BitSet::new_empty(body.local_decls.len()), + }) + } + + fn add_node(&mut self, n: Node<'tcx>) -> NodeId { + let len = self.nodes.len(); + self.nodes.push(n); + len + } + + fn operand_to_node(&mut self, op: &mir::Operand<'tcx>) -> Option { + debug!("operand_to_node: op={:?}", op); + const ZERO_FIELD: mir::Field = mir::Field::from_usize(0); + match op { + mir::Operand::Copy(p) | mir::Operand::Move(p) => { + if let Some(p) = p.as_local() { + debug_assert!(!self.checked_op_locals.contains(p)); + Some(self.locals[p]) + } else if let &[mir::ProjectionElem::Field(ZERO_FIELD, _)] = p.projection.as_ref() { + // Only allow field accesses on the result of checked operations. + if self.checked_op_locals.contains(p.local) { + Some(self.locals[p.local]) + } else { + None + } + } else { + None + } + } + mir::Operand::Constant(ct) => Some(self.add_node(Node::Leaf(ct.literal))), + } + } + + fn check_binop(op: mir::BinOp) -> bool { + use mir::BinOp::*; + match op { + Add | Sub | Mul | Div | Rem | BitXor | BitAnd | BitOr | Shl | Shr | Eq | Lt | Le + | Ne | Ge | Gt => true, + Offset => false, + } + } + + fn build(mut self) -> Option<&'tcx [Node<'tcx>]> { + let mut block = &self.body.basic_blocks()[mir::START_BLOCK]; + loop { + debug!("AbstractConstBuilder: block={:?}", block); + for stmt in block.statements.iter() { + debug!("AbstractConstBuilder: stmt={:?}", stmt); + match stmt.kind { + StatementKind::Assign(box (ref place, ref rvalue)) => { + let local = place.as_local()?; + match *rvalue { + Rvalue::Use(ref operand) => { + self.locals[local] = self.operand_to_node(operand)?; + } + Rvalue::BinaryOp(op, ref lhs, ref rhs) if Self::check_binop(op) => { + let lhs = self.operand_to_node(lhs)?; + let rhs = self.operand_to_node(rhs)?; + self.locals[local] = self.add_node(Node::Binop(op, lhs, rhs)); + if op.is_checkable() { + bug!("unexpected unchecked checkable binary operation"); + } + } + Rvalue::CheckedBinaryOp(op, ref lhs, ref rhs) + if Self::check_binop(op) => + { + let lhs = self.operand_to_node(lhs)?; + let rhs = self.operand_to_node(rhs)?; + self.locals[local] = self.add_node(Node::Binop(op, lhs, rhs)); + self.checked_op_locals.insert(local); + } + _ => return None, + } + } + _ => return None, + } + } + + debug!("AbstractConstBuilder: terminator={:?}", block.terminator()); + match block.terminator().kind { + TerminatorKind::Goto { target } => { + block = &self.body.basic_blocks()[target]; + } + TerminatorKind::Return => { + warn!(?self.nodes); + return { Some(self.tcx.arena.alloc_from_iter(self.nodes)) }; + } + TerminatorKind::Assert { ref cond, expected: false, target, .. } => { + let p = match cond { + mir::Operand::Copy(p) | mir::Operand::Move(p) => p, + mir::Operand::Constant(_) => bug!("Unexpected assert"), + }; + + const ONE_FIELD: mir::Field = mir::Field::from_usize(1); + debug!("proj: {:?}", p.projection); + if let &[mir::ProjectionElem::Field(ONE_FIELD, _)] = p.projection.as_ref() { + // Only allow asserts checking the result of a checked operation. + if self.checked_op_locals.contains(p.local) { + block = &self.body.basic_blocks()[target]; + continue; + } + } + + return None; + } + _ => return None, + } + } + } +} + +/// Builds an abstract const, do not use this directly, but use `AbstractConst::new` instead. +pub(super) fn mir_abstract_const<'tcx>( + tcx: TyCtxt<'tcx>, + def: ty::WithOptConstParam, +) -> Option<&'tcx [Node<'tcx>]> { + if !tcx.features().const_evaluatable_checked { + None + } else { + let body = tcx.mir_const(def).borrow(); + AbstractConstBuilder::new(tcx, &body)?.build() + } +} + +pub fn try_unify<'tcx>(tcx: TyCtxt<'tcx>, a: AbstractConst<'tcx>, b: AbstractConst<'tcx>) -> bool { + match (a.root(), b.root()) { + (Node::Leaf(a_ct), Node::Leaf(b_ct)) => { + let a_ct = a_ct.subst(tcx, a.substs); + let b_ct = b_ct.subst(tcx, b.substs); + match (a_ct.val, b_ct.val) { + (ty::ConstKind::Param(a_param), ty::ConstKind::Param(b_param)) => { + a_param == b_param + } + (ty::ConstKind::Value(a_val), ty::ConstKind::Value(b_val)) => a_val == b_val, + // If we have `fn a() -> [u8; N + 1]` and `fn b() -> [u8; 1 + M]` + // we do not want to use `assert_eq!(a(), b())` to infer that `N` and `M` have to be `1`. This + // means that we can't do anything with inference variables here. + (ty::ConstKind::Infer(_), _) | (_, ty::ConstKind::Infer(_)) => false, + // FIXME(const_evaluatable_checked): We may want to either actually try + // to evaluate `a_ct` and `b_ct` if they are are fully concrete or something like + // this, for now we just return false here. + _ => false, + } + } + (Node::Binop(a_op, al, ar), Node::Binop(b_op, bl, br)) if a_op == b_op => { + try_unify(tcx, a.subtree(al), b.subtree(bl)) + && try_unify(tcx, a.subtree(ar), b.subtree(br)) + } + (Node::UnaryOp(a_op, av), Node::UnaryOp(b_op, bv)) if a_op == b_op => { + try_unify(tcx, a.subtree(av), b.subtree(bv)) + } + (Node::FunctionCall(a_f, a_args), Node::FunctionCall(b_f, b_args)) + if a_args.len() == b_args.len() => + { + try_unify(tcx, a.subtree(a_f), b.subtree(b_f)) + && a_args + .iter() + .zip(b_args) + .all(|(&an, &bn)| try_unify(tcx, a.subtree(an), b.subtree(bn))) + } + _ => false, + } +} diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs index b72e86a4cbd..2f0b66ec8c9 100644 --- a/compiler/rustc_trait_selection/src/traits/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/mod.rs @@ -552,6 +552,20 @@ pub fn provide(providers: &mut ty::query::Providers) { vtable_methods, type_implements_trait, subst_and_check_impossible_predicates, + mir_abstract_const: |tcx, def_id| { + let def_id = def_id.as_local()?; // We do not store failed AbstractConst's. + if let Some(def) = ty::WithOptConstParam::try_lookup(def_id, tcx) { + tcx.mir_abstract_const_of_const_arg(def) + } else { + const_evaluatable::mir_abstract_const(tcx, ty::WithOptConstParam::unknown(def_id)) + } + }, + mir_abstract_const_of_const_arg: |tcx, (did, param_did)| { + const_evaluatable::mir_abstract_const( + tcx, + ty::WithOptConstParam { did, const_param_did: Some(param_did) }, + ) + }, ..*providers }; } diff --git a/src/test/ui/const-generics/const_evaluatable_checked/less_than.rs b/src/test/ui/const-generics/const_evaluatable_checked/less_than.rs new file mode 100644 index 00000000000..907ea255abb --- /dev/null +++ b/src/test/ui/const-generics/const_evaluatable_checked/less_than.rs @@ -0,0 +1,14 @@ +// run-pass +#![feature(const_generics, const_evaluatable_checked)] +#![allow(incomplete_features)] + +struct Foo; + +fn test() -> Foo<{ N > 10 }> where Foo<{ N > 10 }>: Sized { + Foo +} + +fn main() { + let _: Foo = test::<12>(); + let _: Foo = test::<9>(); +} -- 2.44.0