X-Git-Url: https://git.lizzy.rs/?a=blobdiff_plain;f=crates%2Fhir_def%2Fsrc%2Fbody%2Fscope.rs;h=2658eece8e85e60fa93a8d95dbd4cd9bcb7a3bb4;hb=9c6542f2097df1cfcc9491036ec607c6a2842070;hp=065785da7fa4b40502b71a47c8b29b44732a38f5;hpb=42e00032c6ba07eaf2f7d5886c60133c65d84cf5;p=rust.git diff --git a/crates/hir_def/src/body/scope.rs b/crates/hir_def/src/body/scope.rs index 065785da7fa..2658eece8e8 100644 --- a/crates/hir_def/src/body/scope.rs +++ b/crates/hir_def/src/body/scope.rs @@ -1,15 +1,15 @@ //! Name resolution for expressions. use std::sync::Arc; -use arena::{Arena, Idx}; use hir_expand::name::Name; +use la_arena::{Arena, Idx}; use rustc_hash::FxHashMap; use crate::{ body::Body, db::DefDatabase, - expr::{Expr, ExprId, Pat, PatId, Statement}, - DefWithBodyId, + expr::{Expr, ExprId, LabelId, MatchGuard, Pat, PatId, Statement}, + BlockId, DefWithBodyId, }; pub type ScopeId = Idx; @@ -39,6 +39,8 @@ pub fn pat(&self) -> PatId { #[derive(Debug, PartialEq, Eq)] pub struct ScopeData { parent: Option, + block: Option, + label: Option<(LabelId, Name)>, entries: Vec, } @@ -61,6 +63,16 @@ pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] { &self.scopes[scope].entries } + /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`. + pub fn block(&self, scope: ScopeId) -> Option { + self.scopes[scope].block + } + + /// If `scope` refers to a labeled expression scope, returns the corresponding `Label`. + pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> { + self.scopes[scope].label.clone() + } + pub fn scope_chain(&self, scope: Option) -> impl Iterator + '_ { std::iter::successors(scope, move |&scope| self.scopes[scope].parent) } @@ -79,11 +91,34 @@ pub fn scope_by_expr(&self) -> &FxHashMap { } fn root_scope(&mut self) -> ScopeId { - self.scopes.alloc(ScopeData { parent: None, entries: vec![] }) + self.scopes.alloc(ScopeData { parent: None, block: None, label: None, entries: vec![] }) } fn new_scope(&mut self, parent: ScopeId) -> ScopeId { - self.scopes.alloc(ScopeData { parent: Some(parent), entries: vec![] }) + self.scopes.alloc(ScopeData { + parent: Some(parent), + block: None, + label: None, + entries: vec![], + }) + } + + fn new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId { + self.scopes.alloc(ScopeData { parent: Some(parent), block: None, label, entries: vec![] }) + } + + fn new_block_scope( + &mut self, + parent: ScopeId, + block: BlockId, + label: Option<(LabelId, Name)>, + ) -> ScopeId { + self.scopes.alloc(ScopeData { + parent: Some(parent), + block: Some(block), + label, + entries: vec![], + }) } fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) { @@ -114,16 +149,17 @@ fn compute_block_scopes( ) { for stmt in statements { match stmt { - Statement::Let { pat, initializer, .. } => { + Statement::Let { pat, initializer, else_branch, .. } => { if let Some(expr) = initializer { - scopes.set_scope(*expr, scope); + compute_expr_scopes(*expr, body, scopes, scope); + } + if let Some(expr) = else_branch { compute_expr_scopes(*expr, body, scopes, scope); } scope = scopes.new_scope(scope); scopes.add_bindings(body, scope, *pat); } - Statement::Expr(expr) => { - scopes.set_scope(*expr, scope); + Statement::Expr { expr, .. } => { compute_expr_scopes(*expr, body, scopes, scope); } } @@ -134,31 +170,56 @@ fn compute_block_scopes( } fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) { + let make_label = + |label: &Option| label.map(|label| (label, body.labels[label].name.clone())); + scopes.set_scope(expr, scope); match &body[expr] { - Expr::Block { statements, tail, .. } => { - compute_block_scopes(&statements, *tail, body, scopes, scope); + Expr::Block { statements, tail, id, label } => { + let scope = scopes.new_block_scope(scope, *id, make_label(label)); + // Overwrite the old scope for the block expr, so that every block scope can be found + // via the block itself (important for blocks that only contain items, no expressions). + scopes.set_scope(expr, scope); + compute_block_scopes(statements, *tail, body, scopes, scope); } - Expr::For { iterable, pat, body: body_expr, .. } => { + Expr::For { iterable, pat, body: body_expr, label } => { compute_expr_scopes(*iterable, body, scopes, scope); - let scope = scopes.new_scope(scope); + let scope = scopes.new_labeled_scope(scope, make_label(label)); scopes.add_bindings(body, scope, *pat); compute_expr_scopes(*body_expr, body, scopes, scope); } + Expr::While { condition, body: body_expr, label } => { + let scope = scopes.new_labeled_scope(scope, make_label(label)); + compute_expr_scopes(*condition, body, scopes, scope); + compute_expr_scopes(*body_expr, body, scopes, scope); + } + Expr::Loop { body: body_expr, label } => { + let scope = scopes.new_labeled_scope(scope, make_label(label)); + compute_expr_scopes(*body_expr, body, scopes, scope); + } Expr::Lambda { args, body: body_expr, .. } => { let scope = scopes.new_scope(scope); - scopes.add_params_bindings(body, scope, &args); + scopes.add_params_bindings(body, scope, args); compute_expr_scopes(*body_expr, body, scopes, scope); } Expr::Match { expr, arms } => { compute_expr_scopes(*expr, body, scopes, scope); - for arm in arms { - let scope = scopes.new_scope(scope); + for arm in arms.iter() { + let mut scope = scopes.new_scope(scope); scopes.add_bindings(body, scope, arm.pat); - if let Some(guard) = arm.guard { - scopes.set_scope(guard, scope); - compute_expr_scopes(guard, body, scopes, scope); - } + match arm.guard { + Some(MatchGuard::If { expr: guard }) => { + scopes.set_scope(guard, scope); + compute_expr_scopes(guard, body, scopes, scope); + } + Some(MatchGuard::IfLet { pat, expr: guard }) => { + scopes.set_scope(guard, scope); + compute_expr_scopes(guard, body, scopes, scope); + scope = scopes.new_scope(scope); + scopes.add_bindings(body, scope, pat); + } + _ => {} + }; scopes.set_scope(arm.expr, scope); compute_expr_scopes(arm.expr, body, scopes, scope); } @@ -172,7 +233,7 @@ mod tests { use base_db::{fixture::WithFixture, FileId, SourceDatabase}; use hir_expand::{name::AsName, InFile}; use syntax::{algo::find_node_at_offset, ast, AstNode}; - use test_utils::{assert_eq_text, extract_offset, mark}; + use test_utils::{assert_eq_text, extract_offset}; use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId}; @@ -218,7 +279,7 @@ fn do_check(ra_fixture: &str, expected: &[&str]) { let actual = scopes .scope_chain(scope) .flat_map(|scope| scopes.entries(scope)) - .map(|it| it.name().to_string()) + .map(|it| it.name().to_smol_str()) .collect::>() .join("\n"); let expected = expected.join("\n"); @@ -440,7 +501,7 @@ fn foo() { #[test] fn while_let_desugaring() { - mark::check!(infer_resolve_while_let); + cov_mark::check!(infer_resolve_while_let); do_check_local_name( r#" fn test() {