]> git.lizzy.rs Git - rust.git/blobdiff - src/librustc_mir/transform/qualify_min_const_fn.rs
Make check_place iterate instead of recurse
[rust.git] / src / librustc_mir / transform / qualify_min_const_fn.rs
index 7826d3da4fed2a2e79587138823705f1d983998f..a1e2d0683d3800a37f3184e06299d15e80ad0d7c 100644 (file)
@@ -1,7 +1,7 @@
 use rustc::hir::def_id::DefId;
 use rustc::hir;
 use rustc::mir::*;
-use rustc::ty::{self, Predicate, TyCtxt};
+use rustc::ty::{self, Predicate, Ty, TyCtxt, adjustment::{PointerCast}};
 use rustc_target::spec::abi;
 use std::borrow::Cow;
 use syntax_pos::Span;
@@ -60,13 +60,14 @@ pub fn is_min_const_fn(
     }
 
     for local in &mir.local_decls {
-        check_ty(tcx, local.ty, local.source_info.span)?;
+        check_ty(tcx, local.ty, local.source_info.span, def_id)?;
     }
     // impl trait is gone in MIR, so check the return type manually
     check_ty(
         tcx,
         tcx.fn_sig(def_id).output().skip_binder(),
         mir.local_decls.iter().next().unwrap().source_info.span,
+        def_id,
     )?;
 
     for bb in mir.basic_blocks() {
@@ -80,8 +81,9 @@ pub fn is_min_const_fn(
 
 fn check_ty(
     tcx: TyCtxt<'a, 'tcx, 'tcx>,
-    ty: ty::Ty<'tcx>,
+    ty: Ty<'tcx>,
     span: Span,
+    fn_def_id: DefId,
 ) -> McfResult {
     for ty in ty.walk() {
         match ty.sty {
@@ -91,7 +93,9 @@ fn check_ty(
             )),
             ty::Opaque(..) => return Err((span, "`impl Trait` in const fn is unstable".into())),
             ty::FnPtr(..) => {
-                return Err((span, "function pointers in const fn are unstable".into()))
+                if !tcx.const_fn_is_allowed_fn_ptr(fn_def_id) {
+                    return Err((span, "function pointers in const fn are unstable".into()))
+                }
             }
             ty::Dynamic(preds, _) => {
                 for pred in preds.iter() {
@@ -132,10 +136,10 @@ fn check_rvalue(
 ) -> McfResult {
     match rvalue {
         Rvalue::Repeat(operand, _) | Rvalue::Use(operand) => {
-            check_operand(tcx, mir, operand, span)
+            check_operand(operand, span)
         }
         Rvalue::Len(place) | Rvalue::Discriminant(place) | Rvalue::Ref(_, _, place) => {
-            check_place(tcx, mir, place, span)
+            check_place(place, span)
         }
         Rvalue::Cast(CastKind::Misc, operand, cast_ty) => {
             use rustc::ty::cast::CastTy;
@@ -149,26 +153,26 @@ fn check_rvalue(
                 (CastTy::RPtr(_), CastTy::Float) => bug!(),
                 (CastTy::RPtr(_), CastTy::Int(_)) => bug!(),
                 (CastTy::Ptr(_), CastTy::RPtr(_)) => bug!(),
-                _ => check_operand(tcx, mir, operand, span),
+                _ => check_operand(operand, span),
             }
         }
-        Rvalue::Cast(CastKind::MutToConstPointer, operand, _) => {
-            check_operand(tcx, mir, operand, span)
+        Rvalue::Cast(CastKind::Pointer(PointerCast::MutToConstPointer), operand, _) => {
+            check_operand(operand, span)
         }
-        Rvalue::Cast(CastKind::UnsafeFnPointer, _, _) |
-        Rvalue::Cast(CastKind::ClosureFnPointer(_), _, _) |
-        Rvalue::Cast(CastKind::ReifyFnPointer, _, _) => Err((
+        Rvalue::Cast(CastKind::Pointer(PointerCast::UnsafeFnPointer), _, _) |
+        Rvalue::Cast(CastKind::Pointer(PointerCast::ClosureFnPointer(_)), _, _) |
+        Rvalue::Cast(CastKind::Pointer(PointerCast::ReifyFnPointer), _, _) => Err((
             span,
             "function pointer casts are not allowed in const fn".into(),
         )),
-        Rvalue::Cast(CastKind::Unsize, _, _) => Err((
+        Rvalue::Cast(CastKind::Pointer(PointerCast::Unsize), _, _) => Err((
             span,
             "unsizing casts are not allowed in const fn".into(),
         )),
         // binops are fine on integers
         Rvalue::BinaryOp(_, lhs, rhs) | Rvalue::CheckedBinaryOp(_, lhs, rhs) => {
-            check_operand(tcx, mir, lhs, span)?;
-            check_operand(tcx, mir, rhs, span)?;
+            check_operand(lhs, span)?;
+            check_operand(rhs, span)?;
             let ty = lhs.ty(mir, tcx);
             if ty.is_integral() || ty.is_bool() || ty.is_char() {
                 Ok(())
@@ -187,7 +191,7 @@ fn check_rvalue(
         Rvalue::UnaryOp(_, operand) => {
             let ty = operand.ty(mir, tcx);
             if ty.is_integral() || ty.is_bool() {
-                check_operand(tcx, mir, operand, span)
+                check_operand(operand, span)
             } else {
                 Err((
                     span,
@@ -197,7 +201,7 @@ fn check_rvalue(
         }
         Rvalue::Aggregate(_, operands) => {
             for operand in operands {
-                check_operand(tcx, mir, operand, span)?;
+                check_operand(operand, span)?;
             }
             Ok(())
         }
@@ -212,11 +216,11 @@ fn check_statement(
     let span = statement.source_info.span;
     match &statement.kind {
         StatementKind::Assign(place, rval) => {
-            check_place(tcx, mir, place, span)?;
+            check_place(place, span)?;
             check_rvalue(tcx, mir, rval, span)
         }
 
-        StatementKind::FakeRead(_, place) => check_place(tcx, mir, place, span),
+        StatementKind::FakeRead(_, place) => check_place(place, span),
 
         // just an assignment
         StatementKind::SetDiscriminant { .. } => Ok(()),
@@ -235,43 +239,40 @@ fn check_statement(
 }
 
 fn check_operand(
-    tcx: TyCtxt<'a, 'tcx, 'tcx>,
-    mir: &'a Mir<'tcx>,
     operand: &Operand<'tcx>,
     span: Span,
 ) -> McfResult {
     match operand {
         Operand::Move(place) | Operand::Copy(place) => {
-            check_place(tcx, mir, place, span)
+            check_place(place, span)
         }
         Operand::Constant(_) => Ok(()),
     }
 }
 
-fn check_place(
-    tcx: TyCtxt<'a, 'tcx, 'tcx>,
-    mir: &'a Mir<'tcx>,
-    place: &Place<'tcx>,
-    span: Span,
-) -> McfResult {
-    match place {
-        Place::Base(PlaceBase::Local(_)) => Ok(()),
-        // promoteds are always fine, they are essentially constants
-        Place::Base(PlaceBase::Static(box Static { kind: StaticKind::Promoted(_), .. })) => Ok(()),
-        Place::Base(PlaceBase::Static(box Static { kind: StaticKind::Static(_), .. })) =>
-            Err((span, "cannot access `static` items in const fn".into())),
-        Place::Projection(proj) => {
+fn check_place(place: &Place<'tcx>, span: Span) -> McfResult {
+    place.iterate(|place_base, place_projection| {
+        for proj in place_projection {
             match proj.elem {
-                | ProjectionElem::ConstantIndex { .. } | ProjectionElem::Subslice { .. }
-                | ProjectionElem::Deref | ProjectionElem::Field(..) | ProjectionElem::Index(_) => {
-                    check_place(tcx, mir, &proj.base, span)
-                }
-                | ProjectionElem::Downcast(..) => {
-                    Err((span, "`match` or `if let` in `const fn` is unstable".into()))
+                ProjectionElem::Downcast(..) => {
+                    return Err((span, "`match` or `if let` in `const fn` is unstable".into()));
                 }
+                ProjectionElem::ConstantIndex { .. }
+                | ProjectionElem::Subslice { .. }
+                | ProjectionElem::Deref
+                | ProjectionElem::Field(..)
+                | ProjectionElem::Index(_) => {}
             }
         }
-    }
+
+        match place_base {
+            PlaceBase::Static(box Static { kind: StaticKind::Static(_), .. }) => {
+                Err((span, "cannot access `static` items in const fn".into()))
+            }
+            PlaceBase::Local(_)
+            | PlaceBase::Static(box Static { kind: StaticKind::Promoted(_), .. }) => Ok(()),
+        }
+    })
 }
 
 fn check_terminator(
@@ -286,11 +287,11 @@ fn check_terminator(
         | TerminatorKind::Resume => Ok(()),
 
         TerminatorKind::Drop { location, .. } => {
-            check_place(tcx, mir, location, span)
+            check_place(location, span)
         }
         TerminatorKind::DropAndReplace { location, value, .. } => {
-            check_place(tcx, mir, location, span)?;
-            check_operand(tcx, mir, value, span)
+            check_place(location, span)?;
+            check_operand(value, span)
         },
 
         TerminatorKind::FalseEdges { .. } | TerminatorKind::SwitchInt { .. } => Err((
@@ -342,10 +343,10 @@ fn check_terminator(
                     )),
                 }
 
-                check_operand(tcx, mir, func, span)?;
+                check_operand(func, span)?;
 
                 for arg in args {
-                    check_operand(tcx, mir, arg, span)?;
+                    check_operand(arg, span)?;
                 }
                 Ok(())
             } else {
@@ -359,7 +360,7 @@ fn check_terminator(
             msg: _,
             target: _,
             cleanup: _,
-        } => check_operand(tcx, mir, cond, span),
+        } => check_operand(cond, span),
 
         TerminatorKind::FalseUnwind { .. } => {
             Err((span, "loops are not allowed in const fn".into()))