]> git.lizzy.rs Git - rust.git/blobdiff - compiler/rustc_mir_transform/src/generator.rs
Rollup merge of #107125 - WaffleLapkin:expect_an_item_in_your_hir_by_the_next_morning...
[rust.git] / compiler / rustc_mir_transform / src / generator.rs
index 39c61a34afcbdab70fa189af8a82527e35993f7e..5624e312da1cbe8eeabe9cbc82d219ebea390f86 100644 (file)
@@ -54,7 +54,8 @@
 use crate::simplify;
 use crate::util::expand_aggregate;
 use crate::MirPass;
-use rustc_data_structures::fx::FxHashMap;
+use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_errors::pluralize;
 use rustc_hir as hir;
 use rustc_hir::lang_items::LangItem;
 use rustc_hir::GeneratorKind;
@@ -70,6 +71,9 @@
 };
 use rustc_mir_dataflow::storage::always_storage_live_locals;
 use rustc_mir_dataflow::{self, Analysis};
+use rustc_span::def_id::DefId;
+use rustc_span::symbol::sym;
+use rustc_span::Span;
 use rustc_target::abi::VariantIdx;
 use rustc_target::spec::PanicStrategy;
 use std::{iter, ops};
@@ -854,7 +858,7 @@ fn sanitize_witness<'tcx>(
     body: &Body<'tcx>,
     witness: Ty<'tcx>,
     upvars: Vec<Ty<'tcx>>,
-    saved_locals: &GeneratorSavedLocals,
+    layout: &GeneratorLayout<'tcx>,
 ) {
     let did = body.source.def_id();
     let param_env = tcx.param_env(did);
@@ -873,31 +877,36 @@ fn sanitize_witness<'tcx>(
         }
     };
 
-    for (local, decl) in body.local_decls.iter_enumerated() {
-        // Ignore locals which are internal or not saved between yields.
-        if !saved_locals.contains(local) || decl.internal {
+    let mut mismatches = Vec::new();
+    for fty in &layout.field_tys {
+        if fty.ignore_for_traits {
             continue;
         }
-        let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
+        let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty);
 
         // Sanity check that typeck knows about the type of locals which are
         // live across a suspension point
         if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) {
-            span_bug!(
-                body.span,
-                "Broken MIR: generator contains type {} in MIR, \
-                       but typeck only knows about {} and {:?}",
-                decl_ty,
-                allowed,
-                allowed_upvars
-            );
+            mismatches.push(decl_ty);
         }
     }
+
+    if !mismatches.is_empty() {
+        span_bug!(
+            body.span,
+            "Broken MIR: generator contains type {:?} in MIR, \
+                       but typeck only knows about {} and {:?}",
+            mismatches,
+            allowed,
+            allowed_upvars
+        );
+    }
 }
 
 fn compute_layout<'tcx>(
+    tcx: TyCtxt<'tcx>,
     liveness: LivenessInfo,
-    body: &mut Body<'tcx>,
+    body: &Body<'tcx>,
 ) -> (
     FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
     GeneratorLayout<'tcx>,
@@ -915,9 +924,33 @@ fn compute_layout<'tcx>(
     let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
     let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
     for (saved_local, local) in saved_locals.iter_enumerated() {
-        locals.push(local);
-        tys.push(body.local_decls[local].ty);
         debug!("generator saved local {:?} => {:?}", saved_local, local);
+
+        locals.push(local);
+        let decl = &body.local_decls[local];
+        debug!(?decl);
+
+        let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir {
+            match decl.local_info {
+                // Do not include raw pointers created from accessing `static` items, as those could
+                // well be re-created by another access to the same static.
+                Some(box LocalInfo::StaticRef { is_thread_local, .. }) => !is_thread_local,
+                // Fake borrows are only read by fake reads, so do not have any reality in
+                // post-analysis MIR.
+                Some(box LocalInfo::FakeBorrow) => true,
+                _ => false,
+            }
+        } else {
+            // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that
+            // MIR building may introduce. This leads to wrongly ignored types, but this is
+            // necessary for internal consistency and to avoid ICEs.
+            decl.internal
+        };
+        let decl =
+            GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
+        debug!(?decl);
+
+        tys.push(decl);
     }
 
     // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.
@@ -947,7 +980,7 @@ fn compute_layout<'tcx>(
             // just use the first one here. That's fine; fields do not move
             // around inside generators, so it doesn't matter which variant
             // index we access them by.
-            remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx));
+            remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx));
         }
         variant_fields.push(fields);
         variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
@@ -957,6 +990,7 @@ fn compute_layout<'tcx>(
 
     let layout =
         GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts };
+    debug!(?layout);
 
     (remap, layout, storage_liveness)
 }
@@ -1351,6 +1385,42 @@ fn create_cases<'tcx>(
         .collect()
 }
 
+#[instrument(level = "debug", skip(tcx), ret)]
+pub(crate) fn mir_generator_witnesses<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    def_id: DefId,
+) -> GeneratorLayout<'tcx> {
+    assert!(tcx.sess.opts.unstable_opts.drop_tracking_mir);
+    let def_id = def_id.expect_local();
+
+    let (body, _) = tcx.mir_promoted(ty::WithOptConstParam::unknown(def_id));
+    let body = body.borrow();
+    let body = &*body;
+
+    // The first argument is the generator type passed by value
+    let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
+
+    // Get the interior types and substs which typeck computed
+    let movable = match *gen_ty.kind() {
+        ty::Generator(_, _, movability) => movability == hir::Movability::Movable,
+        _ => span_bug!(body.span, "unexpected generator type {}", gen_ty),
+    };
+
+    // When first entering the generator, move the resume argument into its new local.
+    let always_live_locals = always_storage_live_locals(&body);
+
+    let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
+
+    // Extract locals which are live across suspension point into `layout`
+    // `remap` gives a mapping from local indices onto generator struct indices
+    // `storage_liveness` tells us which locals have live storage at suspension points
+    let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body);
+
+    check_suspend_tys(tcx, &generator_layout, &body);
+
+    generator_layout
+}
+
 impl<'tcx> MirPass<'tcx> for StateTransform {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         let Some(yield_ty) = body.yield_ty() else {
@@ -1363,14 +1433,14 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         // The first argument is the generator type passed by value
         let gen_ty = body.local_decls.raw[1].ty;
 
-        // Get the interior types and substs which typeck computed
-        let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() {
+        // Get the discriminant type and substs which typeck computed
+        let (discr_ty, upvars, interior, movable) = match *gen_ty.kind() {
             ty::Generator(_, substs, movability) => {
                 let substs = substs.as_generator();
                 (
-                    substs.upvar_tys().collect(),
-                    substs.witness(),
                     substs.discr_ty(tcx),
+                    substs.upvar_tys().collect::<Vec<_>>(),
+                    substs.witness(),
                     movability == hir::Movability::Movable,
                 )
             }
@@ -1434,8 +1504,6 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         let liveness_info =
             locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
 
-        sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals);
-
         if tcx.sess.opts.unstable_opts.validate_mir {
             let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
                 assigned_local: None,
@@ -1449,7 +1517,13 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         // Extract locals which are live across suspension point into `layout`
         // `remap` gives a mapping from local indices onto generator struct indices
         // `storage_liveness` tells us which locals have live storage at suspension points
-        let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
+        let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body);
+
+        if tcx.sess.opts.unstable_opts.validate_mir
+            && !tcx.sess.opts.unstable_opts.drop_tracking_mir
+        {
+            sanitize_witness(tcx, body, interior, upvars, &layout);
+        }
 
         let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id()));
 
@@ -1583,6 +1657,7 @@ fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
             | StatementKind::AscribeUserType(..)
             | StatementKind::Coverage(..)
             | StatementKind::Intrinsic(..)
+            | StatementKind::ConstEvalCounter
             | StatementKind::Nop => {}
         }
     }
@@ -1631,3 +1706,212 @@ fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location
         }
     }
 }
+
+fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) {
+    let mut linted_tys = FxHashSet::default();
+
+    // We want a user-facing param-env.
+    let param_env = tcx.param_env(body.source.def_id());
+
+    for (variant, yield_source_info) in
+        layout.variant_fields.iter().zip(&layout.variant_source_info)
+    {
+        debug!(?variant);
+        for &local in variant {
+            let decl = &layout.field_tys[local];
+            debug!(?decl);
+
+            if !decl.ignore_for_traits && linted_tys.insert(decl.ty) {
+                let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue };
+
+                check_must_not_suspend_ty(
+                    tcx,
+                    decl.ty,
+                    hir_id,
+                    param_env,
+                    SuspendCheckData {
+                        source_span: decl.source_info.span,
+                        yield_span: yield_source_info.span,
+                        plural_len: 1,
+                        ..Default::default()
+                    },
+                );
+            }
+        }
+    }
+}
+
+#[derive(Default)]
+struct SuspendCheckData<'a> {
+    source_span: Span,
+    yield_span: Span,
+    descr_pre: &'a str,
+    descr_post: &'a str,
+    plural_len: usize,
+}
+
+// Returns whether it emitted a diagnostic or not
+// Note that this fn and the proceeding one are based on the code
+// for creating must_use diagnostics
+//
+// Note that this technique was chosen over things like a `Suspend` marker trait
+// as it is simpler and has precedent in the compiler
+fn check_must_not_suspend_ty<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ty: Ty<'tcx>,
+    hir_id: hir::HirId,
+    param_env: ty::ParamEnv<'tcx>,
+    data: SuspendCheckData<'_>,
+) -> bool {
+    if ty.is_unit() {
+        return false;
+    }
+
+    let plural_suffix = pluralize!(data.plural_len);
+
+    debug!("Checking must_not_suspend for {}", ty);
+
+    match *ty.kind() {
+        ty::Adt(..) if ty.is_box() => {
+            let boxed_ty = ty.boxed_ty();
+            let descr_pre = &format!("{}boxed ", data.descr_pre);
+            check_must_not_suspend_ty(
+                tcx,
+                boxed_ty,
+                hir_id,
+                param_env,
+                SuspendCheckData { descr_pre, ..data },
+            )
+        }
+        ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data),
+        // FIXME: support adding the attribute to TAITs
+        ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => {
+            let mut has_emitted = false;
+            for &(predicate, _) in tcx.explicit_item_bounds(def) {
+                // We only look at the `DefId`, so it is safe to skip the binder here.
+                if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) =
+                    predicate.kind().skip_binder()
+                {
+                    let def_id = poly_trait_predicate.trait_ref.def_id;
+                    let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix);
+                    if check_must_not_suspend_def(
+                        tcx,
+                        def_id,
+                        hir_id,
+                        SuspendCheckData { descr_pre, ..data },
+                    ) {
+                        has_emitted = true;
+                        break;
+                    }
+                }
+            }
+            has_emitted
+        }
+        ty::Dynamic(binder, _, _) => {
+            let mut has_emitted = false;
+            for predicate in binder.iter() {
+                if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() {
+                    let def_id = trait_ref.def_id;
+                    let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post);
+                    if check_must_not_suspend_def(
+                        tcx,
+                        def_id,
+                        hir_id,
+                        SuspendCheckData { descr_post, ..data },
+                    ) {
+                        has_emitted = true;
+                        break;
+                    }
+                }
+            }
+            has_emitted
+        }
+        ty::Tuple(fields) => {
+            let mut has_emitted = false;
+            for (i, ty) in fields.iter().enumerate() {
+                let descr_post = &format!(" in tuple element {i}");
+                if check_must_not_suspend_ty(
+                    tcx,
+                    ty,
+                    hir_id,
+                    param_env,
+                    SuspendCheckData { descr_post, ..data },
+                ) {
+                    has_emitted = true;
+                }
+            }
+            has_emitted
+        }
+        ty::Array(ty, len) => {
+            let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix);
+            check_must_not_suspend_ty(
+                tcx,
+                ty,
+                hir_id,
+                param_env,
+                SuspendCheckData {
+                    descr_pre,
+                    plural_len: len.try_eval_usize(tcx, param_env).unwrap_or(0) as usize + 1,
+                    ..data
+                },
+            )
+        }
+        // If drop tracking is enabled, we want to look through references, since the referrent
+        // may not be considered live across the await point.
+        ty::Ref(_region, ty, _mutability) => {
+            let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix);
+            check_must_not_suspend_ty(
+                tcx,
+                ty,
+                hir_id,
+                param_env,
+                SuspendCheckData { descr_pre, ..data },
+            )
+        }
+        _ => false,
+    }
+}
+
+fn check_must_not_suspend_def(
+    tcx: TyCtxt<'_>,
+    def_id: DefId,
+    hir_id: hir::HirId,
+    data: SuspendCheckData<'_>,
+) -> bool {
+    if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) {
+        let msg = format!(
+            "{}`{}`{} held across a suspend point, but should not be",
+            data.descr_pre,
+            tcx.def_path_str(def_id),
+            data.descr_post,
+        );
+        tcx.struct_span_lint_hir(
+            rustc_session::lint::builtin::MUST_NOT_SUSPEND,
+            hir_id,
+            data.source_span,
+            msg,
+            |lint| {
+                // add span pointing to the offending yield/await
+                lint.span_label(data.yield_span, "the value is held across this suspend point");
+
+                // Add optional reason note
+                if let Some(note) = attr.value_str() {
+                    // FIXME(guswynn): consider formatting this better
+                    lint.span_note(data.source_span, note.as_str());
+                }
+
+                // Add some quick suggestions on what to do
+                // FIXME: can `drop` work as a suggestion here as well?
+                lint.span_help(
+                    data.source_span,
+                    "consider using a block (`{ ... }`) \
+                    to shrink the value's scope, ending before the suspend point",
+                )
+            },
+        );
+
+        true
+    } else {
+        false
+    }
+}