]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/generator.rs
0df732aa22bad7b23724909ed58311f081e96d86
[rust.git] / compiler / rustc_mir_transform / src / generator.rs
1 //! This is the implementation of the pass which transforms generators into state machines.
2 //!
3 //! MIR generation for generators creates a function which has a self argument which
4 //! passes by value. This argument is effectively a generator type which only contains upvars and
5 //! is only used for this argument inside the MIR for the generator.
6 //! It is passed by value to enable upvars to be moved out of it. Drop elaboration runs on that
7 //! MIR before this pass and creates drop flags for MIR locals.
8 //! It will also drop the generator argument (which only consists of upvars) if any of the upvars
9 //! are moved out of. This pass elaborates the drops of upvars / generator argument in the case
10 //! that none of the upvars were moved out of. This is because we cannot have any drops of this
11 //! generator in the MIR, since it is used to create the drop glue for the generator. We'd get
12 //! infinite recursion otherwise.
13 //!
14 //! This pass creates the implementation for either the `Generator::resume` or `Future::poll`
15 //! function and the drop shim for the generator based on the MIR input.
16 //! It converts the generator argument from Self to &mut Self adding derefs in the MIR as needed.
17 //! It computes the final layout of the generator struct which looks like this:
18 //!     First upvars are stored
19 //!     It is followed by the generator state field.
20 //!     Then finally the MIR locals which are live across a suspension point are stored.
21 //!     ```ignore (illustrative)
22 //!     struct Generator {
23 //!         upvars...,
24 //!         state: u32,
25 //!         mir_locals...,
26 //!     }
27 //!     ```
28 //! This pass computes the meaning of the state field and the MIR locals which are live
29 //! across a suspension point. There are however three hardcoded generator states:
30 //!     0 - Generator have not been resumed yet
31 //!     1 - Generator has returned / is completed
32 //!     2 - Generator has been poisoned
33 //!
34 //! It also rewrites `return x` and `yield y` as setting a new generator state and returning
35 //! `GeneratorState::Complete(x)` and `GeneratorState::Yielded(y)`,
36 //! or `Poll::Ready(x)` and `Poll::Pending` respectively.
37 //! MIR locals which are live across a suspension point are moved to the generator struct
38 //! with references to them being updated with references to the generator struct.
39 //!
40 //! The pass creates two functions which have a switch on the generator state giving
41 //! the action to take.
42 //!
43 //! One of them is the implementation of `Generator::resume` / `Future::poll`.
44 //! For generators with state 0 (unresumed) it starts the execution of the generator.
45 //! For generators with state 1 (returned) and state 2 (poisoned) it panics.
46 //! Otherwise it continues the execution from the last suspension point.
47 //!
48 //! The other function is the drop glue for the generator.
49 //! For generators with state 0 (unresumed) it drops the upvars of the generator.
50 //! For generators with state 1 (returned) and state 2 (poisoned) it does nothing.
51 //! Otherwise it drops all the values in scope at the last suspension point.
52
53 use crate::deref_separator::deref_finder;
54 use crate::simplify;
55 use crate::util::expand_aggregate;
56 use crate::MirPass;
57 use rustc_data_structures::fx::FxHashMap;
58 use rustc_hir as hir;
59 use rustc_hir::lang_items::LangItem;
60 use rustc_hir::GeneratorKind;
61 use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
62 use rustc_index::vec::{Idx, IndexVec};
63 use rustc_middle::mir::dump_mir;
64 use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
65 use rustc_middle::mir::*;
66 use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
67 use rustc_middle::ty::{GeneratorSubsts, SubstsRef};
68 use rustc_mir_dataflow::impls::{
69     MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
70 };
71 use rustc_mir_dataflow::storage::always_storage_live_locals;
72 use rustc_mir_dataflow::{self, Analysis};
73 use rustc_target::abi::VariantIdx;
74 use rustc_target::spec::PanicStrategy;
75 use std::{iter, ops};
76
77 pub struct StateTransform;
78
79 struct RenameLocalVisitor<'tcx> {
80     from: Local,
81     to: Local,
82     tcx: TyCtxt<'tcx>,
83 }
84
85 impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
86     fn tcx(&self) -> TyCtxt<'tcx> {
87         self.tcx
88     }
89
90     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
91         if *local == self.from {
92             *local = self.to;
93         }
94     }
95
96     fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
97         match terminator.kind {
98             TerminatorKind::Return => {
99                 // Do not replace the implicit `_0` access here, as that's not possible. The
100                 // transform already handles `return` correctly.
101             }
102             _ => self.super_terminator(terminator, location),
103         }
104     }
105 }
106
107 struct DerefArgVisitor<'tcx> {
108     tcx: TyCtxt<'tcx>,
109 }
110
111 impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> {
112     fn tcx(&self) -> TyCtxt<'tcx> {
113         self.tcx
114     }
115
116     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
117         assert_ne!(*local, SELF_ARG);
118     }
119
120     fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
121         if place.local == SELF_ARG {
122             replace_base(
123                 place,
124                 Place {
125                     local: SELF_ARG,
126                     projection: self.tcx().intern_place_elems(&[ProjectionElem::Deref]),
127                 },
128                 self.tcx,
129             );
130         } else {
131             self.visit_local(&mut place.local, context, location);
132
133             for elem in place.projection.iter() {
134                 if let PlaceElem::Index(local) = elem {
135                     assert_ne!(local, SELF_ARG);
136                 }
137             }
138         }
139     }
140 }
141
142 struct PinArgVisitor<'tcx> {
143     ref_gen_ty: Ty<'tcx>,
144     tcx: TyCtxt<'tcx>,
145 }
146
147 impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> {
148     fn tcx(&self) -> TyCtxt<'tcx> {
149         self.tcx
150     }
151
152     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
153         assert_ne!(*local, SELF_ARG);
154     }
155
156     fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
157         if place.local == SELF_ARG {
158             replace_base(
159                 place,
160                 Place {
161                     local: SELF_ARG,
162                     projection: self.tcx().intern_place_elems(&[ProjectionElem::Field(
163                         Field::new(0),
164                         self.ref_gen_ty,
165                     )]),
166                 },
167                 self.tcx,
168             );
169         } else {
170             self.visit_local(&mut place.local, context, location);
171
172             for elem in place.projection.iter() {
173                 if let PlaceElem::Index(local) = elem {
174                     assert_ne!(local, SELF_ARG);
175                 }
176             }
177         }
178     }
179 }
180
181 fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
182     place.local = new_base.local;
183
184     let mut new_projection = new_base.projection.to_vec();
185     new_projection.append(&mut place.projection.to_vec());
186
187     place.projection = tcx.intern_place_elems(&new_projection);
188 }
189
190 const SELF_ARG: Local = Local::from_u32(1);
191
192 /// Generator has not been resumed yet.
193 const UNRESUMED: usize = GeneratorSubsts::UNRESUMED;
194 /// Generator has returned / is completed.
195 const RETURNED: usize = GeneratorSubsts::RETURNED;
196 /// Generator has panicked and is poisoned.
197 const POISONED: usize = GeneratorSubsts::POISONED;
198
199 /// Number of variants to reserve in generator state. Corresponds to
200 /// `UNRESUMED` (beginning of a generator) and `RETURNED`/`POISONED`
201 /// (end of a generator) states.
202 const RESERVED_VARIANTS: usize = 3;
203
204 /// A `yield` point in the generator.
205 struct SuspensionPoint<'tcx> {
206     /// State discriminant used when suspending or resuming at this point.
207     state: usize,
208     /// The block to jump to after resumption.
209     resume: BasicBlock,
210     /// Where to move the resume argument after resumption.
211     resume_arg: Place<'tcx>,
212     /// Which block to jump to if the generator is dropped in this state.
213     drop: Option<BasicBlock>,
214     /// Set of locals that have live storage while at this suspension point.
215     storage_liveness: GrowableBitSet<Local>,
216 }
217
218 struct TransformVisitor<'tcx> {
219     tcx: TyCtxt<'tcx>,
220     is_async_kind: bool,
221     state_adt_ref: AdtDef<'tcx>,
222     state_substs: SubstsRef<'tcx>,
223
224     // The type of the discriminant in the generator struct
225     discr_ty: Ty<'tcx>,
226
227     // Mapping from Local to (type of local, generator struct index)
228     // FIXME(eddyb) This should use `IndexVec<Local, Option<_>>`.
229     remap: FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
230
231     // A map from a suspension point in a block to the locals which have live storage at that point
232     storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>,
233
234     // A list of suspension points, generated during the transform
235     suspension_points: Vec<SuspensionPoint<'tcx>>,
236
237     // The set of locals that have no `StorageLive`/`StorageDead` annotations.
238     always_live_locals: BitSet<Local>,
239
240     // The original RETURN_PLACE local
241     new_ret_local: Local,
242 }
243
244 impl<'tcx> TransformVisitor<'tcx> {
245     // Make a `GeneratorState` or `Poll` variant assignment.
246     //
247     // `core::ops::GeneratorState` only has single element tuple variants,
248     // so we can just write to the downcasted first field and then set the
249     // discriminant to the appropriate variant.
250     fn make_state(
251         &self,
252         val: Operand<'tcx>,
253         source_info: SourceInfo,
254         is_return: bool,
255         statements: &mut Vec<Statement<'tcx>>,
256     ) {
257         let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
258             (true, false) => 1,  // GeneratorState::Complete
259             (false, false) => 0, // GeneratorState::Yielded
260             (true, true) => 0,   // Poll::Ready
261             (false, true) => 1,  // Poll::Pending
262         });
263
264         let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_substs, None, None);
265
266         // `Poll::Pending`
267         if self.is_async_kind && idx == VariantIdx::new(1) {
268             assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
269
270             // FIXME(swatinem): assert that `val` is indeed unit?
271             statements.extend(expand_aggregate(
272                 Place::return_place(),
273                 std::iter::empty(),
274                 kind,
275                 source_info,
276                 self.tcx,
277             ));
278             return;
279         }
280
281         // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
282         assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
283
284         let ty = self
285             .tcx
286             .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did)
287             .subst(self.tcx, self.state_substs);
288
289         statements.extend(expand_aggregate(
290             Place::return_place(),
291             std::iter::once((val, ty)),
292             kind,
293             source_info,
294             self.tcx,
295         ));
296     }
297
298     // Create a Place referencing a generator struct field
299     fn make_field(&self, variant_index: VariantIdx, idx: usize, ty: Ty<'tcx>) -> Place<'tcx> {
300         let self_place = Place::from(SELF_ARG);
301         let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
302         let mut projection = base.projection.to_vec();
303         projection.push(ProjectionElem::Field(Field::new(idx), ty));
304
305         Place { local: base.local, projection: self.tcx.intern_place_elems(&projection) }
306     }
307
308     // Create a statement which changes the discriminant
309     fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
310         let self_place = Place::from(SELF_ARG);
311         Statement {
312             source_info,
313             kind: StatementKind::SetDiscriminant {
314                 place: Box::new(self_place),
315                 variant_index: state_disc,
316             },
317         }
318     }
319
320     // Create a statement which reads the discriminant into a temporary
321     fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
322         let temp_decl = LocalDecl::new(self.discr_ty, body.span).internal();
323         let local_decls_len = body.local_decls.push(temp_decl);
324         let temp = Place::from(local_decls_len);
325
326         let self_place = Place::from(SELF_ARG);
327         let assign = Statement {
328             source_info: SourceInfo::outermost(body.span),
329             kind: StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))),
330         };
331         (assign, temp)
332     }
333 }
334
335 impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
336     fn tcx(&self) -> TyCtxt<'tcx> {
337         self.tcx
338     }
339
340     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
341         assert_eq!(self.remap.get(local), None);
342     }
343
344     fn visit_place(
345         &mut self,
346         place: &mut Place<'tcx>,
347         _context: PlaceContext,
348         _location: Location,
349     ) {
350         // Replace an Local in the remap with a generator struct access
351         if let Some(&(ty, variant_index, idx)) = self.remap.get(&place.local) {
352             replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
353         }
354     }
355
356     fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
357         // Remove StorageLive and StorageDead statements for remapped locals
358         data.retain_statements(|s| match s.kind {
359             StatementKind::StorageLive(l) | StatementKind::StorageDead(l) => {
360                 !self.remap.contains_key(&l)
361             }
362             _ => true,
363         });
364
365         let ret_val = match data.terminator().kind {
366             TerminatorKind::Return => {
367                 Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None))
368             }
369             TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
370                 Some((false, Some((resume, resume_arg)), value.clone(), drop))
371             }
372             _ => None,
373         };
374
375         if let Some((is_return, resume, v, drop)) = ret_val {
376             let source_info = data.terminator().source_info;
377             // We must assign the value first in case it gets declared dead below
378             self.make_state(v, source_info, is_return, &mut data.statements);
379             let state = if let Some((resume, mut resume_arg)) = resume {
380                 // Yield
381                 let state = RESERVED_VARIANTS + self.suspension_points.len();
382
383                 // The resume arg target location might itself be remapped if its base local is
384                 // live across a yield.
385                 let resume_arg =
386                     if let Some(&(ty, variant, idx)) = self.remap.get(&resume_arg.local) {
387                         replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
388                         resume_arg
389                     } else {
390                         resume_arg
391                     };
392
393                 self.suspension_points.push(SuspensionPoint {
394                     state,
395                     resume,
396                     resume_arg,
397                     drop,
398                     storage_liveness: self.storage_liveness[block].clone().unwrap().into(),
399                 });
400
401                 VariantIdx::new(state)
402             } else {
403                 // Return
404                 VariantIdx::new(RETURNED) // state for returned
405             };
406             data.statements.push(self.set_discr(state, source_info));
407             data.terminator_mut().kind = TerminatorKind::Return;
408         }
409
410         self.super_basic_block_data(block, data);
411     }
412 }
413
414 fn make_generator_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
415     let gen_ty = body.local_decls.raw[1].ty;
416
417     let ref_gen_ty =
418         tcx.mk_ref(tcx.lifetimes.re_erased, ty::TypeAndMut { ty: gen_ty, mutbl: Mutability::Mut });
419
420     // Replace the by value generator argument
421     body.local_decls.raw[1].ty = ref_gen_ty;
422
423     // Add a deref to accesses of the generator state
424     DerefArgVisitor { tcx }.visit_body(body);
425 }
426
427 fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
428     let ref_gen_ty = body.local_decls.raw[1].ty;
429
430     let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span));
431     let pin_adt_ref = tcx.adt_def(pin_did);
432     let substs = tcx.intern_substs(&[ref_gen_ty.into()]);
433     let pin_ref_gen_ty = tcx.mk_adt(pin_adt_ref, substs);
434
435     // Replace the by ref generator argument
436     body.local_decls.raw[1].ty = pin_ref_gen_ty;
437
438     // Add the Pin field access to accesses of the generator state
439     PinArgVisitor { ref_gen_ty, tcx }.visit_body(body);
440 }
441
442 /// Allocates a new local and replaces all references of `local` with it. Returns the new local.
443 ///
444 /// `local` will be changed to a new local decl with type `ty`.
445 ///
446 /// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
447 /// valid value to it before its first use.
448 fn replace_local<'tcx>(
449     local: Local,
450     ty: Ty<'tcx>,
451     body: &mut Body<'tcx>,
452     tcx: TyCtxt<'tcx>,
453 ) -> Local {
454     let new_decl = LocalDecl::new(ty, body.span);
455     let new_local = body.local_decls.push(new_decl);
456     body.local_decls.swap(local, new_local);
457
458     RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body);
459
460     new_local
461 }
462
463 /// Transforms the `body` of the generator applying the following transforms:
464 ///
465 /// - Eliminates all the `get_context` calls that async lowering created.
466 /// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
467 ///
468 /// The `Local`s that have their types replaced are:
469 /// - The `resume` argument itself.
470 /// - The argument to `get_context`.
471 /// - The yielded value of a `yield`.
472 ///
473 /// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
474 /// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
475 ///
476 /// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
477 /// but rather directly use `&mut Context<'_>`, however that would currently
478 /// lead to higher-kinded lifetime errors.
479 /// See <https://github.com/rust-lang/rust/issues/105501>.
480 ///
481 /// The async lowering step and the type / lifetime inference / checking are
482 /// still using the `ResumeTy` indirection for the time being, and that indirection
483 /// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
484 fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
485     let context_mut_ref = tcx.mk_task_context();
486
487     // replace the type of the `resume` argument
488     replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);
489
490     let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
491
492     for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
493         let bb_data = &body[bb];
494         if bb_data.is_cleanup {
495             continue;
496         }
497
498         match &bb_data.terminator().kind {
499             TerminatorKind::Call { func, .. } => {
500                 let func_ty = func.ty(body, tcx);
501                 if let ty::FnDef(def_id, _) = *func_ty.kind() {
502                     if def_id == get_context_def_id {
503                         let local = eliminate_get_context_call(&mut body[bb]);
504                         replace_resume_ty_local(tcx, body, local, context_mut_ref);
505                     }
506                 } else {
507                     continue;
508                 }
509             }
510             TerminatorKind::Yield { resume_arg, .. } => {
511                 replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
512             }
513             _ => {}
514         }
515     }
516 }
517
518 fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
519     let terminator = bb_data.terminator.take().unwrap();
520     if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
521         let arg = args.pop().unwrap();
522         let local = arg.place().unwrap().local;
523
524         let arg = Rvalue::Use(arg);
525         let assign = Statement {
526             source_info: terminator.source_info,
527             kind: StatementKind::Assign(Box::new((destination, arg))),
528         };
529         bb_data.statements.push(assign);
530         bb_data.terminator = Some(Terminator {
531             source_info: terminator.source_info,
532             kind: TerminatorKind::Goto { target: target.unwrap() },
533         });
534         local
535     } else {
536         bug!();
537     }
538 }
539
540 #[cfg_attr(not(debug_assertions), allow(unused))]
541 fn replace_resume_ty_local<'tcx>(
542     tcx: TyCtxt<'tcx>,
543     body: &mut Body<'tcx>,
544     local: Local,
545     context_mut_ref: Ty<'tcx>,
546 ) {
547     let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
548     // We have to replace the `ResumeTy` that is used for type and borrow checking
549     // with `&mut Context<'_>` in MIR.
550     #[cfg(debug_assertions)]
551     {
552         if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
553             let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
554             assert_eq!(*resume_ty_adt, expected_adt);
555         } else {
556             panic!("expected `ResumeTy`, found `{:?}`", local_ty);
557         };
558     }
559 }
560
561 struct LivenessInfo {
562     /// Which locals are live across any suspension point.
563     saved_locals: GeneratorSavedLocals,
564
565     /// The set of saved locals live at each suspension point.
566     live_locals_at_suspension_points: Vec<BitSet<GeneratorSavedLocal>>,
567
568     /// Parallel vec to the above with SourceInfo for each yield terminator.
569     source_info_at_suspension_points: Vec<SourceInfo>,
570
571     /// For every saved local, the set of other saved locals that are
572     /// storage-live at the same time as this local. We cannot overlap locals in
573     /// the layout which have conflicting storage.
574     storage_conflicts: BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>,
575
576     /// For every suspending block, the locals which are storage-live across
577     /// that suspension point.
578     storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>,
579 }
580
581 fn locals_live_across_suspend_points<'tcx>(
582     tcx: TyCtxt<'tcx>,
583     body: &Body<'tcx>,
584     always_live_locals: &BitSet<Local>,
585     movable: bool,
586 ) -> LivenessInfo {
587     let body_ref: &Body<'_> = &body;
588
589     // Calculate when MIR locals have live storage. This gives us an upper bound of their
590     // lifetimes.
591     let mut storage_live = MaybeStorageLive::new(std::borrow::Cow::Borrowed(always_live_locals))
592         .into_engine(tcx, body_ref)
593         .iterate_to_fixpoint()
594         .into_results_cursor(body_ref);
595
596     // Calculate the MIR locals which have been previously
597     // borrowed (even if they are still active).
598     let borrowed_locals_results =
599         MaybeBorrowedLocals.into_engine(tcx, body_ref).pass_name("generator").iterate_to_fixpoint();
600
601     let mut borrowed_locals_cursor =
602         rustc_mir_dataflow::ResultsCursor::new(body_ref, &borrowed_locals_results);
603
604     // Calculate the MIR locals that we actually need to keep storage around
605     // for.
606     let requires_storage_results = MaybeRequiresStorage::new(body, &borrowed_locals_results)
607         .into_engine(tcx, body_ref)
608         .iterate_to_fixpoint();
609     let mut requires_storage_cursor =
610         rustc_mir_dataflow::ResultsCursor::new(body_ref, &requires_storage_results);
611
612     // Calculate the liveness of MIR locals ignoring borrows.
613     let mut liveness = MaybeLiveLocals
614         .into_engine(tcx, body_ref)
615         .pass_name("generator")
616         .iterate_to_fixpoint()
617         .into_results_cursor(body_ref);
618
619     let mut storage_liveness_map = IndexVec::from_elem(None, &body.basic_blocks);
620     let mut live_locals_at_suspension_points = Vec::new();
621     let mut source_info_at_suspension_points = Vec::new();
622     let mut live_locals_at_any_suspension_point = BitSet::new_empty(body.local_decls.len());
623
624     for (block, data) in body.basic_blocks.iter_enumerated() {
625         if let TerminatorKind::Yield { .. } = data.terminator().kind {
626             let loc = Location { block, statement_index: data.statements.len() };
627
628             liveness.seek_to_block_end(block);
629             let mut live_locals: BitSet<_> = BitSet::new_empty(body.local_decls.len());
630             live_locals.union(liveness.get());
631
632             if !movable {
633                 // The `liveness` variable contains the liveness of MIR locals ignoring borrows.
634                 // This is correct for movable generators since borrows cannot live across
635                 // suspension points. However for immovable generators we need to account for
636                 // borrows, so we conservatively assume that all borrowed locals are live until
637                 // we find a StorageDead statement referencing the locals.
638                 // To do this we just union our `liveness` result with `borrowed_locals`, which
639                 // contains all the locals which has been borrowed before this suspension point.
640                 // If a borrow is converted to a raw reference, we must also assume that it lives
641                 // forever. Note that the final liveness is still bounded by the storage liveness
642                 // of the local, which happens using the `intersect` operation below.
643                 borrowed_locals_cursor.seek_before_primary_effect(loc);
644                 live_locals.union(borrowed_locals_cursor.get());
645             }
646
647             // Store the storage liveness for later use so we can restore the state
648             // after a suspension point
649             storage_live.seek_before_primary_effect(loc);
650             storage_liveness_map[block] = Some(storage_live.get().clone());
651
652             // Locals live are live at this point only if they are used across
653             // suspension points (the `liveness` variable)
654             // and their storage is required (the `storage_required` variable)
655             requires_storage_cursor.seek_before_primary_effect(loc);
656             live_locals.intersect(requires_storage_cursor.get());
657
658             // The generator argument is ignored.
659             live_locals.remove(SELF_ARG);
660
661             debug!("loc = {:?}, live_locals = {:?}", loc, live_locals);
662
663             // Add the locals live at this suspension point to the set of locals which live across
664             // any suspension points
665             live_locals_at_any_suspension_point.union(&live_locals);
666
667             live_locals_at_suspension_points.push(live_locals);
668             source_info_at_suspension_points.push(data.terminator().source_info);
669         }
670     }
671
672     debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point);
673     let saved_locals = GeneratorSavedLocals(live_locals_at_any_suspension_point);
674
675     // Renumber our liveness_map bitsets to include only the locals we are
676     // saving.
677     let live_locals_at_suspension_points = live_locals_at_suspension_points
678         .iter()
679         .map(|live_here| saved_locals.renumber_bitset(&live_here))
680         .collect();
681
682     let storage_conflicts = compute_storage_conflicts(
683         body_ref,
684         &saved_locals,
685         always_live_locals.clone(),
686         requires_storage_results,
687     );
688
689     LivenessInfo {
690         saved_locals,
691         live_locals_at_suspension_points,
692         source_info_at_suspension_points,
693         storage_conflicts,
694         storage_liveness: storage_liveness_map,
695     }
696 }
697
698 /// The set of `Local`s that must be saved across yield points.
699 ///
700 /// `GeneratorSavedLocal` is indexed in terms of the elements in this set;
701 /// i.e. `GeneratorSavedLocal::new(1)` corresponds to the second local
702 /// included in this set.
703 struct GeneratorSavedLocals(BitSet<Local>);
704
705 impl GeneratorSavedLocals {
706     /// Returns an iterator over each `GeneratorSavedLocal` along with the `Local` it corresponds
707     /// to.
708     fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)> {
709         self.iter().enumerate().map(|(i, l)| (GeneratorSavedLocal::from(i), l))
710     }
711
712     /// Transforms a `BitSet<Local>` that contains only locals saved across yield points to the
713     /// equivalent `BitSet<GeneratorSavedLocal>`.
714     fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal> {
715         assert!(self.superset(&input), "{:?} not a superset of {:?}", self.0, input);
716         let mut out = BitSet::new_empty(self.count());
717         for (saved_local, local) in self.iter_enumerated() {
718             if input.contains(local) {
719                 out.insert(saved_local);
720             }
721         }
722         out
723     }
724
725     fn get(&self, local: Local) -> Option<GeneratorSavedLocal> {
726         if !self.contains(local) {
727             return None;
728         }
729
730         let idx = self.iter().take_while(|&l| l < local).count();
731         Some(GeneratorSavedLocal::new(idx))
732     }
733 }
734
735 impl ops::Deref for GeneratorSavedLocals {
736     type Target = BitSet<Local>;
737
738     fn deref(&self) -> &Self::Target {
739         &self.0
740     }
741 }
742
743 /// For every saved local, looks for which locals are StorageLive at the same
744 /// time. Generates a bitset for every local of all the other locals that may be
745 /// StorageLive simultaneously with that local. This is used in the layout
746 /// computation; see `GeneratorLayout` for more.
747 fn compute_storage_conflicts<'mir, 'tcx>(
748     body: &'mir Body<'tcx>,
749     saved_locals: &GeneratorSavedLocals,
750     always_live_locals: BitSet<Local>,
751     requires_storage: rustc_mir_dataflow::Results<'tcx, MaybeRequiresStorage<'mir, 'tcx>>,
752 ) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> {
753     assert_eq!(body.local_decls.len(), saved_locals.domain_size());
754
755     debug!("compute_storage_conflicts({:?})", body.span);
756     debug!("always_live = {:?}", always_live_locals);
757
758     // Locals that are always live or ones that need to be stored across
759     // suspension points are not eligible for overlap.
760     let mut ineligible_locals = always_live_locals;
761     ineligible_locals.intersect(&**saved_locals);
762
763     // Compute the storage conflicts for all eligible locals.
764     let mut visitor = StorageConflictVisitor {
765         body,
766         saved_locals: &saved_locals,
767         local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()),
768     };
769
770     requires_storage.visit_reachable_with(body, &mut visitor);
771
772     let local_conflicts = visitor.local_conflicts;
773
774     // Compress the matrix using only stored locals (Local -> GeneratorSavedLocal).
775     //
776     // NOTE: Today we store a full conflict bitset for every local. Technically
777     // this is twice as many bits as we need, since the relation is symmetric.
778     // However, in practice these bitsets are not usually large. The layout code
779     // also needs to keep track of how many conflicts each local has, so it's
780     // simpler to keep it this way for now.
781     let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count());
782     for (saved_local_a, local_a) in saved_locals.iter_enumerated() {
783         if ineligible_locals.contains(local_a) {
784             // Conflicts with everything.
785             storage_conflicts.insert_all_into_row(saved_local_a);
786         } else {
787             // Keep overlap information only for stored locals.
788             for (saved_local_b, local_b) in saved_locals.iter_enumerated() {
789                 if local_conflicts.contains(local_a, local_b) {
790                     storage_conflicts.insert(saved_local_a, saved_local_b);
791                 }
792             }
793         }
794     }
795     storage_conflicts
796 }
797
798 struct StorageConflictVisitor<'mir, 'tcx, 's> {
799     body: &'mir Body<'tcx>,
800     saved_locals: &'s GeneratorSavedLocals,
801     // FIXME(tmandry): Consider using sparse bitsets here once we have good
802     // benchmarks for generators.
803     local_conflicts: BitMatrix<Local, Local>,
804 }
805
806 impl<'mir, 'tcx> rustc_mir_dataflow::ResultsVisitor<'mir, 'tcx>
807     for StorageConflictVisitor<'mir, 'tcx, '_>
808 {
809     type FlowState = BitSet<Local>;
810
811     fn visit_statement_before_primary_effect(
812         &mut self,
813         state: &Self::FlowState,
814         _statement: &'mir Statement<'tcx>,
815         loc: Location,
816     ) {
817         self.apply_state(state, loc);
818     }
819
820     fn visit_terminator_before_primary_effect(
821         &mut self,
822         state: &Self::FlowState,
823         _terminator: &'mir Terminator<'tcx>,
824         loc: Location,
825     ) {
826         self.apply_state(state, loc);
827     }
828 }
829
830 impl StorageConflictVisitor<'_, '_, '_> {
831     fn apply_state(&mut self, flow_state: &BitSet<Local>, loc: Location) {
832         // Ignore unreachable blocks.
833         if self.body.basic_blocks[loc.block].terminator().kind == TerminatorKind::Unreachable {
834             return;
835         }
836
837         let mut eligible_storage_live = flow_state.clone();
838         eligible_storage_live.intersect(&**self.saved_locals);
839
840         for local in eligible_storage_live.iter() {
841             self.local_conflicts.union_row_with(&eligible_storage_live, local);
842         }
843
844         if eligible_storage_live.count() > 1 {
845             trace!("at {:?}, eligible_storage_live={:?}", loc, eligible_storage_live);
846         }
847     }
848 }
849
850 /// Validates the typeck view of the generator against the actual set of types saved between
851 /// yield points.
852 fn sanitize_witness<'tcx>(
853     tcx: TyCtxt<'tcx>,
854     body: &Body<'tcx>,
855     witness: Ty<'tcx>,
856     upvars: Vec<Ty<'tcx>>,
857     saved_locals: &GeneratorSavedLocals,
858 ) {
859     let did = body.source.def_id();
860     let param_env = tcx.param_env(did);
861
862     let allowed_upvars = tcx.normalize_erasing_regions(param_env, upvars);
863     let allowed = match witness.kind() {
864         &ty::GeneratorWitness(interior_tys) => {
865             tcx.normalize_erasing_late_bound_regions(param_env, interior_tys)
866         }
867         _ => {
868             tcx.sess.delay_span_bug(
869                 body.span,
870                 &format!("unexpected generator witness type {:?}", witness.kind()),
871             );
872             return;
873         }
874     };
875
876     for (local, decl) in body.local_decls.iter_enumerated() {
877         // Ignore locals which are internal or not saved between yields.
878         if !saved_locals.contains(local) || decl.internal {
879             continue;
880         }
881         let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
882
883         // Sanity check that typeck knows about the type of locals which are
884         // live across a suspension point
885         if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) {
886             span_bug!(
887                 body.span,
888                 "Broken MIR: generator contains type {} in MIR, \
889                        but typeck only knows about {} and {:?}",
890                 decl_ty,
891                 allowed,
892                 allowed_upvars
893             );
894         }
895     }
896 }
897
898 fn compute_layout<'tcx>(
899     liveness: LivenessInfo,
900     body: &mut Body<'tcx>,
901 ) -> (
902     FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
903     GeneratorLayout<'tcx>,
904     IndexVec<BasicBlock, Option<BitSet<Local>>>,
905 ) {
906     let LivenessInfo {
907         saved_locals,
908         live_locals_at_suspension_points,
909         source_info_at_suspension_points,
910         storage_conflicts,
911         storage_liveness,
912     } = liveness;
913
914     // Gather live local types and their indices.
915     let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
916     let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
917     for (saved_local, local) in saved_locals.iter_enumerated() {
918         locals.push(local);
919         tys.push(body.local_decls[local].ty);
920         debug!("generator saved local {:?} => {:?}", saved_local, local);
921     }
922
923     // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.
924     // In debuginfo, these will correspond to the beginning (UNRESUMED) or end
925     // (RETURNED, POISONED) of the function.
926     let body_span = body.source_scopes[OUTERMOST_SOURCE_SCOPE].span;
927     let mut variant_source_info: IndexVec<VariantIdx, SourceInfo> = [
928         SourceInfo::outermost(body_span.shrink_to_lo()),
929         SourceInfo::outermost(body_span.shrink_to_hi()),
930         SourceInfo::outermost(body_span.shrink_to_hi()),
931     ]
932     .iter()
933     .copied()
934     .collect();
935
936     // Build the generator variant field list.
937     // Create a map from local indices to generator struct indices.
938     let mut variant_fields: IndexVec<VariantIdx, IndexVec<Field, GeneratorSavedLocal>> =
939         iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect();
940     let mut remap = FxHashMap::default();
941     for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() {
942         let variant_index = VariantIdx::from(RESERVED_VARIANTS + suspension_point_idx);
943         let mut fields = IndexVec::new();
944         for (idx, saved_local) in live_locals.iter().enumerate() {
945             fields.push(saved_local);
946             // Note that if a field is included in multiple variants, we will
947             // just use the first one here. That's fine; fields do not move
948             // around inside generators, so it doesn't matter which variant
949             // index we access them by.
950             remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx));
951         }
952         variant_fields.push(fields);
953         variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
954     }
955     debug!("generator variant_fields = {:?}", variant_fields);
956     debug!("generator storage_conflicts = {:#?}", storage_conflicts);
957
958     let layout =
959         GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts };
960
961     (remap, layout, storage_liveness)
962 }
963
964 /// Replaces the entry point of `body` with a block that switches on the generator discriminant and
965 /// dispatches to blocks according to `cases`.
966 ///
967 /// After this function, the former entry point of the function will be bb1.
968 fn insert_switch<'tcx>(
969     body: &mut Body<'tcx>,
970     cases: Vec<(usize, BasicBlock)>,
971     transform: &TransformVisitor<'tcx>,
972     default: TerminatorKind<'tcx>,
973 ) {
974     let default_block = insert_term_block(body, default);
975     let (assign, discr) = transform.get_discr(body);
976     let switch_targets =
977         SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block);
978     let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets };
979
980     let source_info = SourceInfo::outermost(body.span);
981     body.basic_blocks_mut().raw.insert(
982         0,
983         BasicBlockData {
984             statements: vec![assign],
985             terminator: Some(Terminator { source_info, kind: switch }),
986             is_cleanup: false,
987         },
988     );
989
990     let blocks = body.basic_blocks_mut().iter_mut();
991
992     for target in blocks.flat_map(|b| b.terminator_mut().successors_mut()) {
993         *target = BasicBlock::new(target.index() + 1);
994     }
995 }
996
997 fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
998     use crate::shim::DropShimElaborator;
999     use rustc_middle::mir::patch::MirPatch;
1000     use rustc_mir_dataflow::elaborate_drops::{elaborate_drop, Unwind};
1001
1002     // Note that `elaborate_drops` only drops the upvars of a generator, and
1003     // this is ok because `open_drop` can only be reached within that own
1004     // generator's resume function.
1005
1006     let def_id = body.source.def_id();
1007     let param_env = tcx.param_env(def_id);
1008
1009     let mut elaborator = DropShimElaborator { body, patch: MirPatch::new(body), tcx, param_env };
1010
1011     for (block, block_data) in body.basic_blocks.iter_enumerated() {
1012         let (target, unwind, source_info) = match block_data.terminator() {
1013             Terminator { source_info, kind: TerminatorKind::Drop { place, target, unwind } } => {
1014                 if let Some(local) = place.as_local() {
1015                     if local == SELF_ARG {
1016                         (target, unwind, source_info)
1017                     } else {
1018                         continue;
1019                     }
1020                 } else {
1021                     continue;
1022                 }
1023             }
1024             _ => continue,
1025         };
1026         let unwind = if block_data.is_cleanup {
1027             Unwind::InCleanup
1028         } else {
1029             Unwind::To(unwind.unwrap_or_else(|| elaborator.patch.resume_block()))
1030         };
1031         elaborate_drop(
1032             &mut elaborator,
1033             *source_info,
1034             Place::from(SELF_ARG),
1035             (),
1036             *target,
1037             unwind,
1038             block,
1039         );
1040     }
1041     elaborator.patch.apply(body);
1042 }
1043
1044 fn create_generator_drop_shim<'tcx>(
1045     tcx: TyCtxt<'tcx>,
1046     transform: &TransformVisitor<'tcx>,
1047     gen_ty: Ty<'tcx>,
1048     body: &mut Body<'tcx>,
1049     drop_clean: BasicBlock,
1050 ) -> Body<'tcx> {
1051     let mut body = body.clone();
1052     body.arg_count = 1; // make sure the resume argument is not included here
1053
1054     let source_info = SourceInfo::outermost(body.span);
1055
1056     let mut cases = create_cases(&mut body, transform, Operation::Drop);
1057
1058     cases.insert(0, (UNRESUMED, drop_clean));
1059
1060     // The returned state and the poisoned state fall through to the default
1061     // case which is just to return
1062
1063     insert_switch(&mut body, cases, &transform, TerminatorKind::Return);
1064
1065     for block in body.basic_blocks_mut() {
1066         let kind = &mut block.terminator_mut().kind;
1067         if let TerminatorKind::GeneratorDrop = *kind {
1068             *kind = TerminatorKind::Return;
1069         }
1070     }
1071
1072     // Replace the return variable
1073     body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.mk_unit(), source_info);
1074
1075     make_generator_state_argument_indirect(tcx, &mut body);
1076
1077     // Change the generator argument from &mut to *mut
1078     body.local_decls[SELF_ARG] = LocalDecl::with_source_info(
1079         tcx.mk_ptr(ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }),
1080         source_info,
1081     );
1082
1083     // Make sure we remove dead blocks to remove
1084     // unrelated code from the resume part of the function
1085     simplify::remove_dead_blocks(tcx, &mut body);
1086
1087     dump_mir(tcx, false, "generator_drop", &0, &body, |_, _| Ok(()));
1088
1089     body
1090 }
1091
1092 fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock {
1093     let source_info = SourceInfo::outermost(body.span);
1094     body.basic_blocks_mut().push(BasicBlockData {
1095         statements: Vec::new(),
1096         terminator: Some(Terminator { source_info, kind }),
1097         is_cleanup: false,
1098     })
1099 }
1100
1101 fn insert_panic_block<'tcx>(
1102     tcx: TyCtxt<'tcx>,
1103     body: &mut Body<'tcx>,
1104     message: AssertMessage<'tcx>,
1105 ) -> BasicBlock {
1106     let assert_block = BasicBlock::new(body.basic_blocks.len());
1107     let term = TerminatorKind::Assert {
1108         cond: Operand::Constant(Box::new(Constant {
1109             span: body.span,
1110             user_ty: None,
1111             literal: ConstantKind::from_bool(tcx, false),
1112         })),
1113         expected: true,
1114         msg: message,
1115         target: assert_block,
1116         cleanup: None,
1117     };
1118
1119     let source_info = SourceInfo::outermost(body.span);
1120     body.basic_blocks_mut().push(BasicBlockData {
1121         statements: Vec::new(),
1122         terminator: Some(Terminator { source_info, kind: term }),
1123         is_cleanup: false,
1124     });
1125
1126     assert_block
1127 }
1128
1129 fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, param_env: ty::ParamEnv<'tcx>) -> bool {
1130     // Returning from a function with an uninhabited return type is undefined behavior.
1131     if body.return_ty().is_privately_uninhabited(tcx, param_env) {
1132         return false;
1133     }
1134
1135     // If there's a return terminator the function may return.
1136     for block in body.basic_blocks.iter() {
1137         if let TerminatorKind::Return = block.terminator().kind {
1138             return true;
1139         }
1140     }
1141
1142     // Otherwise the function can't return.
1143     false
1144 }
1145
1146 fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
1147     // Nothing can unwind when landing pads are off.
1148     if tcx.sess.panic_strategy() == PanicStrategy::Abort {
1149         return false;
1150     }
1151
1152     // Unwinds can only start at certain terminators.
1153     for block in body.basic_blocks.iter() {
1154         match block.terminator().kind {
1155             // These never unwind.
1156             TerminatorKind::Goto { .. }
1157             | TerminatorKind::SwitchInt { .. }
1158             | TerminatorKind::Abort
1159             | TerminatorKind::Return
1160             | TerminatorKind::Unreachable
1161             | TerminatorKind::GeneratorDrop
1162             | TerminatorKind::FalseEdge { .. }
1163             | TerminatorKind::FalseUnwind { .. } => {}
1164
1165             // Resume will *continue* unwinding, but if there's no other unwinding terminator it
1166             // will never be reached.
1167             TerminatorKind::Resume => {}
1168
1169             TerminatorKind::Yield { .. } => {
1170                 unreachable!("`can_unwind` called before generator transform")
1171             }
1172
1173             // These may unwind.
1174             TerminatorKind::Drop { .. }
1175             | TerminatorKind::DropAndReplace { .. }
1176             | TerminatorKind::Call { .. }
1177             | TerminatorKind::InlineAsm { .. }
1178             | TerminatorKind::Assert { .. } => return true,
1179         }
1180     }
1181
1182     // If we didn't find an unwinding terminator, the function cannot unwind.
1183     false
1184 }
1185
1186 fn create_generator_resume_function<'tcx>(
1187     tcx: TyCtxt<'tcx>,
1188     transform: TransformVisitor<'tcx>,
1189     body: &mut Body<'tcx>,
1190     can_return: bool,
1191 ) {
1192     let can_unwind = can_unwind(tcx, body);
1193
1194     // Poison the generator when it unwinds
1195     if can_unwind {
1196         let source_info = SourceInfo::outermost(body.span);
1197         let poison_block = body.basic_blocks_mut().push(BasicBlockData {
1198             statements: vec![transform.set_discr(VariantIdx::new(POISONED), source_info)],
1199             terminator: Some(Terminator { source_info, kind: TerminatorKind::Resume }),
1200             is_cleanup: true,
1201         });
1202
1203         for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() {
1204             let source_info = block.terminator().source_info;
1205
1206             if let TerminatorKind::Resume = block.terminator().kind {
1207                 // An existing `Resume` terminator is redirected to jump to our dedicated
1208                 // "poisoning block" above.
1209                 if idx != poison_block {
1210                     *block.terminator_mut() = Terminator {
1211                         source_info,
1212                         kind: TerminatorKind::Goto { target: poison_block },
1213                     };
1214                 }
1215             } else if !block.is_cleanup {
1216                 // Any terminators that *can* unwind but don't have an unwind target set are also
1217                 // pointed at our poisoning block (unless they're part of the cleanup path).
1218                 if let Some(unwind @ None) = block.terminator_mut().unwind_mut() {
1219                     *unwind = Some(poison_block);
1220                 }
1221             }
1222         }
1223     }
1224
1225     let mut cases = create_cases(body, &transform, Operation::Resume);
1226
1227     use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
1228
1229     // Jump to the entry point on the unresumed
1230     cases.insert(0, (UNRESUMED, BasicBlock::new(0)));
1231
1232     // Panic when resumed on the returned or poisoned state
1233     let generator_kind = body.generator_kind().unwrap();
1234
1235     if can_unwind {
1236         cases.insert(
1237             1,
1238             (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(generator_kind))),
1239         );
1240     }
1241
1242     if can_return {
1243         cases.insert(
1244             1,
1245             (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(generator_kind))),
1246         );
1247     }
1248
1249     insert_switch(body, cases, &transform, TerminatorKind::Unreachable);
1250
1251     make_generator_state_argument_indirect(tcx, body);
1252     make_generator_state_argument_pinned(tcx, body);
1253
1254     // Make sure we remove dead blocks to remove
1255     // unrelated code from the drop part of the function
1256     simplify::remove_dead_blocks(tcx, body);
1257
1258     dump_mir(tcx, false, "generator_resume", &0, body, |_, _| Ok(()));
1259 }
1260
1261 fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock {
1262     let return_block = insert_term_block(body, TerminatorKind::Return);
1263
1264     let term =
1265         TerminatorKind::Drop { place: Place::from(SELF_ARG), target: return_block, unwind: None };
1266     let source_info = SourceInfo::outermost(body.span);
1267
1268     // Create a block to destroy an unresumed generators. This can only destroy upvars.
1269     body.basic_blocks_mut().push(BasicBlockData {
1270         statements: Vec::new(),
1271         terminator: Some(Terminator { source_info, kind: term }),
1272         is_cleanup: false,
1273     })
1274 }
1275
1276 /// An operation that can be performed on a generator.
1277 #[derive(PartialEq, Copy, Clone)]
1278 enum Operation {
1279     Resume,
1280     Drop,
1281 }
1282
1283 impl Operation {
1284     fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> {
1285         match self {
1286             Operation::Resume => Some(point.resume),
1287             Operation::Drop => point.drop,
1288         }
1289     }
1290 }
1291
1292 fn create_cases<'tcx>(
1293     body: &mut Body<'tcx>,
1294     transform: &TransformVisitor<'tcx>,
1295     operation: Operation,
1296 ) -> Vec<(usize, BasicBlock)> {
1297     let source_info = SourceInfo::outermost(body.span);
1298
1299     transform
1300         .suspension_points
1301         .iter()
1302         .filter_map(|point| {
1303             // Find the target for this suspension point, if applicable
1304             operation.target_block(point).map(|target| {
1305                 let mut statements = Vec::new();
1306
1307                 // Create StorageLive instructions for locals with live storage
1308                 for i in 0..(body.local_decls.len()) {
1309                     if i == 2 {
1310                         // The resume argument is live on function entry. Don't insert a
1311                         // `StorageLive`, or the following `Assign` will read from uninitialized
1312                         // memory.
1313                         continue;
1314                     }
1315
1316                     let l = Local::new(i);
1317                     let needs_storage_live = point.storage_liveness.contains(l)
1318                         && !transform.remap.contains_key(&l)
1319                         && !transform.always_live_locals.contains(l);
1320                     if needs_storage_live {
1321                         statements
1322                             .push(Statement { source_info, kind: StatementKind::StorageLive(l) });
1323                     }
1324                 }
1325
1326                 if operation == Operation::Resume {
1327                     // Move the resume argument to the destination place of the `Yield` terminator
1328                     let resume_arg = Local::new(2); // 0 = return, 1 = self
1329                     statements.push(Statement {
1330                         source_info,
1331                         kind: StatementKind::Assign(Box::new((
1332                             point.resume_arg,
1333                             Rvalue::Use(Operand::Move(resume_arg.into())),
1334                         ))),
1335                     });
1336                 }
1337
1338                 // Then jump to the real target
1339                 let block = body.basic_blocks_mut().push(BasicBlockData {
1340                     statements,
1341                     terminator: Some(Terminator {
1342                         source_info,
1343                         kind: TerminatorKind::Goto { target },
1344                     }),
1345                     is_cleanup: false,
1346                 });
1347
1348                 (point.state, block)
1349             })
1350         })
1351         .collect()
1352 }
1353
1354 impl<'tcx> MirPass<'tcx> for StateTransform {
1355     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
1356         let Some(yield_ty) = body.yield_ty() else {
1357             // This only applies to generators
1358             return;
1359         };
1360
1361         assert!(body.generator_drop().is_none());
1362
1363         // The first argument is the generator type passed by value
1364         let gen_ty = body.local_decls.raw[1].ty;
1365
1366         // Get the interior types and substs which typeck computed
1367         let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() {
1368             ty::Generator(_, substs, movability) => {
1369                 let substs = substs.as_generator();
1370                 (
1371                     substs.upvar_tys().collect(),
1372                     substs.witness(),
1373                     substs.discr_ty(tcx),
1374                     movability == hir::Movability::Movable,
1375                 )
1376             }
1377             _ => {
1378                 tcx.sess
1379                     .delay_span_bug(body.span, &format!("unexpected generator type {}", gen_ty));
1380                 return;
1381             }
1382         };
1383
1384         let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_)));
1385         let (state_adt_ref, state_substs) = if is_async_kind {
1386             // Compute Poll<return_ty>
1387             let poll_did = tcx.require_lang_item(LangItem::Poll, None);
1388             let poll_adt_ref = tcx.adt_def(poll_did);
1389             let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
1390             (poll_adt_ref, poll_substs)
1391         } else {
1392             // Compute GeneratorState<yield_ty, return_ty>
1393             let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
1394             let state_adt_ref = tcx.adt_def(state_did);
1395             let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]);
1396             (state_adt_ref, state_substs)
1397         };
1398         let ret_ty = tcx.mk_adt(state_adt_ref, state_substs);
1399
1400         // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1401         // RETURN_PLACE then is a fresh unused local with type ret_ty.
1402         let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
1403
1404         // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1405         if is_async_kind {
1406             transform_async_context(tcx, body);
1407         }
1408
1409         // We also replace the resume argument and insert an `Assign`.
1410         // This is needed because the resume argument `_2` might be live across a `yield`, in which
1411         // case there is no `Assign` to it that the transform can turn into a store to the generator
1412         // state. After the yield the slot in the generator state would then be uninitialized.
1413         let resume_local = Local::new(2);
1414         let resume_ty =
1415             if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty };
1416         let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
1417
1418         // When first entering the generator, move the resume argument into its new local.
1419         let source_info = SourceInfo::outermost(body.span);
1420         let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements;
1421         stmts.insert(
1422             0,
1423             Statement {
1424                 source_info,
1425                 kind: StatementKind::Assign(Box::new((
1426                     new_resume_local.into(),
1427                     Rvalue::Use(Operand::Move(resume_local.into())),
1428                 ))),
1429             },
1430         );
1431
1432         let always_live_locals = always_storage_live_locals(&body);
1433
1434         let liveness_info =
1435             locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1436
1437         sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals);
1438
1439         if tcx.sess.opts.unstable_opts.validate_mir {
1440             let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
1441                 assigned_local: None,
1442                 saved_locals: &liveness_info.saved_locals,
1443                 storage_conflicts: &liveness_info.storage_conflicts,
1444             };
1445
1446             vis.visit_body(body);
1447         }
1448
1449         // Extract locals which are live across suspension point into `layout`
1450         // `remap` gives a mapping from local indices onto generator struct indices
1451         // `storage_liveness` tells us which locals have live storage at suspension points
1452         let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
1453
1454         let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id()));
1455
1456         // Run the transformation which converts Places from Local to generator struct
1457         // accesses for locals in `remap`.
1458         // It also rewrites `return x` and `yield y` as writing a new generator state and returning
1459         // either GeneratorState::Complete(x) and GeneratorState::Yielded(y),
1460         // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
1461         let mut transform = TransformVisitor {
1462             tcx,
1463             is_async_kind,
1464             state_adt_ref,
1465             state_substs,
1466             remap,
1467             storage_liveness,
1468             always_live_locals,
1469             suspension_points: Vec::new(),
1470             new_ret_local,
1471             discr_ty,
1472         };
1473         transform.visit_body(body);
1474
1475         // Update our MIR struct to reflect the changes we've made
1476         body.arg_count = 2; // self, resume arg
1477         body.spread_arg = None;
1478
1479         body.generator.as_mut().unwrap().yield_ty = None;
1480         body.generator.as_mut().unwrap().generator_layout = Some(layout);
1481
1482         // Insert `drop(generator_struct)` which is used to drop upvars for generators in
1483         // the unresumed state.
1484         // This is expanded to a drop ladder in `elaborate_generator_drops`.
1485         let drop_clean = insert_clean_drop(body);
1486
1487         dump_mir(tcx, false, "generator_pre-elab", &0, body, |_, _| Ok(()));
1488
1489         // Expand `drop(generator_struct)` to a drop ladder which destroys upvars.
1490         // If any upvars are moved out of, drop elaboration will handle upvar destruction.
1491         // However we need to also elaborate the code generated by `insert_clean_drop`.
1492         elaborate_generator_drops(tcx, body);
1493
1494         dump_mir(tcx, false, "generator_post-transform", &0, body, |_, _| Ok(()));
1495
1496         // Create a copy of our MIR and use it to create the drop shim for the generator
1497         let drop_shim = create_generator_drop_shim(tcx, &transform, gen_ty, body, drop_clean);
1498
1499         body.generator.as_mut().unwrap().generator_drop = Some(drop_shim);
1500
1501         // Create the Generator::resume / Future::poll function
1502         create_generator_resume_function(tcx, transform, body, can_return);
1503
1504         // Run derefer to fix Derefs that are not in the first place
1505         deref_finder(tcx, body);
1506     }
1507 }
1508
1509 /// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields
1510 /// in the generator state machine but whose storage is not marked as conflicting
1511 ///
1512 /// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after.
1513 ///
1514 /// This condition would arise when the assignment is the last use of `_5` but the initial
1515 /// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as
1516 /// conflicting. Non-conflicting generator saved locals may be stored at the same location within
1517 /// the generator state machine, which would result in ill-formed MIR: the left-hand and right-hand
1518 /// sides of an assignment may not alias. This caused a miscompilation in [#73137].
1519 ///
1520 /// [#73137]: https://github.com/rust-lang/rust/issues/73137
1521 struct EnsureGeneratorFieldAssignmentsNeverAlias<'a> {
1522     saved_locals: &'a GeneratorSavedLocals,
1523     storage_conflicts: &'a BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>,
1524     assigned_local: Option<GeneratorSavedLocal>,
1525 }
1526
1527 impl EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
1528     fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<GeneratorSavedLocal> {
1529         if place.is_indirect() {
1530             return None;
1531         }
1532
1533         self.saved_locals.get(place.local)
1534     }
1535
1536     fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) {
1537         if let Some(assigned_local) = self.saved_local_for_direct_place(place) {
1538             assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse");
1539
1540             self.assigned_local = Some(assigned_local);
1541             f(self);
1542             self.assigned_local = None;
1543         }
1544     }
1545 }
1546
1547 impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
1548     fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
1549         let Some(lhs) = self.assigned_local else {
1550             // This visitor only invokes `visit_place` for the right-hand side of an assignment
1551             // and only after setting `self.assigned_local`. However, the default impl of
1552             // `Visitor::super_body` may call `visit_place` with a `NonUseContext` for places
1553             // with debuginfo. Ignore them here.
1554             assert!(!context.is_use());
1555             return;
1556         };
1557
1558         let Some(rhs) = self.saved_local_for_direct_place(*place) else { return };
1559
1560         if !self.storage_conflicts.contains(lhs, rhs) {
1561             bug!(
1562                 "Assignment between generator saved locals whose storage is not \
1563                     marked as conflicting: {:?}: {:?} = {:?}",
1564                 location,
1565                 lhs,
1566                 rhs,
1567             );
1568         }
1569     }
1570
1571     fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
1572         match &statement.kind {
1573             StatementKind::Assign(box (lhs, rhs)) => {
1574                 self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location));
1575             }
1576
1577             StatementKind::FakeRead(..)
1578             | StatementKind::SetDiscriminant { .. }
1579             | StatementKind::Deinit(..)
1580             | StatementKind::StorageLive(_)
1581             | StatementKind::StorageDead(_)
1582             | StatementKind::Retag(..)
1583             | StatementKind::AscribeUserType(..)
1584             | StatementKind::Coverage(..)
1585             | StatementKind::Intrinsic(..)
1586             | StatementKind::ConstEvalCounter
1587             | StatementKind::Nop => {}
1588         }
1589     }
1590
1591     fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
1592         // Checking for aliasing in terminators is probably overkill, but until we have actual
1593         // semantics, we should be conservative here.
1594         match &terminator.kind {
1595             TerminatorKind::Call {
1596                 func,
1597                 args,
1598                 destination,
1599                 target: Some(_),
1600                 cleanup: _,
1601                 from_hir_call: _,
1602                 fn_span: _,
1603             } => {
1604                 self.check_assigned_place(*destination, |this| {
1605                     this.visit_operand(func, location);
1606                     for arg in args {
1607                         this.visit_operand(arg, location);
1608                     }
1609                 });
1610             }
1611
1612             TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => {
1613                 self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location));
1614             }
1615
1616             // FIXME: Does `asm!` have any aliasing requirements?
1617             TerminatorKind::InlineAsm { .. } => {}
1618
1619             TerminatorKind::Call { .. }
1620             | TerminatorKind::Goto { .. }
1621             | TerminatorKind::SwitchInt { .. }
1622             | TerminatorKind::Resume
1623             | TerminatorKind::Abort
1624             | TerminatorKind::Return
1625             | TerminatorKind::Unreachable
1626             | TerminatorKind::Drop { .. }
1627             | TerminatorKind::DropAndReplace { .. }
1628             | TerminatorKind::Assert { .. }
1629             | TerminatorKind::GeneratorDrop
1630             | TerminatorKind::FalseEdge { .. }
1631             | TerminatorKind::FalseUnwind { .. } => {}
1632         }
1633     }
1634 }