]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_builtin_macros/src/assert/context.rs
Rollup merge of #103842 - andrewpollack:add-fuchsia-test-script, r=tmandry
[rust.git] / compiler / rustc_builtin_macros / src / assert / context.rs
1 use rustc_ast::{
2     attr,
3     ptr::P,
4     token,
5     tokenstream::{DelimSpan, TokenStream, TokenTree},
6     BinOpKind, BorrowKind, Expr, ExprKind, ItemKind, MacArgs, MacCall, MacDelimiter, Mutability,
7     Path, PathSegment, Stmt, StructRest, UnOp, UseTree, UseTreeKind, DUMMY_NODE_ID,
8 };
9 use rustc_ast_pretty::pprust;
10 use rustc_data_structures::fx::FxHashSet;
11 use rustc_expand::base::ExtCtxt;
12 use rustc_span::{
13     symbol::{sym, Ident, Symbol},
14     Span,
15 };
16 use thin_vec::thin_vec;
17
18 pub(super) struct Context<'cx, 'a> {
19     // An optimization.
20     //
21     // Elements that aren't consumed (PartialEq, PartialOrd, ...) can be copied **after** the
22     // `assert!` expression fails rather than copied on-the-fly.
23     best_case_captures: Vec<Stmt>,
24     // Top-level `let captureN = Capture::new()` statements
25     capture_decls: Vec<Capture>,
26     cx: &'cx ExtCtxt<'a>,
27     // Formatting string used for debugging
28     fmt_string: String,
29     // If the current expression being visited consumes itself. Used to construct
30     // `best_case_captures`.
31     is_consumed: bool,
32     // Top-level `let __local_bindN = &expr` statements
33     local_bind_decls: Vec<Stmt>,
34     // Used to avoid capturing duplicated paths
35     //
36     // ```rust
37     // let a = 1i32;
38     // assert!(add(a, a) == 3);
39     // ```
40     paths: FxHashSet<Ident>,
41     span: Span,
42 }
43
44 impl<'cx, 'a> Context<'cx, 'a> {
45     pub(super) fn new(cx: &'cx ExtCtxt<'a>, span: Span) -> Self {
46         Self {
47             best_case_captures: <_>::default(),
48             capture_decls: <_>::default(),
49             cx,
50             fmt_string: <_>::default(),
51             is_consumed: true,
52             local_bind_decls: <_>::default(),
53             paths: <_>::default(),
54             span,
55         }
56     }
57
58     /// Builds the whole `assert!` expression. For example, `let elem = 1; assert!(elem == 1);` expands to:
59     ///
60     /// ```rust
61     /// #![feature(generic_assert_internals)]
62     /// let elem = 1;
63     /// {
64     ///   #[allow(unused_imports)]
65     ///   use ::core::asserting::{TryCaptureGeneric, TryCapturePrintable};
66     ///   let mut __capture0 = ::core::asserting::Capture::new();
67     ///   let __local_bind0 = &elem;
68     ///   if !(
69     ///     *{
70     ///       (&::core::asserting::Wrapper(__local_bind0)).try_capture(&mut __capture0);
71     ///       __local_bind0
72     ///     } == 1
73     ///   ) {
74     ///     panic!("Assertion failed: elem == 1\nWith captures:\n  elem = {:?}", __capture0)
75     ///   }
76     /// }
77     /// ```
78     pub(super) fn build(mut self, mut cond_expr: P<Expr>, panic_path: Path) -> P<Expr> {
79         let expr_str = pprust::expr_to_string(&cond_expr);
80         self.manage_cond_expr(&mut cond_expr);
81         let initial_imports = self.build_initial_imports();
82         let panic = self.build_panic(&expr_str, panic_path);
83         let cond_expr_with_unlikely = self.build_unlikely(cond_expr);
84
85         let Self { best_case_captures, capture_decls, cx, local_bind_decls, span, .. } = self;
86
87         let mut assert_then_stmts = Vec::with_capacity(2);
88         assert_then_stmts.extend(best_case_captures);
89         assert_then_stmts.push(self.cx.stmt_expr(panic));
90         let assert_then = self.cx.block(span, assert_then_stmts);
91
92         let mut stmts = Vec::with_capacity(4);
93         stmts.push(initial_imports);
94         stmts.extend(capture_decls.into_iter().map(|c| c.decl));
95         stmts.extend(local_bind_decls);
96         stmts.push(
97             cx.stmt_expr(cx.expr(span, ExprKind::If(cond_expr_with_unlikely, assert_then, None))),
98         );
99         cx.expr_block(cx.block(span, stmts))
100     }
101
102     /// Initial **trait** imports
103     ///
104     /// use ::core::asserting::{ ... };
105     fn build_initial_imports(&self) -> Stmt {
106         let nested_tree = |this: &Self, sym| {
107             (
108                 UseTree {
109                     prefix: this.cx.path(this.span, vec![Ident::with_dummy_span(sym)]),
110                     kind: UseTreeKind::Simple(None, DUMMY_NODE_ID, DUMMY_NODE_ID),
111                     span: this.span,
112                 },
113                 DUMMY_NODE_ID,
114             )
115         };
116         self.cx.stmt_item(
117             self.span,
118             self.cx.item(
119                 self.span,
120                 Ident::empty(),
121                 thin_vec![self.cx.attribute(attr::mk_list_item(
122                     Ident::new(sym::allow, self.span),
123                     vec![attr::mk_nested_word_item(Ident::new(sym::unused_imports, self.span))],
124                 ))],
125                 ItemKind::Use(UseTree {
126                     prefix: self.cx.path(self.span, self.cx.std_path(&[sym::asserting])),
127                     kind: UseTreeKind::Nested(vec![
128                         nested_tree(self, sym::TryCaptureGeneric),
129                         nested_tree(self, sym::TryCapturePrintable),
130                     ]),
131                     span: self.span,
132                 }),
133             ),
134         )
135     }
136
137     /// Takes the conditional expression of `assert!` and then wraps it inside `unlikely`
138     fn build_unlikely(&self, cond_expr: P<Expr>) -> P<Expr> {
139         let unlikely_path = self.cx.std_path(&[sym::intrinsics, sym::unlikely]);
140         self.cx.expr_call(
141             self.span,
142             self.cx.expr_path(self.cx.path(self.span, unlikely_path)),
143             vec![self.cx.expr(self.span, ExprKind::Unary(UnOp::Not, cond_expr))],
144         )
145     }
146
147     /// The necessary custom `panic!(...)` expression.
148     ///
149     /// panic!(
150     ///     "Assertion failed: ... \n With expansion: ...",
151     ///     __capture0,
152     ///     ...
153     /// );
154     fn build_panic(&self, expr_str: &str, panic_path: Path) -> P<Expr> {
155         let escaped_expr_str = escape_to_fmt(expr_str);
156         let initial = [
157             TokenTree::token_alone(
158                 token::Literal(token::Lit {
159                     kind: token::LitKind::Str,
160                     symbol: Symbol::intern(&if self.fmt_string.is_empty() {
161                         format!("Assertion failed: {escaped_expr_str}")
162                     } else {
163                         format!(
164                             "Assertion failed: {escaped_expr_str}\nWith captures:\n{}",
165                             &self.fmt_string
166                         )
167                     }),
168                     suffix: None,
169                 }),
170                 self.span,
171             ),
172             TokenTree::token_alone(token::Comma, self.span),
173         ];
174         let captures = self.capture_decls.iter().flat_map(|cap| {
175             [
176                 TokenTree::token_alone(token::Ident(cap.ident.name, false), cap.ident.span),
177                 TokenTree::token_alone(token::Comma, self.span),
178             ]
179         });
180         self.cx.expr(
181             self.span,
182             ExprKind::MacCall(P(MacCall {
183                 path: panic_path,
184                 args: P(MacArgs::Delimited(
185                     DelimSpan::from_single(self.span),
186                     MacDelimiter::Parenthesis,
187                     initial.into_iter().chain(captures).collect::<TokenStream>(),
188                 )),
189                 prior_type_ascription: None,
190             })),
191         )
192     }
193
194     /// Recursive function called until `cond_expr` and `fmt_str` are fully modified.
195     ///
196     /// See [Self::manage_initial_capture] and [Self::manage_try_capture]
197     fn manage_cond_expr(&mut self, expr: &mut P<Expr>) {
198         match (*expr).kind {
199             ExprKind::AddrOf(_, mutability, ref mut local_expr) => {
200                 self.with_is_consumed_management(
201                     matches!(mutability, Mutability::Mut),
202                     |this| this.manage_cond_expr(local_expr)
203                 );
204             }
205             ExprKind::Array(ref mut local_exprs) => {
206                 for local_expr in local_exprs {
207                     self.manage_cond_expr(local_expr);
208                 }
209             }
210             ExprKind::Binary(ref op, ref mut lhs, ref mut rhs) => {
211                 self.with_is_consumed_management(
212                     matches!(
213                         op.node,
214                         BinOpKind::Add
215                             | BinOpKind::And
216                             | BinOpKind::BitAnd
217                             | BinOpKind::BitOr
218                             | BinOpKind::BitXor
219                             | BinOpKind::Div
220                             | BinOpKind::Mul
221                             | BinOpKind::Or
222                             | BinOpKind::Rem
223                             | BinOpKind::Shl
224                             | BinOpKind::Shr
225                             | BinOpKind::Sub
226                     ),
227                     |this| {
228                         this.manage_cond_expr(lhs);
229                         this.manage_cond_expr(rhs);
230                     }
231                 );
232             }
233             ExprKind::Call(_, ref mut local_exprs) => {
234                 for local_expr in local_exprs {
235                     self.manage_cond_expr(local_expr);
236                 }
237             }
238             ExprKind::Cast(ref mut local_expr, _) => {
239                 self.manage_cond_expr(local_expr);
240             }
241             ExprKind::Index(ref mut prefix, ref mut suffix) => {
242                 self.manage_cond_expr(prefix);
243                 self.manage_cond_expr(suffix);
244             }
245             ExprKind::MethodCall(_, _,ref mut local_exprs, _) => {
246                 for local_expr in local_exprs.iter_mut() {
247                     self.manage_cond_expr(local_expr);
248                 }
249             }
250             ExprKind::Path(_, Path { ref segments, .. }) if let &[ref path_segment] = &segments[..] => {
251                 let path_ident = path_segment.ident;
252                 self.manage_initial_capture(expr, path_ident);
253             }
254             ExprKind::Paren(ref mut local_expr) => {
255                 self.manage_cond_expr(local_expr);
256             }
257             ExprKind::Range(ref mut prefix, ref mut suffix, _) => {
258                 if let Some(ref mut elem) = prefix {
259                     self.manage_cond_expr(elem);
260                 }
261                 if let Some(ref mut elem) = suffix {
262                     self.manage_cond_expr(elem);
263                 }
264             }
265             ExprKind::Repeat(ref mut local_expr, ref mut elem) => {
266                 self.manage_cond_expr(local_expr);
267                 self.manage_cond_expr(&mut elem.value);
268             }
269             ExprKind::Struct(ref mut elem) => {
270                 for field in &mut elem.fields {
271                     self.manage_cond_expr(&mut field.expr);
272                 }
273                 if let StructRest::Base(ref mut local_expr) = elem.rest {
274                     self.manage_cond_expr(local_expr);
275                 }
276             }
277             ExprKind::Tup(ref mut local_exprs) => {
278                 for local_expr in local_exprs {
279                     self.manage_cond_expr(local_expr);
280                 }
281             }
282             ExprKind::Unary(un_op, ref mut local_expr) => {
283                 self.with_is_consumed_management(
284                     matches!(un_op, UnOp::Neg | UnOp::Not),
285                     |this| this.manage_cond_expr(local_expr)
286                 );
287             }
288             // Expressions that are not worth or can not be captured.
289             //
290             // Full list instead of `_` to catch possible future inclusions and to
291             // sync with the `rfc-2011-nicer-assert-messages/all-expr-kinds.rs` test.
292             ExprKind::Assign(_, _, _)
293             | ExprKind::AssignOp(_, _, _)
294             | ExprKind::Async(_, _, _)
295             | ExprKind::Await(_)
296             | ExprKind::Block(_, _)
297             | ExprKind::Box(_)
298             | ExprKind::Break(_, _)
299             | ExprKind::Closure(_, _, _, _, _, _, _)
300             | ExprKind::ConstBlock(_)
301             | ExprKind::Continue(_)
302             | ExprKind::Err
303             | ExprKind::Field(_, _)
304             | ExprKind::ForLoop(_, _, _, _)
305             | ExprKind::If(_, _, _)
306             | ExprKind::IncludedBytes(..)
307             | ExprKind::InlineAsm(_)
308             | ExprKind::Let(_, _, _)
309             | ExprKind::Lit(_)
310             | ExprKind::Loop(_, _)
311             | ExprKind::MacCall(_)
312             | ExprKind::Match(_, _)
313             | ExprKind::Path(_, _)
314             | ExprKind::Ret(_)
315             | ExprKind::Try(_)
316             | ExprKind::TryBlock(_)
317             | ExprKind::Type(_, _)
318             | ExprKind::Underscore
319             | ExprKind::While(_, _, _)
320             | ExprKind::Yeet(_)
321             | ExprKind::Yield(_) => {}
322         }
323     }
324
325     /// Pushes the top-level declarations and modifies `expr` to try capturing variables.
326     ///
327     /// `fmt_str`, the formatting string used for debugging, is constructed to show possible
328     /// captured variables.
329     fn manage_initial_capture(&mut self, expr: &mut P<Expr>, path_ident: Ident) {
330         if self.paths.contains(&path_ident) {
331             return;
332         } else {
333             self.fmt_string.push_str("  ");
334             self.fmt_string.push_str(path_ident.as_str());
335             self.fmt_string.push_str(" = {:?}\n");
336             let _ = self.paths.insert(path_ident);
337         }
338         let curr_capture_idx = self.capture_decls.len();
339         let capture_string = format!("__capture{curr_capture_idx}");
340         let ident = Ident::new(Symbol::intern(&capture_string), self.span);
341         let init_std_path = self.cx.std_path(&[sym::asserting, sym::Capture, sym::new]);
342         let init = self.cx.expr_call(
343             self.span,
344             self.cx.expr_path(self.cx.path(self.span, init_std_path)),
345             vec![],
346         );
347         let capture = Capture { decl: self.cx.stmt_let(self.span, true, ident, init), ident };
348         self.capture_decls.push(capture);
349         self.manage_try_capture(ident, curr_capture_idx, expr);
350     }
351
352     /// Tries to copy `__local_bindN` into `__captureN`.
353     ///
354     /// *{
355     ///    (&Wrapper(__local_bindN)).try_capture(&mut __captureN);
356     ///    __local_bindN
357     /// }
358     fn manage_try_capture(&mut self, capture: Ident, curr_capture_idx: usize, expr: &mut P<Expr>) {
359         let local_bind_string = format!("__local_bind{curr_capture_idx}");
360         let local_bind = Ident::new(Symbol::intern(&local_bind_string), self.span);
361         self.local_bind_decls.push(self.cx.stmt_let(
362             self.span,
363             false,
364             local_bind,
365             self.cx.expr_addr_of(self.span, expr.clone()),
366         ));
367         let wrapper = self.cx.expr_call(
368             self.span,
369             self.cx.expr_path(
370                 self.cx.path(self.span, self.cx.std_path(&[sym::asserting, sym::Wrapper])),
371             ),
372             vec![self.cx.expr_path(Path::from_ident(local_bind))],
373         );
374         let try_capture_call = self
375             .cx
376             .stmt_expr(expr_method_call(
377                 self.cx,
378                 PathSegment {
379                     args: None,
380                     id: DUMMY_NODE_ID,
381                     ident: Ident::new(sym::try_capture, self.span),
382                 },
383                 expr_paren(self.cx, self.span, self.cx.expr_addr_of(self.span, wrapper)),
384                 vec![expr_addr_of_mut(
385                     self.cx,
386                     self.span,
387                     self.cx.expr_path(Path::from_ident(capture)),
388                 )],
389                 self.span,
390             ))
391             .add_trailing_semicolon();
392         let local_bind_path = self.cx.expr_path(Path::from_ident(local_bind));
393         let rslt = if self.is_consumed {
394             let ret = self.cx.stmt_expr(local_bind_path);
395             self.cx.expr_block(self.cx.block(self.span, vec![try_capture_call, ret]))
396         } else {
397             self.best_case_captures.push(try_capture_call);
398             local_bind_path
399         };
400         *expr = self.cx.expr_deref(self.span, rslt);
401     }
402
403     // Calls `f` with the internal `is_consumed` set to `curr_is_consumed` and then
404     // sets the internal `is_consumed` back to its original value.
405     fn with_is_consumed_management(&mut self, curr_is_consumed: bool, f: impl FnOnce(&mut Self)) {
406         let prev_is_consumed = self.is_consumed;
407         self.is_consumed = curr_is_consumed;
408         f(self);
409         self.is_consumed = prev_is_consumed;
410     }
411 }
412
413 /// Information about a captured element.
414 #[derive(Debug)]
415 struct Capture {
416     // Generated indexed `Capture` statement.
417     //
418     // `let __capture{} = Capture::new();`
419     decl: Stmt,
420     // The name of the generated indexed `Capture` variable.
421     //
422     // `__capture{}`
423     ident: Ident,
424 }
425
426 /// Escapes to use as a formatting string.
427 fn escape_to_fmt(s: &str) -> String {
428     let mut rslt = String::with_capacity(s.len());
429     for c in s.chars() {
430         rslt.extend(c.escape_debug());
431         match c {
432             '{' | '}' => rslt.push(c),
433             _ => {}
434         }
435     }
436     rslt
437 }
438
439 fn expr_addr_of_mut(cx: &ExtCtxt<'_>, sp: Span, e: P<Expr>) -> P<Expr> {
440     cx.expr(sp, ExprKind::AddrOf(BorrowKind::Ref, Mutability::Mut, e))
441 }
442
443 fn expr_method_call(
444     cx: &ExtCtxt<'_>,
445     path: PathSegment,
446     receiver: P<Expr>,
447     args: Vec<P<Expr>>,
448     span: Span,
449 ) -> P<Expr> {
450     cx.expr(span, ExprKind::MethodCall(path, receiver, args, span))
451 }
452
453 fn expr_paren(cx: &ExtCtxt<'_>, sp: Span, e: P<Expr>) -> P<Expr> {
454     cx.expr(sp, ExprKind::Paren(e))
455 }