use crate::simplify;
use crate::util::expand_aggregate;
use crate::MirPass;
-use rustc_data_structures::fx::FxHashMap;
+use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_errors::pluralize;
use rustc_hir as hir;
use rustc_hir::lang_items::LangItem;
use rustc_hir::GeneratorKind;
};
use rustc_mir_dataflow::storage::always_storage_live_locals;
use rustc_mir_dataflow::{self, Analysis};
+use rustc_span::def_id::DefId;
+use rustc_span::symbol::sym;
+use rustc_span::Span;
use rustc_target::abi::VariantIdx;
use rustc_target::spec::PanicStrategy;
use std::{iter, ops};
body: &Body<'tcx>,
witness: Ty<'tcx>,
upvars: Vec<Ty<'tcx>>,
- saved_locals: &GeneratorSavedLocals,
+ layout: &GeneratorLayout<'tcx>,
) {
let did = body.source.def_id();
let param_env = tcx.param_env(did);
}
};
- for (local, decl) in body.local_decls.iter_enumerated() {
- // Ignore locals which are internal or not saved between yields.
- if !saved_locals.contains(local) || decl.internal {
+ let mut mismatches = Vec::new();
+ for fty in &layout.field_tys {
+ if fty.ignore_for_traits {
continue;
}
- let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
+ let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty);
// Sanity check that typeck knows about the type of locals which are
// live across a suspension point
if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) {
- span_bug!(
- body.span,
- "Broken MIR: generator contains type {} in MIR, \
- but typeck only knows about {} and {:?}",
- decl_ty,
- allowed,
- allowed_upvars
- );
+ mismatches.push(decl_ty);
}
}
+
+ if !mismatches.is_empty() {
+ span_bug!(
+ body.span,
+ "Broken MIR: generator contains type {:?} in MIR, \
+ but typeck only knows about {} and {:?}",
+ mismatches,
+ allowed,
+ allowed_upvars
+ );
+ }
}
fn compute_layout<'tcx>(
+ tcx: TyCtxt<'tcx>,
liveness: LivenessInfo,
- body: &mut Body<'tcx>,
+ body: &Body<'tcx>,
) -> (
FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
GeneratorLayout<'tcx>,
let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
for (saved_local, local) in saved_locals.iter_enumerated() {
- locals.push(local);
- tys.push(body.local_decls[local].ty);
debug!("generator saved local {:?} => {:?}", saved_local, local);
+
+ locals.push(local);
+ let decl = &body.local_decls[local];
+ debug!(?decl);
+
+ let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir {
+ match decl.local_info {
+ // Do not include raw pointers created from accessing `static` items, as those could
+ // well be re-created by another access to the same static.
+ Some(box LocalInfo::StaticRef { is_thread_local, .. }) => !is_thread_local,
+ // Fake borrows are only read by fake reads, so do not have any reality in
+ // post-analysis MIR.
+ Some(box LocalInfo::FakeBorrow) => true,
+ _ => false,
+ }
+ } else {
+ // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that
+ // MIR building may introduce. This leads to wrongly ignored types, but this is
+ // necessary for internal consistency and to avoid ICEs.
+ decl.internal
+ };
+ let decl =
+ GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
+ debug!(?decl);
+
+ tys.push(decl);
}
// Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.
// just use the first one here. That's fine; fields do not move
// around inside generators, so it doesn't matter which variant
// index we access them by.
- remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx));
+ remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx));
}
variant_fields.push(fields);
variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
let layout =
GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts };
+ debug!(?layout);
(remap, layout, storage_liveness)
}
.collect()
}
+#[instrument(level = "debug", skip(tcx), ret)]
+pub(crate) fn mir_generator_witnesses<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ def_id: DefId,
+) -> GeneratorLayout<'tcx> {
+ assert!(tcx.sess.opts.unstable_opts.drop_tracking_mir);
+ let def_id = def_id.expect_local();
+
+ let (body, _) = tcx.mir_promoted(ty::WithOptConstParam::unknown(def_id));
+ let body = body.borrow();
+ let body = &*body;
+
+ // The first argument is the generator type passed by value
+ let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
+
+ // Get the interior types and substs which typeck computed
+ let movable = match *gen_ty.kind() {
+ ty::Generator(_, _, movability) => movability == hir::Movability::Movable,
+ _ => span_bug!(body.span, "unexpected generator type {}", gen_ty),
+ };
+
+ // When first entering the generator, move the resume argument into its new local.
+ let always_live_locals = always_storage_live_locals(&body);
+
+ let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
+
+ // Extract locals which are live across suspension point into `layout`
+ // `remap` gives a mapping from local indices onto generator struct indices
+ // `storage_liveness` tells us which locals have live storage at suspension points
+ let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body);
+
+ check_suspend_tys(tcx, &generator_layout, &body);
+
+ generator_layout
+}
+
impl<'tcx> MirPass<'tcx> for StateTransform {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let Some(yield_ty) = body.yield_ty() else {
// The first argument is the generator type passed by value
let gen_ty = body.local_decls.raw[1].ty;
- // Get the interior types and substs which typeck computed
- let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() {
+ // Get the discriminant type and substs which typeck computed
+ let (discr_ty, upvars, interior, movable) = match *gen_ty.kind() {
ty::Generator(_, substs, movability) => {
let substs = substs.as_generator();
(
- substs.upvar_tys().collect(),
- substs.witness(),
substs.discr_ty(tcx),
+ substs.upvar_tys().collect::<Vec<_>>(),
+ substs.witness(),
movability == hir::Movability::Movable,
)
}
let liveness_info =
locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
- sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals);
-
if tcx.sess.opts.unstable_opts.validate_mir {
let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
assigned_local: None,
// Extract locals which are live across suspension point into `layout`
// `remap` gives a mapping from local indices onto generator struct indices
// `storage_liveness` tells us which locals have live storage at suspension points
- let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
+ let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body);
+
+ if tcx.sess.opts.unstable_opts.validate_mir
+ && !tcx.sess.opts.unstable_opts.drop_tracking_mir
+ {
+ sanitize_witness(tcx, body, interior, upvars, &layout);
+ }
let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id()));
}
}
}
+
+fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) {
+ let mut linted_tys = FxHashSet::default();
+
+ // We want a user-facing param-env.
+ let param_env = tcx.param_env(body.source.def_id());
+
+ for (variant, yield_source_info) in
+ layout.variant_fields.iter().zip(&layout.variant_source_info)
+ {
+ debug!(?variant);
+ for &local in variant {
+ let decl = &layout.field_tys[local];
+ debug!(?decl);
+
+ if !decl.ignore_for_traits && linted_tys.insert(decl.ty) {
+ let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue };
+
+ check_must_not_suspend_ty(
+ tcx,
+ decl.ty,
+ hir_id,
+ param_env,
+ SuspendCheckData {
+ source_span: decl.source_info.span,
+ yield_span: yield_source_info.span,
+ plural_len: 1,
+ ..Default::default()
+ },
+ );
+ }
+ }
+ }
+}
+
+#[derive(Default)]
+struct SuspendCheckData<'a> {
+ source_span: Span,
+ yield_span: Span,
+ descr_pre: &'a str,
+ descr_post: &'a str,
+ plural_len: usize,
+}
+
+// Returns whether it emitted a diagnostic or not
+// Note that this fn and the proceeding one are based on the code
+// for creating must_use diagnostics
+//
+// Note that this technique was chosen over things like a `Suspend` marker trait
+// as it is simpler and has precedent in the compiler
+fn check_must_not_suspend_ty<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ ty: Ty<'tcx>,
+ hir_id: hir::HirId,
+ param_env: ty::ParamEnv<'tcx>,
+ data: SuspendCheckData<'_>,
+) -> bool {
+ if ty.is_unit() {
+ return false;
+ }
+
+ let plural_suffix = pluralize!(data.plural_len);
+
+ debug!("Checking must_not_suspend for {}", ty);
+
+ match *ty.kind() {
+ ty::Adt(..) if ty.is_box() => {
+ let boxed_ty = ty.boxed_ty();
+ let descr_pre = &format!("{}boxed ", data.descr_pre);
+ check_must_not_suspend_ty(
+ tcx,
+ boxed_ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_pre, ..data },
+ )
+ }
+ ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data),
+ // FIXME: support adding the attribute to TAITs
+ ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => {
+ let mut has_emitted = false;
+ for &(predicate, _) in tcx.explicit_item_bounds(def) {
+ // We only look at the `DefId`, so it is safe to skip the binder here.
+ if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) =
+ predicate.kind().skip_binder()
+ {
+ let def_id = poly_trait_predicate.trait_ref.def_id;
+ let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix);
+ if check_must_not_suspend_def(
+ tcx,
+ def_id,
+ hir_id,
+ SuspendCheckData { descr_pre, ..data },
+ ) {
+ has_emitted = true;
+ break;
+ }
+ }
+ }
+ has_emitted
+ }
+ ty::Dynamic(binder, _, _) => {
+ let mut has_emitted = false;
+ for predicate in binder.iter() {
+ if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() {
+ let def_id = trait_ref.def_id;
+ let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post);
+ if check_must_not_suspend_def(
+ tcx,
+ def_id,
+ hir_id,
+ SuspendCheckData { descr_post, ..data },
+ ) {
+ has_emitted = true;
+ break;
+ }
+ }
+ }
+ has_emitted
+ }
+ ty::Tuple(fields) => {
+ let mut has_emitted = false;
+ for (i, ty) in fields.iter().enumerate() {
+ let descr_post = &format!(" in tuple element {i}");
+ if check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_post, ..data },
+ ) {
+ has_emitted = true;
+ }
+ }
+ has_emitted
+ }
+ ty::Array(ty, len) => {
+ let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix);
+ check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData {
+ descr_pre,
+ plural_len: len.try_eval_usize(tcx, param_env).unwrap_or(0) as usize + 1,
+ ..data
+ },
+ )
+ }
+ // If drop tracking is enabled, we want to look through references, since the referrent
+ // may not be considered live across the await point.
+ ty::Ref(_region, ty, _mutability) => {
+ let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix);
+ check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_pre, ..data },
+ )
+ }
+ _ => false,
+ }
+}
+
+fn check_must_not_suspend_def(
+ tcx: TyCtxt<'_>,
+ def_id: DefId,
+ hir_id: hir::HirId,
+ data: SuspendCheckData<'_>,
+) -> bool {
+ if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) {
+ let msg = format!(
+ "{}`{}`{} held across a suspend point, but should not be",
+ data.descr_pre,
+ tcx.def_path_str(def_id),
+ data.descr_post,
+ );
+ tcx.struct_span_lint_hir(
+ rustc_session::lint::builtin::MUST_NOT_SUSPEND,
+ hir_id,
+ data.source_span,
+ msg,
+ |lint| {
+ // add span pointing to the offending yield/await
+ lint.span_label(data.yield_span, "the value is held across this suspend point");
+
+ // Add optional reason note
+ if let Some(note) = attr.value_str() {
+ // FIXME(guswynn): consider formatting this better
+ lint.span_note(data.source_span, note.as_str());
+ }
+
+ // Add some quick suggestions on what to do
+ // FIXME: can `drop` work as a suggestion here as well?
+ lint.span_help(
+ data.source_span,
+ "consider using a block (`{ ... }`) \
+ to shrink the value's scope, ending before the suspend point",
+ )
+ },
+ );
+
+ true
+ } else {
+ false
+ }
+}