]> git.lizzy.rs Git - rust.git/blobdiff - crates/hir_def/src/body/scope.rs
parameters.split_last()
[rust.git] / crates / hir_def / src / body / scope.rs
index 9142bc05b8d836251389c1c1fe292b01826063e9..2658eece8e85e60fa93a8d95dbd4cd9bcb7a3bb4 100644 (file)
@@ -1,15 +1,15 @@
 //! Name resolution for expressions.
 use std::sync::Arc;
 
-use arena::{Arena, Idx};
 use hir_expand::name::Name;
+use la_arena::{Arena, Idx};
 use rustc_hash::FxHashMap;
 
 use crate::{
     body::Body,
     db::DefDatabase,
-    expr::{Expr, ExprId, Pat, PatId, Statement},
-    DefWithBodyId,
+    expr::{Expr, ExprId, LabelId, MatchGuard, Pat, PatId, Statement},
+    BlockId, DefWithBodyId,
 };
 
 pub type ScopeId = Idx<ScopeData>;
@@ -39,6 +39,8 @@ pub fn pat(&self) -> PatId {
 #[derive(Debug, PartialEq, Eq)]
 pub struct ScopeData {
     parent: Option<ScopeId>,
+    block: Option<BlockId>,
+    label: Option<(LabelId, Name)>,
     entries: Vec<ScopeEntry>,
 }
 
@@ -61,6 +63,16 @@ pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
         &self.scopes[scope].entries
     }
 
+    /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`.
+    pub fn block(&self, scope: ScopeId) -> Option<BlockId> {
+        self.scopes[scope].block
+    }
+
+    /// If `scope` refers to a labeled expression scope, returns the corresponding `Label`.
+    pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> {
+        self.scopes[scope].label.clone()
+    }
+
     pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
         std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
     }
@@ -79,11 +91,34 @@ pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
     }
 
     fn root_scope(&mut self) -> ScopeId {
-        self.scopes.alloc(ScopeData { parent: None, entries: vec![] })
+        self.scopes.alloc(ScopeData { parent: None, block: None, label: None, entries: vec![] })
     }
 
     fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
-        self.scopes.alloc(ScopeData { parent: Some(parent), entries: vec![] })
+        self.scopes.alloc(ScopeData {
+            parent: Some(parent),
+            block: None,
+            label: None,
+            entries: vec![],
+        })
+    }
+
+    fn new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId {
+        self.scopes.alloc(ScopeData { parent: Some(parent), block: None, label, entries: vec![] })
+    }
+
+    fn new_block_scope(
+        &mut self,
+        parent: ScopeId,
+        block: BlockId,
+        label: Option<(LabelId, Name)>,
+    ) -> ScopeId {
+        self.scopes.alloc(ScopeData {
+            parent: Some(parent),
+            block: Some(block),
+            label,
+            entries: vec![],
+        })
     }
 
     fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
@@ -114,16 +149,17 @@ fn compute_block_scopes(
 ) {
     for stmt in statements {
         match stmt {
-            Statement::Let { pat, initializer, .. } => {
+            Statement::Let { pat, initializer, else_branch, .. } => {
                 if let Some(expr) = initializer {
-                    scopes.set_scope(*expr, scope);
+                    compute_expr_scopes(*expr, body, scopes, scope);
+                }
+                if let Some(expr) = else_branch {
                     compute_expr_scopes(*expr, body, scopes, scope);
                 }
                 scope = scopes.new_scope(scope);
                 scopes.add_bindings(body, scope, *pat);
             }
-            Statement::Expr(expr) => {
-                scopes.set_scope(*expr, scope);
+            Statement::Expr { expr, .. } => {
                 compute_expr_scopes(*expr, body, scopes, scope);
             }
         }
@@ -134,31 +170,56 @@ fn compute_block_scopes(
 }
 
 fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
+    let make_label =
+        |label: &Option<LabelId>| label.map(|label| (label, body.labels[label].name.clone()));
+
     scopes.set_scope(expr, scope);
     match &body[expr] {
-        Expr::Block { statements, tail, .. } => {
-            compute_block_scopes(&statements, *tail, body, scopes, scope);
+        Expr::Block { statements, tail, id, label } => {
+            let scope = scopes.new_block_scope(scope, *id, make_label(label));
+            // Overwrite the old scope for the block expr, so that every block scope can be found
+            // via the block itself (important for blocks that only contain items, no expressions).
+            scopes.set_scope(expr, scope);
+            compute_block_scopes(statements, *tail, body, scopes, scope);
         }
-        Expr::For { iterable, pat, body: body_expr, .. } => {
+        Expr::For { iterable, pat, body: body_expr, label } => {
             compute_expr_scopes(*iterable, body, scopes, scope);
-            let scope = scopes.new_scope(scope);
+            let scope = scopes.new_labeled_scope(scope, make_label(label));
             scopes.add_bindings(body, scope, *pat);
             compute_expr_scopes(*body_expr, body, scopes, scope);
         }
+        Expr::While { condition, body: body_expr, label } => {
+            let scope = scopes.new_labeled_scope(scope, make_label(label));
+            compute_expr_scopes(*condition, body, scopes, scope);
+            compute_expr_scopes(*body_expr, body, scopes, scope);
+        }
+        Expr::Loop { body: body_expr, label } => {
+            let scope = scopes.new_labeled_scope(scope, make_label(label));
+            compute_expr_scopes(*body_expr, body, scopes, scope);
+        }
         Expr::Lambda { args, body: body_expr, .. } => {
             let scope = scopes.new_scope(scope);
-            scopes.add_params_bindings(body, scope, &args);
+            scopes.add_params_bindings(body, scope, args);
             compute_expr_scopes(*body_expr, body, scopes, scope);
         }
         Expr::Match { expr, arms } => {
             compute_expr_scopes(*expr, body, scopes, scope);
-            for arm in arms {
-                let scope = scopes.new_scope(scope);
+            for arm in arms.iter() {
+                let mut scope = scopes.new_scope(scope);
                 scopes.add_bindings(body, scope, arm.pat);
-                if let Some(guard) = arm.guard {
-                    scopes.set_scope(guard, scope);
-                    compute_expr_scopes(guard, body, scopes, scope);
-                }
+                match arm.guard {
+                    Some(MatchGuard::If { expr: guard }) => {
+                        scopes.set_scope(guard, scope);
+                        compute_expr_scopes(guard, body, scopes, scope);
+                    }
+                    Some(MatchGuard::IfLet { pat, expr: guard }) => {
+                        scopes.set_scope(guard, scope);
+                        compute_expr_scopes(guard, body, scopes, scope);
+                        scope = scopes.new_scope(scope);
+                        scopes.add_bindings(body, scope, pat);
+                    }
+                    _ => {}
+                };
                 scopes.set_scope(arm.expr, scope);
                 compute_expr_scopes(arm.expr, body, scopes, scope);
             }
@@ -172,7 +233,7 @@ mod tests {
     use base_db::{fixture::WithFixture, FileId, SourceDatabase};
     use hir_expand::{name::AsName, InFile};
     use syntax::{algo::find_node_at_offset, ast, AstNode};
-    use test_utils::{assert_eq_text, extract_offset, mark};
+    use test_utils::{assert_eq_text, extract_offset};
 
     use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
 
@@ -194,7 +255,7 @@ fn do_check(ra_fixture: &str, expected: &[&str]) {
             let mut buf = String::new();
             let off: usize = offset.into();
             buf.push_str(&code[..off]);
-            buf.push_str("<|>marker");
+            buf.push_str("$0marker");
             buf.push_str(&code[off..]);
             buf
         };
@@ -218,7 +279,7 @@ fn do_check(ra_fixture: &str, expected: &[&str]) {
         let actual = scopes
             .scope_chain(scope)
             .flat_map(|scope| scopes.entries(scope))
-            .map(|it| it.name().to_string())
+            .map(|it| it.name().to_smol_str())
             .collect::<Vec<_>>()
             .join("\n");
         let expected = expected.join("\n");
@@ -231,7 +292,7 @@ fn test_lambda_scope() {
             r"
             fn quux(foo: i32) {
                 let f = |bar, baz: i32| {
-                    <|>
+                    $0
                 };
             }",
             &["bar", "baz", "foo"],
@@ -243,7 +304,7 @@ fn test_call_scope() {
         do_check(
             r"
             fn quux() {
-                f(|x| <|> );
+                f(|x| $0 );
             }",
             &["x"],
         );
@@ -254,7 +315,7 @@ fn test_method_call_scope() {
         do_check(
             r"
             fn quux() {
-                z.f(|x| <|> );
+                z.f(|x| $0 );
             }",
             &["x"],
         );
@@ -267,7 +328,7 @@ fn test_loop_scope() {
             fn quux() {
                 loop {
                     let x = ();
-                    <|>
+                    $0
                 };
             }",
             &["x"],
@@ -281,7 +342,7 @@ fn test_match() {
             fn quux() {
                 match () {
                     Some(x) => {
-                        <|>
+                        $0
                     }
                 };
             }",
@@ -294,7 +355,7 @@ fn test_shadow_variable() {
         do_check(
             r"
             fn foo(x: String) {
-                let x : &str = &x<|>;
+                let x : &str = &x$0;
             }",
             &["x"],
         );
@@ -307,7 +368,7 @@ fn test_bindings_after_at() {
 fn foo() {
     match Some(()) {
         opt @ Some(unit) => {
-            <|>
+            $0
         }
         _ => {}
     }
@@ -330,7 +391,7 @@ fn inner() {}
 
             fn foo() {
                 mac!();
-                <|>
+                $0
             }
         ",
             &[],
@@ -343,7 +404,7 @@ fn broken_inner_item() {
             r"
             fn foo() {
                 trait {}
-                <|>
+                $0
             }
         ",
             &[],
@@ -391,7 +452,7 @@ fn foo(x: i32, y: u32) {
         let z = x * 2;
     }
     {
-        let t = x<|> * 3;
+        let t = x$0 * 3;
     }
 }
 "#,
@@ -404,7 +465,7 @@ fn test_resolve_local_name_declaration() {
         do_check_local_name(
             r#"
 fn foo(x: String) {
-    let x : &str = &x<|>;
+    let x : &str = &x$0;
 }
 "#,
             7,
@@ -417,7 +478,7 @@ fn test_resolve_local_name_shadow() {
             r"
 fn foo(x: String) {
     let x : &str = &x;
-    x<|>
+    x$0
 }
 ",
             28,
@@ -430,7 +491,7 @@ fn ref_patterns_contribute_bindings() {
             r"
 fn foo() {
     if let Some(&from) = bar() {
-        from<|>;
+        from$0;
     }
 }
 ",
@@ -440,13 +501,13 @@ fn foo() {
 
     #[test]
     fn while_let_desugaring() {
-        mark::check!(infer_resolve_while_let);
+        cov_mark::check!(infer_resolve_while_let);
         do_check_local_name(
             r#"
 fn test() {
     let foo: Option<f32> = None;
     while let Option::Some(spam) = foo {
-        spam<|>
+        spam$0
     }
 }
 "#,