]> git.lizzy.rs Git - rust.git/blobdiff - src/librustc/ty/fold.rs
Rollup merge of #61420 - felixrabe:patch-2, r=dtolnay
[rust.git] / src / librustc / ty / fold.rs
index 7f58a436fcf0d678367ad01427e60706717e271f..dae28d51efc2ead309d4765d76a941708e3abb73 100644 (file)
@@ -32,6 +32,7 @@
 //! looking for, and does not need to visit anything else.
 
 use crate::hir::def_id::DefId;
+use crate::mir::interpret::ConstValue;
 use crate::ty::{self, Binder, Ty, TyCtxt, TypeFlags, flags::FlagComputation};
 
 use std::collections::BTreeMap;
@@ -96,7 +97,11 @@ fn needs_infer(&self) -> bool {
         )
     }
     fn has_placeholders(&self) -> bool {
-        self.has_type_flags(TypeFlags::HAS_RE_PLACEHOLDER | TypeFlags::HAS_TY_PLACEHOLDER)
+        self.has_type_flags(
+            TypeFlags::HAS_RE_PLACEHOLDER |
+            TypeFlags::HAS_TY_PLACEHOLDER |
+            TypeFlags::HAS_CT_PLACEHOLDER
+        )
     }
     fn needs_subst(&self) -> bool {
         self.has_type_flags(TypeFlags::NEEDS_SUBST)
@@ -193,29 +198,37 @@ fn visit_const(&mut self, c: &'tcx ty::Const<'tcx>) -> bool {
 ///////////////////////////////////////////////////////////////////////////
 // Some sample folders
 
-pub struct BottomUpFolder<'a, 'gcx: 'a+'tcx, 'tcx: 'a, F, G>
+pub struct BottomUpFolder<'a, 'gcx: 'a+'tcx, 'tcx: 'a, F, G, H>
     where F: FnMut(Ty<'tcx>) -> Ty<'tcx>,
           G: FnMut(ty::Region<'tcx>) -> ty::Region<'tcx>,
+          H: FnMut(&'tcx ty::Const<'tcx>) -> &'tcx ty::Const<'tcx>,
 {
     pub tcx: TyCtxt<'a, 'gcx, 'tcx>,
-    pub fldop: F,
-    pub reg_op: G,
+    pub ty_op: F,
+    pub lt_op: G,
+    pub ct_op: H,
 }
 
-impl<'a, 'gcx, 'tcx, F, G> TypeFolder<'gcx, 'tcx> for BottomUpFolder<'a, 'gcx, 'tcx, F, G>
+impl<'a, 'gcx, 'tcx, F, G, H> TypeFolder<'gcx, 'tcx> for BottomUpFolder<'a, 'gcx, 'tcx, F, G, H>
     where F: FnMut(Ty<'tcx>) -> Ty<'tcx>,
           G: FnMut(ty::Region<'tcx>) -> ty::Region<'tcx>,
+          H: FnMut(&'tcx ty::Const<'tcx>) -> &'tcx ty::Const<'tcx>,
 {
     fn tcx<'b>(&'b self) -> TyCtxt<'b, 'gcx, 'tcx> { self.tcx }
 
     fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
-        let t1 = ty.super_fold_with(self);
-        (self.fldop)(t1)
+        let t = ty.super_fold_with(self);
+        (self.ty_op)(t)
     }
 
     fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
         let r = r.super_fold_with(self);
-        (self.reg_op)(r)
+        (self.lt_op)(r)
+    }
+
+    fn fold_const(&mut self, ct: &'tcx ty::Const<'tcx>) -> &'tcx ty::Const<'tcx> {
+        let ct = ct.super_fold_with(self);
+        (self.ct_op)(ct)
     }
 }
 
@@ -422,22 +435,26 @@ struct BoundVarReplacer<'a, 'gcx: 'a + 'tcx, 'tcx: 'a> {
 
     fld_r: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a),
     fld_t: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a),
+    fld_c: &'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> &'tcx ty::Const<'tcx> + 'a),
 }
 
 impl<'a, 'gcx, 'tcx> BoundVarReplacer<'a, 'gcx, 'tcx> {
-    fn new<F, G>(
+    fn new<F, G, H>(
         tcx: TyCtxt<'a, 'gcx, 'tcx>,
         fld_r: &'a mut F,
-        fld_t: &'a mut G
+        fld_t: &'a mut G,
+        fld_c: &'a mut H,
     ) -> Self
         where F: FnMut(ty::BoundRegion) -> ty::Region<'tcx>,
-              G: FnMut(ty::BoundTy) -> Ty<'tcx>
+              G: FnMut(ty::BoundTy) -> Ty<'tcx>,
+              H: FnMut(ty::BoundVar, Ty<'tcx>) -> &'tcx ty::Const<'tcx>,
     {
         BoundVarReplacer {
             tcx,
             current_index: ty::INNERMOST,
             fld_r,
             fld_t,
+            fld_c,
         }
     }
 }
@@ -497,6 +514,32 @@ fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
             _ => r
         }
     }
+
+    fn fold_const(&mut self, ct: &'tcx ty::Const<'tcx>) -> &'tcx ty::Const<'tcx> {
+        if let ty::Const {
+            val: ConstValue::Infer(ty::InferConst::Canonical(debruijn, bound_const)),
+            ty,
+        } = *ct {
+            if debruijn == self.current_index {
+                let fld_c = &mut self.fld_c;
+                let ct = fld_c(bound_const, ty);
+                ty::fold::shift_vars(
+                    self.tcx,
+                    &ct,
+                    self.current_index.as_u32()
+                )
+            } else {
+                ct
+            }
+        } else {
+            if !ct.has_vars_bound_at_or_above(self.current_index) {
+                // Nothing more to substitute.
+                ct
+            } else {
+                ct.super_fold_with(self)
+            }
+        }
+    }
 }
 
 impl<'a, 'gcx, 'tcx> TyCtxt<'a, 'gcx, 'tcx> {
@@ -519,27 +562,34 @@ pub fn replace_late_bound_regions<T, F>(
         where F: FnMut(ty::BoundRegion) -> ty::Region<'tcx>,
               T: TypeFoldable<'tcx>
     {
-        // identity for bound types
+        // identity for bound types and consts
         let fld_t = |bound_ty| self.mk_ty(ty::Bound(ty::INNERMOST, bound_ty));
-        self.replace_escaping_bound_vars(value.skip_binder(), fld_r, fld_t)
+        let fld_c = |bound_ct, ty| {
+            self.mk_const_infer(ty::InferConst::Canonical(ty::INNERMOST, bound_ct), ty)
+        };
+        self.replace_escaping_bound_vars(value.skip_binder(), fld_r, fld_t, fld_c)
     }
 
     /// Replaces all escaping bound vars. The `fld_r` closure replaces escaping
-    /// bound regions while the `fld_t` closure replaces escaping bound types.
-    pub fn replace_escaping_bound_vars<T, F, G>(
+    /// bound regions; the `fld_t` closure replaces escaping bound types and the `fld_c`
+    /// closure replaces escaping bound consts.
+    pub fn replace_escaping_bound_vars<T, F, G, H>(
         self,
         value: &T,
         mut fld_r: F,
-        mut fld_t: G
+        mut fld_t: G,
+        mut fld_c: H,
     ) -> (T, BTreeMap<ty::BoundRegion, ty::Region<'tcx>>)
         where F: FnMut(ty::BoundRegion) -> ty::Region<'tcx>,
               G: FnMut(ty::BoundTy) -> Ty<'tcx>,
-              T: TypeFoldable<'tcx>
+              H: FnMut(ty::BoundVar, Ty<'tcx>) -> &'tcx ty::Const<'tcx>,
+              T: TypeFoldable<'tcx>,
     {
         use rustc_data_structures::fx::FxHashMap;
 
         let mut region_map = BTreeMap::new();
         let mut type_map = FxHashMap::default();
+        let mut const_map = FxHashMap::default();
 
         if !value.has_escaping_bound_vars() {
             (value.clone(), region_map)
@@ -552,7 +602,16 @@ pub fn replace_escaping_bound_vars<T, F, G>(
                 *type_map.entry(bound_ty).or_insert_with(|| fld_t(bound_ty))
             };
 
-            let mut replacer = BoundVarReplacer::new(self, &mut real_fld_r, &mut real_fld_t);
+            let mut real_fld_c = |bound_ct, ty| {
+                *const_map.entry(bound_ct).or_insert_with(|| fld_c(bound_ct, ty))
+            };
+
+            let mut replacer = BoundVarReplacer::new(
+                self,
+                &mut real_fld_r,
+                &mut real_fld_t,
+                &mut real_fld_c,
+            );
             let result = value.fold_with(&mut replacer);
             (result, region_map)
         }
@@ -561,17 +620,19 @@ pub fn replace_escaping_bound_vars<T, F, G>(
     /// Replaces all types or regions bound by the given `Binder`. The `fld_r`
     /// closure replaces bound regions while the `fld_t` closure replaces bound
     /// types.
-    pub fn replace_bound_vars<T, F, G>(
+    pub fn replace_bound_vars<T, F, G, H>(
         self,
         value: &Binder<T>,
         fld_r: F,
-        fld_t: G
+        fld_t: G,
+        fld_c: H,
     ) -> (T, BTreeMap<ty::BoundRegion, ty::Region<'tcx>>)
         where F: FnMut(ty::BoundRegion) -> ty::Region<'tcx>,
               G: FnMut(ty::BoundTy) -> Ty<'tcx>,
+              H: FnMut(ty::BoundVar, Ty<'tcx>) -> &'tcx ty::Const<'tcx>,
               T: TypeFoldable<'tcx>
     {
-        self.replace_escaping_bound_vars(value.skip_binder(), fld_r, fld_t)
+        self.replace_escaping_bound_vars(value.skip_binder(), fld_r, fld_t, fld_c)
     }
 
     /// Replaces any late-bound regions bound in `value` with
@@ -732,6 +793,28 @@ fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
             _ => ty.super_fold_with(self),
         }
     }
+
+    fn fold_const(&mut self, ct: &'tcx ty::Const<'tcx>) -> &'tcx ty::Const<'tcx> {
+        if let ty::Const {
+            val: ConstValue::Infer(ty::InferConst::Canonical(debruijn, bound_const)),
+            ty,
+        } = *ct {
+            if self.amount == 0 || debruijn < self.current_index {
+                ct
+            } else {
+                let debruijn = match self.direction {
+                    Direction::In => debruijn.shifted_in(self.amount),
+                    Direction::Out => {
+                        assert!(debruijn.as_u32() >= self.amount);
+                        debruijn.shifted_out(self.amount)
+                    }
+                };
+                self.tcx.mk_const_infer(ty::InferConst::Canonical(debruijn, bound_const), ty)
+            }
+        } else {
+            ct.super_fold_with(self)
+        }
+    }
 }
 
 pub fn shift_region<'a, 'gcx, 'tcx>(
@@ -824,6 +907,17 @@ fn visit_region(&mut self, r: ty::Region<'tcx>) -> bool {
         // visited.
         r.bound_at_or_above_binder(self.outer_index)
     }
+
+    fn visit_const(&mut self, ct: &'tcx ty::Const<'tcx>) -> bool {
+        if let ty::Const {
+            val: ConstValue::Infer(ty::InferConst::Canonical(debruijn, _)),
+            ..
+        } = *ct {
+            debruijn >= self.outer_index
+        } else {
+            false
+        }
+    }
 }
 
 struct HasTypeFlagsVisitor {
@@ -845,7 +939,7 @@ fn visit_region(&mut self, r: ty::Region<'tcx>) -> bool {
     fn visit_const(&mut self, c: &'tcx ty::Const<'tcx>) -> bool {
         let flags = FlagComputation::for_const(c);
         debug!("HasTypeFlagsVisitor: c={:?} c.flags={:?} self.flags={:?}", c, flags, self.flags);
-        flags.intersects(self.flags) || c.super_visit_with(self)
+        flags.intersects(self.flags)
     }
 }