]> git.lizzy.rs Git - rust.git/blob - crates/hir_def/src/body/scope.rs
Merge #7616
[rust.git] / crates / hir_def / src / body / scope.rs
1 //! Name resolution for expressions.
2 use std::sync::Arc;
3
4 use hir_expand::name::Name;
5 use la_arena::{Arena, Idx};
6 use rustc_hash::FxHashMap;
7
8 use crate::{
9     body::Body,
10     db::DefDatabase,
11     expr::{Expr, ExprId, Pat, PatId, Statement},
12     BlockId, DefWithBodyId,
13 };
14
15 pub type ScopeId = Idx<ScopeData>;
16
17 #[derive(Debug, PartialEq, Eq)]
18 pub struct ExprScopes {
19     scopes: Arena<ScopeData>,
20     scope_by_expr: FxHashMap<ExprId, ScopeId>,
21 }
22
23 #[derive(Debug, PartialEq, Eq)]
24 pub struct ScopeEntry {
25     name: Name,
26     pat: PatId,
27 }
28
29 impl ScopeEntry {
30     pub fn name(&self) -> &Name {
31         &self.name
32     }
33
34     pub fn pat(&self) -> PatId {
35         self.pat
36     }
37 }
38
39 #[derive(Debug, PartialEq, Eq)]
40 pub struct ScopeData {
41     parent: Option<ScopeId>,
42     block: Option<BlockId>,
43     entries: Vec<ScopeEntry>,
44 }
45
46 impl ExprScopes {
47     pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> {
48         let body = db.body(def);
49         Arc::new(ExprScopes::new(&*body))
50     }
51
52     fn new(body: &Body) -> ExprScopes {
53         let mut scopes =
54             ExprScopes { scopes: Arena::default(), scope_by_expr: FxHashMap::default() };
55         let root = scopes.root_scope();
56         scopes.add_params_bindings(body, root, &body.params);
57         compute_expr_scopes(body.body_expr, body, &mut scopes, root);
58         scopes
59     }
60
61     pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
62         &self.scopes[scope].entries
63     }
64
65     /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`.
66     pub fn block(&self, scope: ScopeId) -> Option<BlockId> {
67         self.scopes[scope].block
68     }
69
70     pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
71         std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
72     }
73
74     pub fn resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry> {
75         self.scope_chain(Some(scope))
76             .find_map(|scope| self.entries(scope).iter().find(|it| it.name == *name))
77     }
78
79     pub fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
80         self.scope_by_expr.get(&expr).copied()
81     }
82
83     pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
84         &self.scope_by_expr
85     }
86
87     fn root_scope(&mut self) -> ScopeId {
88         self.scopes.alloc(ScopeData { parent: None, block: None, entries: vec![] })
89     }
90
91     fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
92         self.scopes.alloc(ScopeData { parent: Some(parent), block: None, entries: vec![] })
93     }
94
95     fn new_block_scope(&mut self, parent: ScopeId, block: BlockId) -> ScopeId {
96         self.scopes.alloc(ScopeData { parent: Some(parent), block: Some(block), entries: vec![] })
97     }
98
99     fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
100         let pattern = &body[pat];
101         if let Pat::Bind { name, .. } = pattern {
102             let entry = ScopeEntry { name: name.clone(), pat };
103             self.scopes[scope].entries.push(entry);
104         }
105
106         pattern.walk_child_pats(|pat| self.add_bindings(body, scope, pat));
107     }
108
109     fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) {
110         params.iter().for_each(|pat| self.add_bindings(body, scope, *pat));
111     }
112
113     fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
114         self.scope_by_expr.insert(node, scope);
115     }
116 }
117
118 fn compute_block_scopes(
119     statements: &[Statement],
120     tail: Option<ExprId>,
121     body: &Body,
122     scopes: &mut ExprScopes,
123     mut scope: ScopeId,
124 ) {
125     for stmt in statements {
126         match stmt {
127             Statement::Let { pat, initializer, .. } => {
128                 if let Some(expr) = initializer {
129                     scopes.set_scope(*expr, scope);
130                     compute_expr_scopes(*expr, body, scopes, scope);
131                 }
132                 scope = scopes.new_scope(scope);
133                 scopes.add_bindings(body, scope, *pat);
134             }
135             Statement::Expr(expr) => {
136                 scopes.set_scope(*expr, scope);
137                 compute_expr_scopes(*expr, body, scopes, scope);
138             }
139         }
140     }
141     if let Some(expr) = tail {
142         compute_expr_scopes(expr, body, scopes, scope);
143     }
144 }
145
146 fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
147     scopes.set_scope(expr, scope);
148     match &body[expr] {
149         Expr::Block { statements, tail, id, .. } => {
150             let scope = scopes.new_block_scope(scope, *id);
151             // Overwrite the old scope for the block expr, so that every block scope can be found
152             // via the block itself (important for blocks that only contain items, no expressions).
153             scopes.set_scope(expr, scope);
154             compute_block_scopes(&statements, *tail, body, scopes, scope);
155         }
156         Expr::For { iterable, pat, body: body_expr, .. } => {
157             compute_expr_scopes(*iterable, body, scopes, scope);
158             let scope = scopes.new_scope(scope);
159             scopes.add_bindings(body, scope, *pat);
160             compute_expr_scopes(*body_expr, body, scopes, scope);
161         }
162         Expr::Lambda { args, body: body_expr, .. } => {
163             let scope = scopes.new_scope(scope);
164             scopes.add_params_bindings(body, scope, &args);
165             compute_expr_scopes(*body_expr, body, scopes, scope);
166         }
167         Expr::Match { expr, arms } => {
168             compute_expr_scopes(*expr, body, scopes, scope);
169             for arm in arms {
170                 let scope = scopes.new_scope(scope);
171                 scopes.add_bindings(body, scope, arm.pat);
172                 if let Some(guard) = arm.guard {
173                     scopes.set_scope(guard, scope);
174                     compute_expr_scopes(guard, body, scopes, scope);
175                 }
176                 scopes.set_scope(arm.expr, scope);
177                 compute_expr_scopes(arm.expr, body, scopes, scope);
178             }
179         }
180         e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
181     };
182 }
183
184 #[cfg(test)]
185 mod tests {
186     use base_db::{fixture::WithFixture, FileId, SourceDatabase};
187     use hir_expand::{name::AsName, InFile};
188     use syntax::{algo::find_node_at_offset, ast, AstNode};
189     use test_utils::{assert_eq_text, extract_offset, mark};
190
191     use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
192
193     fn find_function(db: &TestDB, file_id: FileId) -> FunctionId {
194         let krate = db.test_crate();
195         let crate_def_map = db.crate_def_map(krate);
196
197         let module = crate_def_map.modules_for_file(file_id).next().unwrap();
198         let (_, def) = crate_def_map[module].scope.entries().next().unwrap();
199         match def.take_values().unwrap() {
200             ModuleDefId::FunctionId(it) => it,
201             _ => panic!(),
202         }
203     }
204
205     fn do_check(ra_fixture: &str, expected: &[&str]) {
206         let (offset, code) = extract_offset(ra_fixture);
207         let code = {
208             let mut buf = String::new();
209             let off: usize = offset.into();
210             buf.push_str(&code[..off]);
211             buf.push_str("$0marker");
212             buf.push_str(&code[off..]);
213             buf
214         };
215
216         let (db, position) = TestDB::with_position(&code);
217         let file_id = position.file_id;
218         let offset = position.offset;
219
220         let file_syntax = db.parse(file_id).syntax_node();
221         let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap();
222         let function = find_function(&db, file_id);
223
224         let scopes = db.expr_scopes(function.into());
225         let (_body, source_map) = db.body_with_source_map(function.into());
226
227         let expr_id = source_map
228             .node_expr(InFile { file_id: file_id.into(), value: &marker.into() })
229             .unwrap();
230         let scope = scopes.scope_for(expr_id);
231
232         let actual = scopes
233             .scope_chain(scope)
234             .flat_map(|scope| scopes.entries(scope))
235             .map(|it| it.name().to_string())
236             .collect::<Vec<_>>()
237             .join("\n");
238         let expected = expected.join("\n");
239         assert_eq_text!(&expected, &actual);
240     }
241
242     #[test]
243     fn test_lambda_scope() {
244         do_check(
245             r"
246             fn quux(foo: i32) {
247                 let f = |bar, baz: i32| {
248                     $0
249                 };
250             }",
251             &["bar", "baz", "foo"],
252         );
253     }
254
255     #[test]
256     fn test_call_scope() {
257         do_check(
258             r"
259             fn quux() {
260                 f(|x| $0 );
261             }",
262             &["x"],
263         );
264     }
265
266     #[test]
267     fn test_method_call_scope() {
268         do_check(
269             r"
270             fn quux() {
271                 z.f(|x| $0 );
272             }",
273             &["x"],
274         );
275     }
276
277     #[test]
278     fn test_loop_scope() {
279         do_check(
280             r"
281             fn quux() {
282                 loop {
283                     let x = ();
284                     $0
285                 };
286             }",
287             &["x"],
288         );
289     }
290
291     #[test]
292     fn test_match() {
293         do_check(
294             r"
295             fn quux() {
296                 match () {
297                     Some(x) => {
298                         $0
299                     }
300                 };
301             }",
302             &["x"],
303         );
304     }
305
306     #[test]
307     fn test_shadow_variable() {
308         do_check(
309             r"
310             fn foo(x: String) {
311                 let x : &str = &x$0;
312             }",
313             &["x"],
314         );
315     }
316
317     #[test]
318     fn test_bindings_after_at() {
319         do_check(
320             r"
321 fn foo() {
322     match Some(()) {
323         opt @ Some(unit) => {
324             $0
325         }
326         _ => {}
327     }
328 }
329 ",
330             &["opt", "unit"],
331         );
332     }
333
334     #[test]
335     fn macro_inner_item() {
336         do_check(
337             r"
338             macro_rules! mac {
339                 () => {{
340                     fn inner() {}
341                     inner();
342                 }};
343             }
344
345             fn foo() {
346                 mac!();
347                 $0
348             }
349         ",
350             &[],
351         );
352     }
353
354     #[test]
355     fn broken_inner_item() {
356         do_check(
357             r"
358             fn foo() {
359                 trait {}
360                 $0
361             }
362         ",
363             &[],
364         );
365     }
366
367     fn do_check_local_name(ra_fixture: &str, expected_offset: u32) {
368         let (db, position) = TestDB::with_position(ra_fixture);
369         let file_id = position.file_id;
370         let offset = position.offset;
371
372         let file = db.parse(file_id).ok().unwrap();
373         let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
374             .expect("failed to find a name at the target offset");
375         let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), offset).unwrap();
376
377         let function = find_function(&db, file_id);
378
379         let scopes = db.expr_scopes(function.into());
380         let (_body, source_map) = db.body_with_source_map(function.into());
381
382         let expr_scope = {
383             let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
384             let expr_id =
385                 source_map.node_expr(InFile { file_id: file_id.into(), value: &expr_ast }).unwrap();
386             scopes.scope_for(expr_id).unwrap()
387         };
388
389         let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
390         let pat_src = source_map.pat_syntax(resolved.pat()).unwrap();
391
392         let local_name = pat_src.value.either(
393             |it| it.syntax_node_ptr().to_node(file.syntax()),
394             |it| it.syntax_node_ptr().to_node(file.syntax()),
395         );
396         assert_eq!(local_name.text_range(), expected_name.syntax().text_range());
397     }
398
399     #[test]
400     fn test_resolve_local_name() {
401         do_check_local_name(
402             r#"
403 fn foo(x: i32, y: u32) {
404     {
405         let z = x * 2;
406     }
407     {
408         let t = x$0 * 3;
409     }
410 }
411 "#,
412             7,
413         );
414     }
415
416     #[test]
417     fn test_resolve_local_name_declaration() {
418         do_check_local_name(
419             r#"
420 fn foo(x: String) {
421     let x : &str = &x$0;
422 }
423 "#,
424             7,
425         );
426     }
427
428     #[test]
429     fn test_resolve_local_name_shadow() {
430         do_check_local_name(
431             r"
432 fn foo(x: String) {
433     let x : &str = &x;
434     x$0
435 }
436 ",
437             28,
438         );
439     }
440
441     #[test]
442     fn ref_patterns_contribute_bindings() {
443         do_check_local_name(
444             r"
445 fn foo() {
446     if let Some(&from) = bar() {
447         from$0;
448     }
449 }
450 ",
451             28,
452         );
453     }
454
455     #[test]
456     fn while_let_desugaring() {
457         mark::check!(infer_resolve_while_let);
458         do_check_local_name(
459             r#"
460 fn test() {
461     let foo: Option<f32> = None;
462     while let Option::Some(spam) = foo {
463         spam$0
464     }
465 }
466 "#,
467             75,
468         );
469     }
470 }