]> git.lizzy.rs Git - rust.git/blobdiff - crates/hir_def/src/body/scope.rs
parameters.split_last()
[rust.git] / crates / hir_def / src / body / scope.rs
index 49f1427b45e2767579088f731207e3196059c1ec..2658eece8e85e60fa93a8d95dbd4cd9bcb7a3bb4 100644 (file)
@@ -8,8 +8,8 @@
 use crate::{
     body::Body,
     db::DefDatabase,
-    expr::{Expr, ExprId, Pat, PatId, Statement},
-    DefWithBodyId,
+    expr::{Expr, ExprId, LabelId, MatchGuard, Pat, PatId, Statement},
+    BlockId, DefWithBodyId,
 };
 
 pub type ScopeId = Idx<ScopeData>;
@@ -39,6 +39,8 @@ pub fn pat(&self) -> PatId {
 #[derive(Debug, PartialEq, Eq)]
 pub struct ScopeData {
     parent: Option<ScopeId>,
+    block: Option<BlockId>,
+    label: Option<(LabelId, Name)>,
     entries: Vec<ScopeEntry>,
 }
 
@@ -61,6 +63,16 @@ pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
         &self.scopes[scope].entries
     }
 
+    /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`.
+    pub fn block(&self, scope: ScopeId) -> Option<BlockId> {
+        self.scopes[scope].block
+    }
+
+    /// If `scope` refers to a labeled expression scope, returns the corresponding `Label`.
+    pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> {
+        self.scopes[scope].label.clone()
+    }
+
     pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
         std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
     }
@@ -79,11 +91,34 @@ pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
     }
 
     fn root_scope(&mut self) -> ScopeId {
-        self.scopes.alloc(ScopeData { parent: None, entries: vec![] })
+        self.scopes.alloc(ScopeData { parent: None, block: None, label: None, entries: vec![] })
     }
 
     fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
-        self.scopes.alloc(ScopeData { parent: Some(parent), entries: vec![] })
+        self.scopes.alloc(ScopeData {
+            parent: Some(parent),
+            block: None,
+            label: None,
+            entries: vec![],
+        })
+    }
+
+    fn new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId {
+        self.scopes.alloc(ScopeData { parent: Some(parent), block: None, label, entries: vec![] })
+    }
+
+    fn new_block_scope(
+        &mut self,
+        parent: ScopeId,
+        block: BlockId,
+        label: Option<(LabelId, Name)>,
+    ) -> ScopeId {
+        self.scopes.alloc(ScopeData {
+            parent: Some(parent),
+            block: Some(block),
+            label,
+            entries: vec![],
+        })
     }
 
     fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
@@ -114,16 +149,17 @@ fn compute_block_scopes(
 ) {
     for stmt in statements {
         match stmt {
-            Statement::Let { pat, initializer, .. } => {
+            Statement::Let { pat, initializer, else_branch, .. } => {
                 if let Some(expr) = initializer {
-                    scopes.set_scope(*expr, scope);
+                    compute_expr_scopes(*expr, body, scopes, scope);
+                }
+                if let Some(expr) = else_branch {
                     compute_expr_scopes(*expr, body, scopes, scope);
                 }
                 scope = scopes.new_scope(scope);
                 scopes.add_bindings(body, scope, *pat);
             }
-            Statement::Expr(expr) => {
-                scopes.set_scope(*expr, scope);
+            Statement::Expr { expr, .. } => {
                 compute_expr_scopes(*expr, body, scopes, scope);
             }
         }
@@ -134,31 +170,56 @@ fn compute_block_scopes(
 }
 
 fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
+    let make_label =
+        |label: &Option<LabelId>| label.map(|label| (label, body.labels[label].name.clone()));
+
     scopes.set_scope(expr, scope);
     match &body[expr] {
-        Expr::Block { statements, tail, .. } => {
-            compute_block_scopes(&statements, *tail, body, scopes, scope);
+        Expr::Block { statements, tail, id, label } => {
+            let scope = scopes.new_block_scope(scope, *id, make_label(label));
+            // Overwrite the old scope for the block expr, so that every block scope can be found
+            // via the block itself (important for blocks that only contain items, no expressions).
+            scopes.set_scope(expr, scope);
+            compute_block_scopes(statements, *tail, body, scopes, scope);
         }
-        Expr::For { iterable, pat, body: body_expr, .. } => {
+        Expr::For { iterable, pat, body: body_expr, label } => {
             compute_expr_scopes(*iterable, body, scopes, scope);
-            let scope = scopes.new_scope(scope);
+            let scope = scopes.new_labeled_scope(scope, make_label(label));
             scopes.add_bindings(body, scope, *pat);
             compute_expr_scopes(*body_expr, body, scopes, scope);
         }
+        Expr::While { condition, body: body_expr, label } => {
+            let scope = scopes.new_labeled_scope(scope, make_label(label));
+            compute_expr_scopes(*condition, body, scopes, scope);
+            compute_expr_scopes(*body_expr, body, scopes, scope);
+        }
+        Expr::Loop { body: body_expr, label } => {
+            let scope = scopes.new_labeled_scope(scope, make_label(label));
+            compute_expr_scopes(*body_expr, body, scopes, scope);
+        }
         Expr::Lambda { args, body: body_expr, .. } => {
             let scope = scopes.new_scope(scope);
-            scopes.add_params_bindings(body, scope, &args);
+            scopes.add_params_bindings(body, scope, args);
             compute_expr_scopes(*body_expr, body, scopes, scope);
         }
         Expr::Match { expr, arms } => {
             compute_expr_scopes(*expr, body, scopes, scope);
-            for arm in arms {
-                let scope = scopes.new_scope(scope);
+            for arm in arms.iter() {
+                let mut scope = scopes.new_scope(scope);
                 scopes.add_bindings(body, scope, arm.pat);
-                if let Some(guard) = arm.guard {
-                    scopes.set_scope(guard, scope);
-                    compute_expr_scopes(guard, body, scopes, scope);
-                }
+                match arm.guard {
+                    Some(MatchGuard::If { expr: guard }) => {
+                        scopes.set_scope(guard, scope);
+                        compute_expr_scopes(guard, body, scopes, scope);
+                    }
+                    Some(MatchGuard::IfLet { pat, expr: guard }) => {
+                        scopes.set_scope(guard, scope);
+                        compute_expr_scopes(guard, body, scopes, scope);
+                        scope = scopes.new_scope(scope);
+                        scopes.add_bindings(body, scope, pat);
+                    }
+                    _ => {}
+                };
                 scopes.set_scope(arm.expr, scope);
                 compute_expr_scopes(arm.expr, body, scopes, scope);
             }
@@ -172,7 +233,7 @@ mod tests {
     use base_db::{fixture::WithFixture, FileId, SourceDatabase};
     use hir_expand::{name::AsName, InFile};
     use syntax::{algo::find_node_at_offset, ast, AstNode};
-    use test_utils::{assert_eq_text, extract_offset, mark};
+    use test_utils::{assert_eq_text, extract_offset};
 
     use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
 
@@ -218,7 +279,7 @@ fn do_check(ra_fixture: &str, expected: &[&str]) {
         let actual = scopes
             .scope_chain(scope)
             .flat_map(|scope| scopes.entries(scope))
-            .map(|it| it.name().to_string())
+            .map(|it| it.name().to_smol_str())
             .collect::<Vec<_>>()
             .join("\n");
         let expected = expected.join("\n");
@@ -440,7 +501,7 @@ fn foo() {
 
     #[test]
     fn while_let_desugaring() {
-        mark::check!(infer_resolve_while_let);
+        cov_mark::check!(infer_resolve_while_let);
         do_check_local_name(
             r#"
 fn test() {