]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_builtin_macros/src/assert/context.rs
Rollup merge of #107242 - notriddle:notriddle/title-ordering, r=GuillaumeGomez
[rust.git] / compiler / rustc_builtin_macros / src / assert / context.rs
1 use rustc_ast::{
2     ptr::P,
3     token,
4     tokenstream::{DelimSpan, TokenStream, TokenTree},
5     BinOpKind, BorrowKind, DelimArgs, Expr, ExprKind, ItemKind, MacCall, MacDelimiter, MethodCall,
6     Mutability, Path, PathSegment, Stmt, StructRest, UnOp, UseTree, UseTreeKind, DUMMY_NODE_ID,
7 };
8 use rustc_ast_pretty::pprust;
9 use rustc_data_structures::fx::FxHashSet;
10 use rustc_expand::base::ExtCtxt;
11 use rustc_span::{
12     symbol::{sym, Ident, Symbol},
13     Span,
14 };
15 use thin_vec::thin_vec;
16
17 pub(super) struct Context<'cx, 'a> {
18     // An optimization.
19     //
20     // Elements that aren't consumed (PartialEq, PartialOrd, ...) can be copied **after** the
21     // `assert!` expression fails rather than copied on-the-fly.
22     best_case_captures: Vec<Stmt>,
23     // Top-level `let captureN = Capture::new()` statements
24     capture_decls: Vec<Capture>,
25     cx: &'cx ExtCtxt<'a>,
26     // Formatting string used for debugging
27     fmt_string: String,
28     // If the current expression being visited consumes itself. Used to construct
29     // `best_case_captures`.
30     is_consumed: bool,
31     // Top-level `let __local_bindN = &expr` statements
32     local_bind_decls: Vec<Stmt>,
33     // Used to avoid capturing duplicated paths
34     //
35     // ```rust
36     // let a = 1i32;
37     // assert!(add(a, a) == 3);
38     // ```
39     paths: FxHashSet<Ident>,
40     span: Span,
41 }
42
43 impl<'cx, 'a> Context<'cx, 'a> {
44     pub(super) fn new(cx: &'cx ExtCtxt<'a>, span: Span) -> Self {
45         Self {
46             best_case_captures: <_>::default(),
47             capture_decls: <_>::default(),
48             cx,
49             fmt_string: <_>::default(),
50             is_consumed: true,
51             local_bind_decls: <_>::default(),
52             paths: <_>::default(),
53             span,
54         }
55     }
56
57     /// Builds the whole `assert!` expression. For example, `let elem = 1; assert!(elem == 1);` expands to:
58     ///
59     /// ```rust
60     /// #![feature(generic_assert_internals)]
61     /// let elem = 1;
62     /// {
63     ///   #[allow(unused_imports)]
64     ///   use ::core::asserting::{TryCaptureGeneric, TryCapturePrintable};
65     ///   let mut __capture0 = ::core::asserting::Capture::new();
66     ///   let __local_bind0 = &elem;
67     ///   if !(
68     ///     *{
69     ///       (&::core::asserting::Wrapper(__local_bind0)).try_capture(&mut __capture0);
70     ///       __local_bind0
71     ///     } == 1
72     ///   ) {
73     ///     panic!("Assertion failed: elem == 1\nWith captures:\n  elem = {:?}", __capture0)
74     ///   }
75     /// }
76     /// ```
77     pub(super) fn build(mut self, mut cond_expr: P<Expr>, panic_path: Path) -> P<Expr> {
78         let expr_str = pprust::expr_to_string(&cond_expr);
79         self.manage_cond_expr(&mut cond_expr);
80         let initial_imports = self.build_initial_imports();
81         let panic = self.build_panic(&expr_str, panic_path);
82         let cond_expr_with_unlikely = self.build_unlikely(cond_expr);
83
84         let Self { best_case_captures, capture_decls, cx, local_bind_decls, span, .. } = self;
85
86         let mut assert_then_stmts = Vec::with_capacity(2);
87         assert_then_stmts.extend(best_case_captures);
88         assert_then_stmts.push(self.cx.stmt_expr(panic));
89         let assert_then = self.cx.block(span, assert_then_stmts);
90
91         let mut stmts = Vec::with_capacity(4);
92         stmts.push(initial_imports);
93         stmts.extend(capture_decls.into_iter().map(|c| c.decl));
94         stmts.extend(local_bind_decls);
95         stmts.push(
96             cx.stmt_expr(cx.expr(span, ExprKind::If(cond_expr_with_unlikely, assert_then, None))),
97         );
98         cx.expr_block(cx.block(span, stmts))
99     }
100
101     /// Initial **trait** imports
102     ///
103     /// use ::core::asserting::{ ... };
104     fn build_initial_imports(&self) -> Stmt {
105         let nested_tree = |this: &Self, sym| {
106             (
107                 UseTree {
108                     prefix: this.cx.path(this.span, vec![Ident::with_dummy_span(sym)]),
109                     kind: UseTreeKind::Simple(None),
110                     span: this.span,
111                 },
112                 DUMMY_NODE_ID,
113             )
114         };
115         self.cx.stmt_item(
116             self.span,
117             self.cx.item(
118                 self.span,
119                 Ident::empty(),
120                 thin_vec![self.cx.attr_nested_word(sym::allow, sym::unused_imports, self.span)],
121                 ItemKind::Use(UseTree {
122                     prefix: self.cx.path(self.span, self.cx.std_path(&[sym::asserting])),
123                     kind: UseTreeKind::Nested(vec![
124                         nested_tree(self, sym::TryCaptureGeneric),
125                         nested_tree(self, sym::TryCapturePrintable),
126                     ]),
127                     span: self.span,
128                 }),
129             ),
130         )
131     }
132
133     /// Takes the conditional expression of `assert!` and then wraps it inside `unlikely`
134     fn build_unlikely(&self, cond_expr: P<Expr>) -> P<Expr> {
135         let unlikely_path = self.cx.std_path(&[sym::intrinsics, sym::unlikely]);
136         self.cx.expr_call(
137             self.span,
138             self.cx.expr_path(self.cx.path(self.span, unlikely_path)),
139             vec![self.cx.expr(self.span, ExprKind::Unary(UnOp::Not, cond_expr))],
140         )
141     }
142
143     /// The necessary custom `panic!(...)` expression.
144     ///
145     /// panic!(
146     ///     "Assertion failed: ... \n With expansion: ...",
147     ///     __capture0,
148     ///     ...
149     /// );
150     fn build_panic(&self, expr_str: &str, panic_path: Path) -> P<Expr> {
151         let escaped_expr_str = escape_to_fmt(expr_str);
152         let initial = [
153             TokenTree::token_alone(
154                 token::Literal(token::Lit {
155                     kind: token::LitKind::Str,
156                     symbol: Symbol::intern(&if self.fmt_string.is_empty() {
157                         format!("Assertion failed: {escaped_expr_str}")
158                     } else {
159                         format!(
160                             "Assertion failed: {escaped_expr_str}\nWith captures:\n{}",
161                             &self.fmt_string
162                         )
163                     }),
164                     suffix: None,
165                 }),
166                 self.span,
167             ),
168             TokenTree::token_alone(token::Comma, self.span),
169         ];
170         let captures = self.capture_decls.iter().flat_map(|cap| {
171             [
172                 TokenTree::token_alone(token::Ident(cap.ident.name, false), cap.ident.span),
173                 TokenTree::token_alone(token::Comma, self.span),
174             ]
175         });
176         self.cx.expr(
177             self.span,
178             ExprKind::MacCall(P(MacCall {
179                 path: panic_path,
180                 args: P(DelimArgs {
181                     dspan: DelimSpan::from_single(self.span),
182                     delim: MacDelimiter::Parenthesis,
183                     tokens: initial.into_iter().chain(captures).collect::<TokenStream>(),
184                 }),
185                 prior_type_ascription: None,
186             })),
187         )
188     }
189
190     /// Recursive function called until `cond_expr` and `fmt_str` are fully modified.
191     ///
192     /// See [Self::manage_initial_capture] and [Self::manage_try_capture]
193     fn manage_cond_expr(&mut self, expr: &mut P<Expr>) {
194         match &mut expr.kind {
195             ExprKind::AddrOf(_, mutability, local_expr) => {
196                 self.with_is_consumed_management(
197                     matches!(mutability, Mutability::Mut),
198                     |this| this.manage_cond_expr(local_expr)
199                 );
200             }
201             ExprKind::Array(local_exprs) => {
202                 for local_expr in local_exprs {
203                     self.manage_cond_expr(local_expr);
204                 }
205             }
206             ExprKind::Binary(op, lhs, rhs) => {
207                 self.with_is_consumed_management(
208                     matches!(
209                         op.node,
210                         BinOpKind::Add
211                             | BinOpKind::And
212                             | BinOpKind::BitAnd
213                             | BinOpKind::BitOr
214                             | BinOpKind::BitXor
215                             | BinOpKind::Div
216                             | BinOpKind::Mul
217                             | BinOpKind::Or
218                             | BinOpKind::Rem
219                             | BinOpKind::Shl
220                             | BinOpKind::Shr
221                             | BinOpKind::Sub
222                     ),
223                     |this| {
224                         this.manage_cond_expr(lhs);
225                         this.manage_cond_expr(rhs);
226                     }
227                 );
228             }
229             ExprKind::Call(_, local_exprs) => {
230                 for local_expr in local_exprs {
231                     self.manage_cond_expr(local_expr);
232                 }
233             }
234             ExprKind::Cast(local_expr, _) => {
235                 self.manage_cond_expr(local_expr);
236             }
237             ExprKind::Index(prefix, suffix) => {
238                 self.manage_cond_expr(prefix);
239                 self.manage_cond_expr(suffix);
240             }
241             ExprKind::MethodCall(call) => {
242                 for arg in &mut call.args {
243                     self.manage_cond_expr(arg);
244                 }
245             }
246             ExprKind::Path(_, Path { segments, .. }) if let [path_segment] = &segments[..] => {
247                 let path_ident = path_segment.ident;
248                 self.manage_initial_capture(expr, path_ident);
249             }
250             ExprKind::Paren(local_expr) => {
251                 self.manage_cond_expr(local_expr);
252             }
253             ExprKind::Range(prefix, suffix, _) => {
254                 if let Some(elem) = prefix {
255                     self.manage_cond_expr(elem);
256                 }
257                 if let Some(elem) = suffix {
258                     self.manage_cond_expr(elem);
259                 }
260             }
261             ExprKind::Repeat(local_expr, elem) => {
262                 self.manage_cond_expr(local_expr);
263                 self.manage_cond_expr(&mut elem.value);
264             }
265             ExprKind::Struct(elem) => {
266                 for field in &mut elem.fields {
267                     self.manage_cond_expr(&mut field.expr);
268                 }
269                 if let StructRest::Base(local_expr) = &mut elem.rest {
270                     self.manage_cond_expr(local_expr);
271                 }
272             }
273             ExprKind::Tup(local_exprs) => {
274                 for local_expr in local_exprs {
275                     self.manage_cond_expr(local_expr);
276                 }
277             }
278             ExprKind::Unary(un_op, local_expr) => {
279                 self.with_is_consumed_management(
280                     matches!(un_op, UnOp::Neg | UnOp::Not),
281                     |this| this.manage_cond_expr(local_expr)
282                 );
283             }
284             // Expressions that are not worth or can not be captured.
285             //
286             // Full list instead of `_` to catch possible future inclusions and to
287             // sync with the `rfc-2011-nicer-assert-messages/all-expr-kinds.rs` test.
288             ExprKind::Assign(_, _, _)
289             | ExprKind::AssignOp(_, _, _)
290             | ExprKind::Async(_, _, _)
291             | ExprKind::Await(_)
292             | ExprKind::Block(_, _)
293             | ExprKind::Box(_)
294             | ExprKind::Break(_, _)
295             | ExprKind::Closure(_)
296             | ExprKind::ConstBlock(_)
297             | ExprKind::Continue(_)
298             | ExprKind::Err
299             | ExprKind::Field(_, _)
300             | ExprKind::FormatArgs(_)
301             | ExprKind::ForLoop(_, _, _, _)
302             | ExprKind::If(_, _, _)
303             | ExprKind::IncludedBytes(..)
304             | ExprKind::InlineAsm(_)
305             | ExprKind::Let(_, _, _)
306             | ExprKind::Lit(_)
307             | ExprKind::Loop(_, _, _)
308             | ExprKind::MacCall(_)
309             | ExprKind::Match(_, _)
310             | ExprKind::Path(_, _)
311             | ExprKind::Ret(_)
312             | ExprKind::Try(_)
313             | ExprKind::TryBlock(_)
314             | ExprKind::Type(_, _)
315             | ExprKind::Underscore
316             | ExprKind::While(_, _, _)
317             | ExprKind::Yeet(_)
318             | ExprKind::Yield(_) => {}
319         }
320     }
321
322     /// Pushes the top-level declarations and modifies `expr` to try capturing variables.
323     ///
324     /// `fmt_str`, the formatting string used for debugging, is constructed to show possible
325     /// captured variables.
326     fn manage_initial_capture(&mut self, expr: &mut P<Expr>, path_ident: Ident) {
327         if self.paths.contains(&path_ident) {
328             return;
329         } else {
330             self.fmt_string.push_str("  ");
331             self.fmt_string.push_str(path_ident.as_str());
332             self.fmt_string.push_str(" = {:?}\n");
333             let _ = self.paths.insert(path_ident);
334         }
335         let curr_capture_idx = self.capture_decls.len();
336         let capture_string = format!("__capture{curr_capture_idx}");
337         let ident = Ident::new(Symbol::intern(&capture_string), self.span);
338         let init_std_path = self.cx.std_path(&[sym::asserting, sym::Capture, sym::new]);
339         let init = self.cx.expr_call(
340             self.span,
341             self.cx.expr_path(self.cx.path(self.span, init_std_path)),
342             vec![],
343         );
344         let capture = Capture { decl: self.cx.stmt_let(self.span, true, ident, init), ident };
345         self.capture_decls.push(capture);
346         self.manage_try_capture(ident, curr_capture_idx, expr);
347     }
348
349     /// Tries to copy `__local_bindN` into `__captureN`.
350     ///
351     /// *{
352     ///    (&Wrapper(__local_bindN)).try_capture(&mut __captureN);
353     ///    __local_bindN
354     /// }
355     fn manage_try_capture(&mut self, capture: Ident, curr_capture_idx: usize, expr: &mut P<Expr>) {
356         let local_bind_string = format!("__local_bind{curr_capture_idx}");
357         let local_bind = Ident::new(Symbol::intern(&local_bind_string), self.span);
358         self.local_bind_decls.push(self.cx.stmt_let(
359             self.span,
360             false,
361             local_bind,
362             self.cx.expr_addr_of(self.span, expr.clone()),
363         ));
364         let wrapper = self.cx.expr_call(
365             self.span,
366             self.cx.expr_path(
367                 self.cx.path(self.span, self.cx.std_path(&[sym::asserting, sym::Wrapper])),
368             ),
369             vec![self.cx.expr_path(Path::from_ident(local_bind))],
370         );
371         let try_capture_call = self
372             .cx
373             .stmt_expr(expr_method_call(
374                 self.cx,
375                 PathSegment {
376                     args: None,
377                     id: DUMMY_NODE_ID,
378                     ident: Ident::new(sym::try_capture, self.span),
379                 },
380                 expr_paren(self.cx, self.span, self.cx.expr_addr_of(self.span, wrapper)),
381                 vec![expr_addr_of_mut(
382                     self.cx,
383                     self.span,
384                     self.cx.expr_path(Path::from_ident(capture)),
385                 )],
386                 self.span,
387             ))
388             .add_trailing_semicolon();
389         let local_bind_path = self.cx.expr_path(Path::from_ident(local_bind));
390         let rslt = if self.is_consumed {
391             let ret = self.cx.stmt_expr(local_bind_path);
392             self.cx.expr_block(self.cx.block(self.span, vec![try_capture_call, ret]))
393         } else {
394             self.best_case_captures.push(try_capture_call);
395             local_bind_path
396         };
397         *expr = self.cx.expr_deref(self.span, rslt);
398     }
399
400     // Calls `f` with the internal `is_consumed` set to `curr_is_consumed` and then
401     // sets the internal `is_consumed` back to its original value.
402     fn with_is_consumed_management(&mut self, curr_is_consumed: bool, f: impl FnOnce(&mut Self)) {
403         let prev_is_consumed = self.is_consumed;
404         self.is_consumed = curr_is_consumed;
405         f(self);
406         self.is_consumed = prev_is_consumed;
407     }
408 }
409
410 /// Information about a captured element.
411 #[derive(Debug)]
412 struct Capture {
413     // Generated indexed `Capture` statement.
414     //
415     // `let __capture{} = Capture::new();`
416     decl: Stmt,
417     // The name of the generated indexed `Capture` variable.
418     //
419     // `__capture{}`
420     ident: Ident,
421 }
422
423 /// Escapes to use as a formatting string.
424 fn escape_to_fmt(s: &str) -> String {
425     let mut rslt = String::with_capacity(s.len());
426     for c in s.chars() {
427         rslt.extend(c.escape_debug());
428         match c {
429             '{' | '}' => rslt.push(c),
430             _ => {}
431         }
432     }
433     rslt
434 }
435
436 fn expr_addr_of_mut(cx: &ExtCtxt<'_>, sp: Span, e: P<Expr>) -> P<Expr> {
437     cx.expr(sp, ExprKind::AddrOf(BorrowKind::Ref, Mutability::Mut, e))
438 }
439
440 fn expr_method_call(
441     cx: &ExtCtxt<'_>,
442     seg: PathSegment,
443     receiver: P<Expr>,
444     args: Vec<P<Expr>>,
445     span: Span,
446 ) -> P<Expr> {
447     cx.expr(span, ExprKind::MethodCall(Box::new(MethodCall { seg, receiver, args, span })))
448 }
449
450 fn expr_paren(cx: &ExtCtxt<'_>, sp: Span, e: P<Expr>) -> P<Expr> {
451     cx.expr(sp, ExprKind::Paren(e))
452 }