]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_builtin_macros/src/assert/context.rs
Rollup merge of #100520 - jakubdabek:patch-1, r=thomcc
[rust.git] / compiler / rustc_builtin_macros / src / assert / context.rs
1 use rustc_ast::{
2     attr,
3     ptr::P,
4     token,
5     tokenstream::{DelimSpan, TokenStream, TokenTree},
6     BinOpKind, BorrowKind, Expr, ExprKind, ItemKind, MacArgs, MacCall, MacDelimiter, Mutability,
7     Path, PathSegment, Stmt, StructRest, UnOp, UseTree, UseTreeKind, DUMMY_NODE_ID,
8 };
9 use rustc_ast_pretty::pprust;
10 use rustc_data_structures::fx::FxHashSet;
11 use rustc_expand::base::ExtCtxt;
12 use rustc_span::{
13     symbol::{sym, Ident, Symbol},
14     Span,
15 };
16
17 pub(super) struct Context<'cx, 'a> {
18     // An optimization.
19     //
20     // Elements that aren't consumed (PartialEq, PartialOrd, ...) can be copied **after** the
21     // `assert!` expression fails rather than copied on-the-fly.
22     best_case_captures: Vec<Stmt>,
23     // Top-level `let captureN = Capture::new()` statements
24     capture_decls: Vec<Capture>,
25     cx: &'cx ExtCtxt<'a>,
26     // Formatting string used for debugging
27     fmt_string: String,
28     // If the current expression being visited consumes itself. Used to construct
29     // `best_case_captures`.
30     is_consumed: bool,
31     // Top-level `let __local_bindN = &expr` statements
32     local_bind_decls: Vec<Stmt>,
33     // Used to avoid capturing duplicated paths
34     //
35     // ```rust
36     // let a = 1i32;
37     // assert!(add(a, a) == 3);
38     // ```
39     paths: FxHashSet<Ident>,
40     span: Span,
41 }
42
43 impl<'cx, 'a> Context<'cx, 'a> {
44     pub(super) fn new(cx: &'cx ExtCtxt<'a>, span: Span) -> Self {
45         Self {
46             best_case_captures: <_>::default(),
47             capture_decls: <_>::default(),
48             cx,
49             fmt_string: <_>::default(),
50             is_consumed: true,
51             local_bind_decls: <_>::default(),
52             paths: <_>::default(),
53             span,
54         }
55     }
56
57     /// Builds the whole `assert!` expression. For example, `let elem = 1; assert!(elem == 1);` expands to:
58     ///
59     /// ```rust
60     /// let elem = 1;
61     /// {
62     ///   #[allow(unused_imports)]
63     ///   use ::core::asserting::{TryCaptureGeneric, TryCapturePrintable};
64     ///   let mut __capture0 = ::core::asserting::Capture::new();
65     ///   let __local_bind0 = &elem;
66     ///   if !(
67     ///     *{
68     ///       (&::core::asserting::Wrapper(__local_bind0)).try_capture(&mut __capture0);
69     ///       __local_bind0
70     ///     } == 1
71     ///   ) {
72     ///     panic!("Assertion failed: elem == 1\nWith captures:\n  elem = {}", __capture0)
73     ///   }
74     /// }
75     /// ```
76     pub(super) fn build(mut self, mut cond_expr: P<Expr>, panic_path: Path) -> P<Expr> {
77         let expr_str = pprust::expr_to_string(&cond_expr);
78         self.manage_cond_expr(&mut cond_expr);
79         let initial_imports = self.build_initial_imports();
80         let panic = self.build_panic(&expr_str, panic_path);
81         let cond_expr_with_unlikely = self.build_unlikely(cond_expr);
82
83         let Self { best_case_captures, capture_decls, cx, local_bind_decls, span, .. } = self;
84
85         let mut assert_then_stmts = Vec::with_capacity(2);
86         assert_then_stmts.extend(best_case_captures);
87         assert_then_stmts.push(self.cx.stmt_expr(panic));
88         let assert_then = self.cx.block(span, assert_then_stmts);
89
90         let mut stmts = Vec::with_capacity(4);
91         stmts.push(initial_imports);
92         stmts.extend(capture_decls.into_iter().map(|c| c.decl));
93         stmts.extend(local_bind_decls);
94         stmts.push(
95             cx.stmt_expr(cx.expr(span, ExprKind::If(cond_expr_with_unlikely, assert_then, None))),
96         );
97         cx.expr_block(cx.block(span, stmts))
98     }
99
100     /// Initial **trait** imports
101     ///
102     /// use ::core::asserting::{ ... };
103     fn build_initial_imports(&self) -> Stmt {
104         let nested_tree = |this: &Self, sym| {
105             (
106                 UseTree {
107                     prefix: this.cx.path(this.span, vec![Ident::with_dummy_span(sym)]),
108                     kind: UseTreeKind::Simple(None, DUMMY_NODE_ID, DUMMY_NODE_ID),
109                     span: this.span,
110                 },
111                 DUMMY_NODE_ID,
112             )
113         };
114         self.cx.stmt_item(
115             self.span,
116             self.cx.item(
117                 self.span,
118                 Ident::empty(),
119                 vec![self.cx.attribute(attr::mk_list_item(
120                     Ident::new(sym::allow, self.span),
121                     vec![attr::mk_nested_word_item(Ident::new(sym::unused_imports, self.span))],
122                 ))]
123                 .into(),
124                 ItemKind::Use(UseTree {
125                     prefix: self.cx.path(self.span, self.cx.std_path(&[sym::asserting])),
126                     kind: UseTreeKind::Nested(vec![
127                         nested_tree(self, sym::TryCaptureGeneric),
128                         nested_tree(self, sym::TryCapturePrintable),
129                     ]),
130                     span: self.span,
131                 }),
132             ),
133         )
134     }
135
136     /// Takes the conditional expression of `assert!` and then wraps it inside `unlikely`
137     fn build_unlikely(&self, cond_expr: P<Expr>) -> P<Expr> {
138         let unlikely_path = self.cx.std_path(&[sym::intrinsics, sym::unlikely]);
139         self.cx.expr_call(
140             self.span,
141             self.cx.expr_path(self.cx.path(self.span, unlikely_path)),
142             vec![self.cx.expr(self.span, ExprKind::Unary(UnOp::Not, cond_expr))],
143         )
144     }
145
146     /// The necessary custom `panic!(...)` expression.
147     ///
148     /// panic!(
149     ///     "Assertion failed: ... \n With expansion: ...",
150     ///     __capture0,
151     ///     ...
152     /// );
153     fn build_panic(&self, expr_str: &str, panic_path: Path) -> P<Expr> {
154         let escaped_expr_str = escape_to_fmt(expr_str);
155         let initial = [
156             TokenTree::token_alone(
157                 token::Literal(token::Lit {
158                     kind: token::LitKind::Str,
159                     symbol: Symbol::intern(&if self.fmt_string.is_empty() {
160                         format!("Assertion failed: {escaped_expr_str}")
161                     } else {
162                         format!(
163                             "Assertion failed: {escaped_expr_str}\nWith captures:\n{}",
164                             &self.fmt_string
165                         )
166                     }),
167                     suffix: None,
168                 }),
169                 self.span,
170             ),
171             TokenTree::token_alone(token::Comma, self.span),
172         ];
173         let captures = self.capture_decls.iter().flat_map(|cap| {
174             [
175                 TokenTree::token_alone(token::Ident(cap.ident.name, false), cap.ident.span),
176                 TokenTree::token_alone(token::Comma, self.span),
177             ]
178         });
179         self.cx.expr(
180             self.span,
181             ExprKind::MacCall(P(MacCall {
182                 path: panic_path,
183                 args: P(MacArgs::Delimited(
184                     DelimSpan::from_single(self.span),
185                     MacDelimiter::Parenthesis,
186                     initial.into_iter().chain(captures).collect::<TokenStream>(),
187                 )),
188                 prior_type_ascription: None,
189             })),
190         )
191     }
192
193     /// Recursive function called until `cond_expr` and `fmt_str` are fully modified.
194     ///
195     /// See [Self::manage_initial_capture] and [Self::manage_try_capture]
196     fn manage_cond_expr(&mut self, expr: &mut P<Expr>) {
197         match (*expr).kind {
198             ExprKind::AddrOf(_, mutability, ref mut local_expr) => {
199                 self.with_is_consumed_management(
200                     matches!(mutability, Mutability::Mut),
201                     |this| this.manage_cond_expr(local_expr)
202                 );
203             }
204             ExprKind::Array(ref mut local_exprs) => {
205                 for local_expr in local_exprs {
206                     self.manage_cond_expr(local_expr);
207                 }
208             }
209             ExprKind::Binary(ref op, ref mut lhs, ref mut rhs) => {
210                 self.with_is_consumed_management(
211                     matches!(
212                         op.node,
213                         BinOpKind::Add
214                             | BinOpKind::And
215                             | BinOpKind::BitAnd
216                             | BinOpKind::BitOr
217                             | BinOpKind::BitXor
218                             | BinOpKind::Div
219                             | BinOpKind::Mul
220                             | BinOpKind::Or
221                             | BinOpKind::Rem
222                             | BinOpKind::Shl
223                             | BinOpKind::Shr
224                             | BinOpKind::Sub
225                     ),
226                     |this| {
227                         this.manage_cond_expr(lhs);
228                         this.manage_cond_expr(rhs);
229                     }
230                 );
231             }
232             ExprKind::Call(_, ref mut local_exprs) => {
233                 for local_expr in local_exprs {
234                     self.manage_cond_expr(local_expr);
235                 }
236             }
237             ExprKind::Cast(ref mut local_expr, _) => {
238                 self.manage_cond_expr(local_expr);
239             }
240             ExprKind::Index(ref mut prefix, ref mut suffix) => {
241                 self.manage_cond_expr(prefix);
242                 self.manage_cond_expr(suffix);
243             }
244             ExprKind::MethodCall(_, _,ref mut local_exprs, _) => {
245                 for local_expr in local_exprs.iter_mut() {
246                     self.manage_cond_expr(local_expr);
247                 }
248             }
249             ExprKind::Path(_, Path { ref segments, .. }) if let &[ref path_segment] = &segments[..] => {
250                 let path_ident = path_segment.ident;
251                 self.manage_initial_capture(expr, path_ident);
252             }
253             ExprKind::Paren(ref mut local_expr) => {
254                 self.manage_cond_expr(local_expr);
255             }
256             ExprKind::Range(ref mut prefix, ref mut suffix, _) => {
257                 if let Some(ref mut elem) = prefix {
258                     self.manage_cond_expr(elem);
259                 }
260                 if let Some(ref mut elem) = suffix {
261                     self.manage_cond_expr(elem);
262                 }
263             }
264             ExprKind::Repeat(ref mut local_expr, ref mut elem) => {
265                 self.manage_cond_expr(local_expr);
266                 self.manage_cond_expr(&mut elem.value);
267             }
268             ExprKind::Struct(ref mut elem) => {
269                 for field in &mut elem.fields {
270                     self.manage_cond_expr(&mut field.expr);
271                 }
272                 if let StructRest::Base(ref mut local_expr) = elem.rest {
273                     self.manage_cond_expr(local_expr);
274                 }
275             }
276             ExprKind::Tup(ref mut local_exprs) => {
277                 for local_expr in local_exprs {
278                     self.manage_cond_expr(local_expr);
279                 }
280             }
281             ExprKind::Unary(un_op, ref mut local_expr) => {
282                 self.with_is_consumed_management(
283                     matches!(un_op, UnOp::Neg | UnOp::Not),
284                     |this| this.manage_cond_expr(local_expr)
285                 );
286             }
287             // Expressions that are not worth or can not be captured.
288             //
289             // Full list instead of `_` to catch possible future inclusions and to
290             // sync with the `rfc-2011-nicer-assert-messages/all-expr-kinds.rs` test.
291             ExprKind::Assign(_, _, _)
292             | ExprKind::AssignOp(_, _, _)
293             | ExprKind::Async(_, _, _)
294             | ExprKind::Await(_)
295             | ExprKind::Block(_, _)
296             | ExprKind::Box(_)
297             | ExprKind::Break(_, _)
298             | ExprKind::Closure(_, _, _, _, _, _, _)
299             | ExprKind::ConstBlock(_)
300             | ExprKind::Continue(_)
301             | ExprKind::Err
302             | ExprKind::Field(_, _)
303             | ExprKind::ForLoop(_, _, _, _)
304             | ExprKind::If(_, _, _)
305             | ExprKind::InlineAsm(_)
306             | ExprKind::Let(_, _, _)
307             | ExprKind::Lit(_)
308             | ExprKind::Loop(_, _)
309             | ExprKind::MacCall(_)
310             | ExprKind::Match(_, _)
311             | ExprKind::Path(_, _)
312             | ExprKind::Ret(_)
313             | ExprKind::Try(_)
314             | ExprKind::TryBlock(_)
315             | ExprKind::Type(_, _)
316             | ExprKind::Underscore
317             | ExprKind::While(_, _, _)
318             | ExprKind::Yeet(_)
319             | ExprKind::Yield(_) => {}
320         }
321     }
322
323     /// Pushes the top-level declarations and modifies `expr` to try capturing variables.
324     ///
325     /// `fmt_str`, the formatting string used for debugging, is constructed to show possible
326     /// captured variables.
327     fn manage_initial_capture(&mut self, expr: &mut P<Expr>, path_ident: Ident) {
328         if self.paths.contains(&path_ident) {
329             return;
330         } else {
331             self.fmt_string.push_str("  ");
332             self.fmt_string.push_str(path_ident.as_str());
333             self.fmt_string.push_str(" = {:?}\n");
334             let _ = self.paths.insert(path_ident);
335         }
336         let curr_capture_idx = self.capture_decls.len();
337         let capture_string = format!("__capture{curr_capture_idx}");
338         let ident = Ident::new(Symbol::intern(&capture_string), self.span);
339         let init_std_path = self.cx.std_path(&[sym::asserting, sym::Capture, sym::new]);
340         let init = self.cx.expr_call(
341             self.span,
342             self.cx.expr_path(self.cx.path(self.span, init_std_path)),
343             vec![],
344         );
345         let capture = Capture { decl: self.cx.stmt_let(self.span, true, ident, init), ident };
346         self.capture_decls.push(capture);
347         self.manage_try_capture(ident, curr_capture_idx, expr);
348     }
349
350     /// Tries to copy `__local_bindN` into `__captureN`.
351     ///
352     /// *{
353     ///    (&Wrapper(__local_bindN)).try_capture(&mut __captureN);
354     ///    __local_bindN
355     /// }
356     fn manage_try_capture(&mut self, capture: Ident, curr_capture_idx: usize, expr: &mut P<Expr>) {
357         let local_bind_string = format!("__local_bind{curr_capture_idx}");
358         let local_bind = Ident::new(Symbol::intern(&local_bind_string), self.span);
359         self.local_bind_decls.push(self.cx.stmt_let(
360             self.span,
361             false,
362             local_bind,
363             self.cx.expr_addr_of(self.span, expr.clone()),
364         ));
365         let wrapper = self.cx.expr_call(
366             self.span,
367             self.cx.expr_path(
368                 self.cx.path(self.span, self.cx.std_path(&[sym::asserting, sym::Wrapper])),
369             ),
370             vec![self.cx.expr_path(Path::from_ident(local_bind))],
371         );
372         let try_capture_call = self
373             .cx
374             .stmt_expr(expr_method_call(
375                 self.cx,
376                 PathSegment {
377                     args: None,
378                     id: DUMMY_NODE_ID,
379                     ident: Ident::new(sym::try_capture, self.span),
380                 },
381                 expr_paren(self.cx, self.span, self.cx.expr_addr_of(self.span, wrapper)),
382                 vec![expr_addr_of_mut(
383                     self.cx,
384                     self.span,
385                     self.cx.expr_path(Path::from_ident(capture)),
386                 )],
387                 self.span,
388             ))
389             .add_trailing_semicolon();
390         let local_bind_path = self.cx.expr_path(Path::from_ident(local_bind));
391         let rslt = if self.is_consumed {
392             let ret = self.cx.stmt_expr(local_bind_path);
393             self.cx.expr_block(self.cx.block(self.span, vec![try_capture_call, ret]))
394         } else {
395             self.best_case_captures.push(try_capture_call);
396             local_bind_path
397         };
398         *expr = self.cx.expr_deref(self.span, rslt);
399     }
400
401     // Calls `f` with the internal `is_consumed` set to `curr_is_consumed` and then
402     // sets the internal `is_consumed` back to its original value.
403     fn with_is_consumed_management(&mut self, curr_is_consumed: bool, f: impl FnOnce(&mut Self)) {
404         let prev_is_consumed = self.is_consumed;
405         self.is_consumed = curr_is_consumed;
406         f(self);
407         self.is_consumed = prev_is_consumed;
408     }
409 }
410
411 /// Information about a captured element.
412 #[derive(Debug)]
413 struct Capture {
414     // Generated indexed `Capture` statement.
415     //
416     // `let __capture{} = Capture::new();`
417     decl: Stmt,
418     // The name of the generated indexed `Capture` variable.
419     //
420     // `__capture{}`
421     ident: Ident,
422 }
423
424 /// Escapes to use as a formatting string.
425 fn escape_to_fmt(s: &str) -> String {
426     let mut rslt = String::with_capacity(s.len());
427     for c in s.chars() {
428         rslt.extend(c.escape_debug());
429         match c {
430             '{' | '}' => rslt.push(c),
431             _ => {}
432         }
433     }
434     rslt
435 }
436
437 fn expr_addr_of_mut(cx: &ExtCtxt<'_>, sp: Span, e: P<Expr>) -> P<Expr> {
438     cx.expr(sp, ExprKind::AddrOf(BorrowKind::Ref, Mutability::Mut, e))
439 }
440
441 fn expr_method_call(
442     cx: &ExtCtxt<'_>,
443     path: PathSegment,
444     receiver: P<Expr>,
445     args: Vec<P<Expr>>,
446     span: Span,
447 ) -> P<Expr> {
448     cx.expr(span, ExprKind::MethodCall(path, receiver, args, span))
449 }
450
451 fn expr_paren(cx: &ExtCtxt<'_>, sp: Span, e: P<Expr>) -> P<Expr> {
452     cx.expr(sp, ExprKind::Paren(e))
453 }