]> git.lizzy.rs Git - rust.git/commitdiff
Fix recursive macro statement expansion
authorEdwin Cheng <edwin0cheng@gmail.com>
Thu, 25 Mar 2021 19:52:35 +0000 (03:52 +0800)
committerEdwin Cheng <edwin0cheng@gmail.com>
Thu, 25 Mar 2021 20:21:15 +0000 (04:21 +0800)
crates/hir_def/src/body/lower.rs
crates/hir_def/src/expr.rs
crates/hir_def/src/item_tree.rs
crates/hir_def/src/item_tree/lower.rs
crates/hir_expand/src/db.rs
crates/hir_ty/src/infer/expr.rs
crates/hir_ty/src/tests/macros.rs

index 19f5065d1381776ae01670b67b85c919f39af8b4..229e81dd43a6fb9c5f93b1c094ee471b9e7e3d74 100644 (file)
@@ -74,6 +74,7 @@ pub(super) fn lower(
             _c: Count::new(),
         },
         expander,
+        statements_in_scope: Vec::new(),
     }
     .collect(params, body)
 }
@@ -83,6 +84,7 @@ struct ExprCollector<'a> {
     expander: Expander,
     body: Body,
     source_map: BodySourceMap,
+    statements_in_scope: Vec<Statement>,
 }
 
 impl ExprCollector<'_> {
@@ -533,15 +535,13 @@ fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
                 ids[0]
             }
             ast::Expr::MacroStmts(e) => {
-                // FIXME:  these statements should be held by some hir containter
-                for stmt in e.statements() {
-                    self.collect_stmt(stmt);
-                }
-                if let Some(expr) = e.expr() {
-                    self.collect_expr(expr)
-                } else {
-                    self.alloc_expr(Expr::Missing, syntax_ptr)
-                }
+                e.statements().for_each(|s| self.collect_stmt(s));
+                let tail = e
+                    .expr()
+                    .map(|e| self.collect_expr(e))
+                    .unwrap_or_else(|| self.alloc_expr(Expr::Missing, syntax_ptr.clone()));
+
+                self.alloc_expr(Expr::MacroStmts { tail }, syntax_ptr)
             }
         })
     }
@@ -618,58 +618,54 @@ fn collect_expr_opt(&mut self, expr: Option<ast::Expr>) -> ExprId {
         }
     }
 
-    fn collect_stmt(&mut self, s: ast::Stmt) -> Option<Vec<Statement>> {
-        let stmt = match s {
+    fn collect_stmt(&mut self, s: ast::Stmt) {
+        match s {
             ast::Stmt::LetStmt(stmt) => {
-                self.check_cfg(&stmt)?;
-
+                if self.check_cfg(&stmt).is_none() {
+                    return;
+                }
                 let pat = self.collect_pat_opt(stmt.pat());
                 let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it));
                 let initializer = stmt.initializer().map(|e| self.collect_expr(e));
-                vec![Statement::Let { pat, type_ref, initializer }]
+                self.statements_in_scope.push(Statement::Let { pat, type_ref, initializer });
             }
             ast::Stmt::ExprStmt(stmt) => {
-                self.check_cfg(&stmt)?;
+                if self.check_cfg(&stmt).is_none() {
+                    return;
+                }
 
                 // Note that macro could be expended to multiple statements
                 if let Some(ast::Expr::MacroCall(m)) = stmt.expr() {
                     let syntax_ptr = AstPtr::new(&stmt.expr().unwrap());
-                    let mut stmts = vec![];
 
                     self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| {
                         match expansion {
                             Some(expansion) => {
                                 let statements: ast::MacroStmts = expansion;
 
-                                statements.statements().for_each(|stmt| {
-                                    if let Some(mut r) = this.collect_stmt(stmt) {
-                                        stmts.append(&mut r);
-                                    }
-                                });
+                                statements.statements().for_each(|stmt| this.collect_stmt(stmt));
                                 if let Some(expr) = statements.expr() {
-                                    stmts.push(Statement::Expr(this.collect_expr(expr)));
+                                    let expr = this.collect_expr(expr);
+                                    this.statements_in_scope.push(Statement::Expr(expr));
                                 }
                             }
                             None => {
-                                stmts.push(Statement::Expr(
-                                    this.alloc_expr(Expr::Missing, syntax_ptr.clone()),
-                                ));
+                                let expr = this.alloc_expr(Expr::Missing, syntax_ptr.clone());
+                                this.statements_in_scope.push(Statement::Expr(expr));
                             }
                         }
                     });
-                    stmts
                 } else {
-                    vec![Statement::Expr(self.collect_expr_opt(stmt.expr()))]
+                    let expr = self.collect_expr_opt(stmt.expr());
+                    self.statements_in_scope.push(Statement::Expr(expr));
                 }
             }
             ast::Stmt::Item(item) => {
-                self.check_cfg(&item)?;
-
-                return None;
+                if self.check_cfg(&item).is_none() {
+                    return;
+                }
             }
-        };
-
-        Some(stmt)
+        }
     }
 
     fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId {
@@ -685,10 +681,12 @@ fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId {
         let module = if has_def_map { def_map.root() } else { self.expander.module };
         let prev_def_map = mem::replace(&mut self.expander.def_map, def_map);
         let prev_local_module = mem::replace(&mut self.expander.module, module);
+        let prev_statements = std::mem::take(&mut self.statements_in_scope);
+
+        block.statements().for_each(|s| self.collect_stmt(s));
 
-        let statements =
-            block.statements().filter_map(|s| self.collect_stmt(s)).flatten().collect();
         let tail = block.tail_expr().map(|e| self.collect_expr(e));
+        let statements = std::mem::replace(&mut self.statements_in_scope, prev_statements);
         let syntax_node_ptr = AstPtr::new(&block.into());
         let expr_id = self.alloc_expr(
             Expr::Block { id: block_id, statements, tail, label: None },
index 24be9377395ec21e17bd7521218c5133896bfb31..6c7376fad3abafc5c87e8fae81d946f3eba5c1e7 100644 (file)
@@ -171,6 +171,9 @@ pub enum Expr {
     Unsafe {
         body: ExprId,
     },
+    MacroStmts {
+        tail: ExprId,
+    },
     Array(Array),
     Literal(Literal),
 }
@@ -357,6 +360,7 @@ pub fn walk_child_exprs(&self, mut f: impl FnMut(ExprId)) {
                     f(*repeat)
                 }
             },
+            Expr::MacroStmts { tail } => f(*tail),
             Expr::Literal(_) => {}
         }
     }
index ae2475b4e46abfce0fbf79449c61c3ebf07c9a9e..ca0048b1637f30ff4a44b1d75bfb473faa6febd8 100644 (file)
@@ -110,15 +110,6 @@ pub(crate) fn file_item_tree_query(db: &dyn DefDatabase, file_id: HirFileId) ->
                     // still need to collect inner items.
                     ctx.lower_inner_items(e.syntax())
                 },
-                ast::ExprStmt(stmt) => {
-                    // Macros can expand to stmt. We return an empty item tree in this case, but
-                    // still need to collect inner items.
-                    ctx.lower_inner_items(stmt.syntax())
-                },
-                ast::Item(item) => {
-                    // Macros can expand to stmt and other item, and we add it as top level item
-                    ctx.lower_single_item(item)
-                },
                 _ => {
                     panic!("cannot create item tree from {:?} {}", syntax, syntax);
                 },
index d3fe1ce1eb2f222bda385c9488669dca9caeb933..3f558edd81622e33d2feb6e8ada2b99dc5254740 100644 (file)
@@ -87,14 +87,6 @@ pub(super) fn lower_macro_stmts(mut self, stmts: ast::MacroStmts) -> ItemTree {
         self.tree
     }
 
-    pub(super) fn lower_single_item(mut self, item: ast::Item) -> ItemTree {
-        self.tree.top_level = self
-            .lower_mod_item(&item, false)
-            .map(|item| item.0)
-            .unwrap_or_else(|| Default::default());
-        self.tree
-    }
-
     pub(super) fn lower_inner_items(mut self, within: &SyntaxNode) -> ItemTree {
         self.collect_inner_items(within);
         self.tree
index fc73e435bfbfdcbcf8d00ca927e0b1c057c53f11..d672f67238a2a0bb985988957145721227909cba 100644 (file)
@@ -5,7 +5,13 @@
 use base_db::{salsa, SourceDatabase};
 use mbe::{ExpandError, ExpandResult, MacroRules};
 use parser::FragmentKind;
-use syntax::{algo::diff, ast::NameOwner, AstNode, GreenNode, Parse, SyntaxKind::*, SyntaxNode};
+use syntax::{
+    algo::diff,
+    ast::{MacroStmts, NameOwner},
+    AstNode, GreenNode, Parse,
+    SyntaxKind::*,
+    SyntaxNode,
+};
 
 use crate::{
     ast_id_map::AstIdMap, hygiene::HygieneFrame, BuiltinDeriveExpander, BuiltinFnLikeExpander,
@@ -340,13 +346,19 @@ fn parse_macro_with_arg(
         None => return ExpandResult { value: None, err: result.err },
     };
 
-    log::debug!("expanded = {}", tt.as_debug_string());
-
     let fragment_kind = to_fragment_kind(db, macro_call_id);
 
+    log::debug!("expanded = {}", tt.as_debug_string());
+    log::debug!("kind = {:?}", fragment_kind);
+
     let (parse, rev_token_map) = match mbe::token_tree_to_syntax_node(&tt, fragment_kind) {
         Ok(it) => it,
         Err(err) => {
+            log::debug!(
+                "failed to parse expanstion to {:?} = {}",
+                fragment_kind,
+                tt.as_debug_string()
+            );
             return ExpandResult::only_err(err);
         }
     };
@@ -362,15 +374,34 @@ fn parse_macro_with_arg(
                     return ExpandResult::only_err(err);
                 }
             };
-
-            if !diff(&node, &call_node.value).is_empty() {
-                ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: Some(err) }
-            } else {
+            if is_self_replicating(&node, &call_node.value) {
                 return ExpandResult::only_err(err);
+            } else {
+                ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: Some(err) }
+            }
+        }
+        None => {
+            log::debug!("parse = {:?}", parse.syntax_node().kind());
+            ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: None }
+        }
+    }
+}
+
+fn is_self_replicating(from: &SyntaxNode, to: &SyntaxNode) -> bool {
+    if diff(from, to).is_empty() {
+        return true;
+    }
+    if let Some(stmts) = MacroStmts::cast(from.clone()) {
+        if stmts.statements().any(|stmt| diff(stmt.syntax(), to).is_empty()) {
+            return true;
+        }
+        if let Some(expr) = stmts.expr() {
+            if diff(expr.syntax(), to).is_empty() {
+                return true;
             }
         }
-        None => ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: None },
     }
+    false
 }
 
 fn hygiene_frame(db: &dyn AstDatabase, file_id: HirFileId) -> Arc<HygieneFrame> {
@@ -390,21 +421,15 @@ fn to_fragment_kind(db: &dyn AstDatabase, id: MacroCallId) -> FragmentKind {
 
     let parent = match syn.parent() {
         Some(it) => it,
-        None => {
-            // FIXME:
-            // If it is root, which means the parent HirFile
-            // MacroKindFile must be non-items
-            // return expr now.
-            return FragmentKind::Expr;
-        }
+        None => return FragmentKind::Statements,
     };
 
     match parent.kind() {
         MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items,
-        MACRO_STMTS => FragmentKind::Statement,
+        MACRO_STMTS => FragmentKind::Statements,
         ITEM_LIST => FragmentKind::Items,
         LET_STMT => {
-            // FIXME: Handle Pattern
+            // FIXME: Handle LHS Pattern
             FragmentKind::Expr
         }
         EXPR_STMT => FragmentKind::Statements,
index 3f3187ea275d4fd3c2c90bbb89563bfbcc95d976..e6ede05ca144289878dab272188da978f635a74a 100644 (file)
@@ -767,6 +767,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                     None => self.table.new_float_var(),
                 },
             },
+            Expr::MacroStmts { tail } => self.infer_expr(*tail, expected),
         };
         // use a new type variable if we got unknown here
         let ty = self.insert_type_vars_shallow(ty);
index 7eda518663aedbad797237d0b200c5c7785dce2b..01935ec99431542826b1edd8f7e308cff53a89d7 100644 (file)
@@ -226,11 +226,48 @@ fn foo() {
         "#,
         expect![[r#"
             !0..8 'leta=();': ()
+            !0..8 'leta=();': ()
+            !3..4 'a': ()
+            !5..7 '()': ()
             57..84 '{     ...); } }': ()
         "#]],
     );
 }
 
+#[test]
+fn recurisve_macro_expanded_in_stmts() {
+    check_infer(
+        r#"
+        macro_rules! ng {
+            ([$($tts:tt)*]) => {
+                $($tts)*;
+            };
+            ([$($tts:tt)*] $head:tt $($rest:tt)*) => {
+                ng! {
+                    [$($tts)* $head] $($rest)*
+                }
+            };
+        }
+        fn foo() {
+            ng!([] let a = 3);
+            let b = a;
+        }
+        "#,
+        expect![[r#"
+            !0..7 'leta=3;': {unknown}
+            !0..7 'leta=3;': {unknown}
+            !0..13 'ng!{[leta=3]}': {unknown}
+            !0..13 'ng!{[leta=]3}': {unknown}
+            !0..13 'ng!{[leta]=3}': {unknown}
+            !3..4 'a': i32
+            !5..6 '3': i32
+            196..237 '{     ...= a; }': ()
+            229..230 'b': i32
+            233..234 'a': i32
+        "#]],
+    );
+}
+
 #[test]
 fn recursive_inner_item_macro_rules() {
     check_infer(
@@ -246,7 +283,8 @@ fn foo() {
         "#,
         expect![[r#"
             !0..1 '1': i32
-            !0..7 'mac!($)': {unknown}
+            !0..26 'macro_...>{1};}': {unknown}
+            !0..26 'macro_...>{1};}': {unknown}
             107..143 '{     ...!(); }': ()
             129..130 'a': i32
         "#]],