]> git.lizzy.rs Git - rust.git/blobdiff - crates/hir_def/src/body/lower.rs
Merge #8398
[rust.git] / crates / hir_def / src / body / lower.rs
index 229e81dd43a6fb9c5f93b1c094ee471b9e7e3d74..9f278d35b73a83ffdd9b8e102e7a7c0d0d4c6369 100644 (file)
@@ -1,10 +1,11 @@
 //! Transforms `ast::Expr` into an equivalent `hir_def::expr::Expr`
 //! representation.
 
-use std::mem;
+use std::{mem, sync::Arc};
 
 use either::Either;
 use hir_expand::{
+    ast_id_map::{AstIdMap, FileAstId},
     hygiene::Hygiene,
     name::{name, AsName, Name},
     ExpandError, HirFileId,
@@ -30,6 +31,7 @@
         LabelId, Literal, LogicOp, MatchArm, Ordering, Pat, PatId, RecordFieldPat, RecordLitField,
         Statement,
     },
+    intern::Interned,
     item_scope::BuiltinShadowMode,
     path::{GenericArgs, Path},
     type_ref::{Mutability, Rawness, TypeRef},
 
 use super::{diagnostics::BodyDiagnostic, ExprSource, PatSource};
 
-pub(crate) struct LowerCtx {
+pub struct LowerCtx<'a> {
+    pub db: &'a dyn DefDatabase,
     hygiene: Hygiene,
+    file_id: Option<HirFileId>,
+    source_ast_id_map: Option<Arc<AstIdMap>>,
 }
 
-impl LowerCtx {
-    pub(crate) fn new(db: &dyn DefDatabase, file_id: HirFileId) -> Self {
-        LowerCtx { hygiene: Hygiene::new(db.upcast(), file_id) }
+impl<'a> LowerCtx<'a> {
+    pub fn new(db: &'a dyn DefDatabase, file_id: HirFileId) -> Self {
+        LowerCtx {
+            db,
+            hygiene: Hygiene::new(db.upcast(), file_id),
+            file_id: Some(file_id),
+            source_ast_id_map: Some(db.ast_id_map(file_id)),
+        }
+    }
+
+    pub fn with_hygiene(db: &'a dyn DefDatabase, hygiene: &Hygiene) -> Self {
+        LowerCtx { db, hygiene: hygiene.clone(), file_id: None, source_ast_id_map: None }
     }
-    pub(crate) fn with_hygiene(hygiene: &Hygiene) -> Self {
-        LowerCtx { hygiene: hygiene.clone() }
+
+    pub(crate) fn hygiene(&self) -> &Hygiene {
+        &self.hygiene
+    }
+
+    pub(crate) fn file_id(&self) -> HirFileId {
+        self.file_id.unwrap()
     }
 
     pub(crate) fn lower_path(&self, ast: ast::Path) -> Option<Path> {
-        Path::from_src(ast, &self.hygiene)
+        Path::from_src(ast, self)
+    }
+
+    pub(crate) fn ast_id<N: AstNode>(&self, item: &N) -> Option<FileAstId<N>> {
+        self.source_ast_id_map.as_ref().map(|ast_id_map| ast_id_map.ast_id(item))
     }
 }
 
@@ -124,7 +147,7 @@ fn collect(
         (self.body, self.source_map)
     }
 
-    fn ctx(&self) -> LowerCtx {
+    fn ctx(&self) -> LowerCtx<'_> {
         LowerCtx::new(self.db, self.expander.current_file_id)
     }
 
@@ -182,7 +205,7 @@ fn collect_expr(&mut self, expr: ast::Expr) -> ExprId {
         self.maybe_collect_expr(expr).unwrap_or_else(|| self.missing_expr())
     }
 
-    /// Returns `None` if the expression is `#[cfg]`d out.
+    /// Returns `None` if and only if the expression is `#[cfg]`d out.
     fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
         let syntax_ptr = AstPtr::new(&expr);
         self.check_cfg(&expr)?;
@@ -322,8 +345,10 @@ fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
                     Vec::new()
                 };
                 let method_name = e.name_ref().map(|nr| nr.as_name()).unwrap_or_else(Name::missing);
-                let generic_args =
-                    e.generic_arg_list().and_then(|it| GenericArgs::from_ast(&self.ctx(), it));
+                let generic_args = e
+                    .generic_arg_list()
+                    .and_then(|it| GenericArgs::from_ast(&self.ctx(), it))
+                    .map(Box::new);
                 self.alloc_expr(
                     Expr::MethodCall { receiver, method_name, args, generic_args },
                     syntax_ptr,
@@ -353,7 +378,7 @@ fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
             ast::Expr::PathExpr(e) => {
                 let path = e
                     .path()
-                    .and_then(|path| self.expander.parse_path(path))
+                    .and_then(|path| self.expander.parse_path(self.db, path))
                     .map(Expr::Path)
                     .unwrap_or(Expr::Missing);
                 self.alloc_expr(path, syntax_ptr)
@@ -385,7 +410,8 @@ fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
                 self.alloc_expr(Expr::Yield { expr }, syntax_ptr)
             }
             ast::Expr::RecordExpr(e) => {
-                let path = e.path().and_then(|path| self.expander.parse_path(path));
+                let path =
+                    e.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new);
                 let record_lit = if let Some(nfl) = e.record_expr_field_list() {
                     let fields = nfl
                         .fields()
@@ -430,7 +456,7 @@ fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
             }
             ast::Expr::CastExpr(e) => {
                 let expr = self.collect_expr_opt(e.expr());
-                let type_ref = TypeRef::from_ast_opt(&self.ctx(), e.ty());
+                let type_ref = Interned::new(TypeRef::from_ast_opt(&self.ctx(), e.ty()));
                 self.alloc_expr(Expr::Cast { expr, type_ref }, syntax_ptr)
             }
             ast::Expr::RefExpr(e) => {
@@ -464,13 +490,16 @@ fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
                 if let Some(pl) = e.param_list() {
                     for param in pl.params() {
                         let pat = self.collect_pat_opt(param.pat());
-                        let type_ref = param.ty().map(|it| TypeRef::from_ast(&self.ctx(), it));
+                        let type_ref =
+                            param.ty().map(|it| Interned::new(TypeRef::from_ast(&self.ctx(), it)));
                         args.push(pat);
                         arg_types.push(type_ref);
                     }
                 }
-                let ret_type =
-                    e.ret_type().and_then(|r| r.ty()).map(|it| TypeRef::from_ast(&self.ctx(), it));
+                let ret_type = e
+                    .ret_type()
+                    .and_then(|r| r.ty())
+                    .map(|it| Interned::new(TypeRef::from_ast(&self.ctx(), it)));
                 let body = self.collect_expr_opt(e.body());
                 self.alloc_expr(Expr::Lambda { args, arg_types, ret_type, body }, syntax_ptr)
             }
@@ -525,8 +554,9 @@ fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
                 }
             }
             ast::Expr::MacroCall(e) => {
+                let macro_ptr = AstPtr::new(&e);
                 let mut ids = vec![];
-                self.collect_macro_call(e, syntax_ptr.clone(), true, |this, expansion| {
+                self.collect_macro_call(e, macro_ptr, true, |this, expansion| {
                     ids.push(match expansion {
                         Some(it) => this.collect_expr(it),
                         None => this.alloc_expr(Expr::Missing, syntax_ptr.clone()),
@@ -549,7 +579,7 @@ fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
     fn collect_macro_call<F: FnMut(&mut Self, Option<T>), T: ast::AstNode>(
         &mut self,
         e: ast::MacroCall,
-        syntax_ptr: AstPtr<ast::Expr>,
+        syntax_ptr: AstPtr<ast::MacroCall>,
         is_error_recoverable: bool,
         mut collector: F,
     ) {
@@ -561,9 +591,13 @@ fn collect_macro_call<F: FnMut(&mut Self, Option<T>), T: ast::AstNode>(
 
         let res = match res {
             Ok(res) => res,
-            Err(UnresolvedMacro) => {
+            Err(UnresolvedMacro { path }) => {
                 self.source_map.diagnostics.push(BodyDiagnostic::UnresolvedMacroCall(
-                    UnresolvedMacroCall { file: outer_file, node: syntax_ptr.cast().unwrap() },
+                    UnresolvedMacroCall {
+                        file: outer_file,
+                        node: syntax_ptr.cast().unwrap(),
+                        path,
+                    },
                 ));
                 collector(self, None);
                 return;
@@ -625,7 +659,8 @@ fn collect_stmt(&mut self, s: ast::Stmt) {
                     return;
                 }
                 let pat = self.collect_pat_opt(stmt.pat());
-                let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it));
+                let type_ref =
+                    stmt.ty().map(|it| Interned::new(TypeRef::from_ast(&self.ctx(), it)));
                 let initializer = stmt.initializer().map(|e| self.collect_expr(e));
                 self.statements_in_scope.push(Statement::Let { pat, type_ref, initializer });
             }
@@ -633,31 +668,36 @@ fn collect_stmt(&mut self, s: ast::Stmt) {
                 if self.check_cfg(&stmt).is_none() {
                     return;
                 }
-
+                let has_semi = stmt.semicolon_token().is_some();
                 // Note that macro could be expended to multiple statements
                 if let Some(ast::Expr::MacroCall(m)) = stmt.expr() {
+                    let macro_ptr = AstPtr::new(&m);
                     let syntax_ptr = AstPtr::new(&stmt.expr().unwrap());
 
-                    self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| {
-                        match expansion {
+                    self.collect_macro_call(
+                        m,
+                        macro_ptr,
+                        false,
+                        |this, expansion| match expansion {
                             Some(expansion) => {
                                 let statements: ast::MacroStmts = expansion;
 
                                 statements.statements().for_each(|stmt| this.collect_stmt(stmt));
                                 if let Some(expr) = statements.expr() {
                                     let expr = this.collect_expr(expr);
-                                    this.statements_in_scope.push(Statement::Expr(expr));
+                                    this.statements_in_scope
+                                        .push(Statement::Expr { expr, has_semi });
                                 }
                             }
                             None => {
                                 let expr = this.alloc_expr(Expr::Missing, syntax_ptr.clone());
-                                this.statements_in_scope.push(Statement::Expr(expr));
+                                this.statements_in_scope.push(Statement::Expr { expr, has_semi });
                             }
-                        }
-                    });
+                        },
+                    );
                 } else {
                     let expr = self.collect_expr_opt(stmt.expr());
-                    self.statements_in_scope.push(Statement::Expr(expr));
+                    self.statements_in_scope.push(Statement::Expr { expr, has_semi });
                 }
             }
             ast::Stmt::Item(item) => {
@@ -673,19 +713,30 @@ fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId {
         let block_loc =
             BlockLoc { ast_id, module: self.expander.def_map.module_id(self.expander.module) };
         let block_id = self.db.intern_block(block_loc);
-        self.body.block_scopes.push(block_id);
 
-        let opt_def_map = self.db.block_def_map(block_id);
-        let has_def_map = opt_def_map.is_some();
-        let def_map = opt_def_map.unwrap_or_else(|| self.expander.def_map.clone());
-        let module = if has_def_map { def_map.root() } else { self.expander.module };
+        let (module, def_map) = match self.db.block_def_map(block_id) {
+            Some(def_map) => {
+                self.body.block_scopes.push(block_id);
+                (def_map.root(), def_map)
+            }
+            None => (self.expander.module, self.expander.def_map.clone()),
+        };
         let prev_def_map = mem::replace(&mut self.expander.def_map, def_map);
         let prev_local_module = mem::replace(&mut self.expander.module, module);
         let prev_statements = std::mem::take(&mut self.statements_in_scope);
 
         block.statements().for_each(|s| self.collect_stmt(s));
+        block.tail_expr().and_then(|e| {
+            let expr = self.maybe_collect_expr(e)?;
+            Some(self.statements_in_scope.push(Statement::Expr { expr, has_semi: false }))
+        });
 
-        let tail = block.tail_expr().map(|e| self.collect_expr(e));
+        let mut tail = None;
+        if let Some(Statement::Expr { expr, has_semi: false }) = self.statements_in_scope.last() {
+            tail = Some(*expr);
+            self.statements_in_scope.pop();
+        }
+        let tail = tail;
         let statements = std::mem::replace(&mut self.statements_in_scope, prev_statements);
         let syntax_node_ptr = AstPtr::new(&block.into());
         let expr_id = self.alloc_expr(
@@ -753,7 +804,8 @@ fn collect_pat(&mut self, pat: ast::Pat) -> PatId {
                 }
             }
             ast::Pat::TupleStructPat(p) => {
-                let path = p.path().and_then(|path| self.expander.parse_path(path));
+                let path =
+                    p.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new);
                 let (args, ellipsis) = self.collect_tuple_pat(p.fields());
                 Pat::TupleStruct { path, args, ellipsis }
             }
@@ -763,7 +815,8 @@ fn collect_pat(&mut self, pat: ast::Pat) -> PatId {
                 Pat::Ref { pat, mutability }
             }
             ast::Pat::PathPat(p) => {
-                let path = p.path().and_then(|path| self.expander.parse_path(path));
+                let path =
+                    p.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new);
                 path.map(Pat::Path).unwrap_or(Pat::Missing)
             }
             ast::Pat::OrPat(p) => {
@@ -777,7 +830,8 @@ fn collect_pat(&mut self, pat: ast::Pat) -> PatId {
             }
             ast::Pat::WildcardPat(_) => Pat::Wild,
             ast::Pat::RecordPat(p) => {
-                let path = p.path().and_then(|path| self.expander.parse_path(path));
+                let path =
+                    p.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new);
                 let args: Vec<_> = p
                     .record_pat_field_list()
                     .expect("every struct should have a field list")
@@ -839,8 +893,23 @@ fn collect_pat(&mut self, pat: ast::Pat) -> PatId {
                     Pat::Missing
                 }
             }
+            ast::Pat::MacroPat(mac) => match mac.macro_call() {
+                Some(call) => {
+                    let macro_ptr = AstPtr::new(&call);
+                    let mut pat = None;
+                    self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| {
+                        pat = Some(this.collect_pat_opt(expanded_pat));
+                    });
+
+                    match pat {
+                        Some(pat) => return pat,
+                        None => Pat::Missing,
+                    }
+                }
+                None => Pat::Missing,
+            },
             // FIXME: implement
-            ast::Pat::RangePat(_) | ast::Pat::MacroPat(_) => Pat::Missing,
+            ast::Pat::RangePat(_) => Pat::Missing,
         };
         let ptr = AstPtr::new(&pat);
         self.alloc_pat(pattern, Either::Left(ptr))