]> git.lizzy.rs Git - rust.git/blob - crates/hir_def/src/body/scope.rs
Merge #9879
[rust.git] / crates / hir_def / src / body / scope.rs
1 //! Name resolution for expressions.
2 use std::sync::Arc;
3
4 use hir_expand::name::Name;
5 use la_arena::{Arena, Idx};
6 use rustc_hash::FxHashMap;
7
8 use crate::{
9     body::Body,
10     db::DefDatabase,
11     expr::{Expr, ExprId, LabelId, MatchGuard, Pat, PatId, Statement},
12     BlockId, DefWithBodyId,
13 };
14
15 pub type ScopeId = Idx<ScopeData>;
16
17 #[derive(Debug, PartialEq, Eq)]
18 pub struct ExprScopes {
19     scopes: Arena<ScopeData>,
20     scope_by_expr: FxHashMap<ExprId, ScopeId>,
21 }
22
23 #[derive(Debug, PartialEq, Eq)]
24 pub struct ScopeEntry {
25     name: Name,
26     pat: PatId,
27 }
28
29 impl ScopeEntry {
30     pub fn name(&self) -> &Name {
31         &self.name
32     }
33
34     pub fn pat(&self) -> PatId {
35         self.pat
36     }
37 }
38
39 #[derive(Debug, PartialEq, Eq)]
40 pub struct ScopeData {
41     parent: Option<ScopeId>,
42     block: Option<BlockId>,
43     label: Option<(LabelId, Name)>,
44     entries: Vec<ScopeEntry>,
45 }
46
47 impl ExprScopes {
48     pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> {
49         let body = db.body(def);
50         Arc::new(ExprScopes::new(&*body))
51     }
52
53     fn new(body: &Body) -> ExprScopes {
54         let mut scopes =
55             ExprScopes { scopes: Arena::default(), scope_by_expr: FxHashMap::default() };
56         let root = scopes.root_scope();
57         scopes.add_params_bindings(body, root, &body.params);
58         compute_expr_scopes(body.body_expr, body, &mut scopes, root);
59         scopes
60     }
61
62     pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
63         &self.scopes[scope].entries
64     }
65
66     /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`.
67     pub fn block(&self, scope: ScopeId) -> Option<BlockId> {
68         self.scopes[scope].block
69     }
70
71     /// If `scope` refers to a labeled expression scope, returns the corresponding `Label`.
72     pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> {
73         self.scopes[scope].label.clone()
74     }
75
76     pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
77         std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
78     }
79
80     pub fn resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry> {
81         self.scope_chain(Some(scope))
82             .find_map(|scope| self.entries(scope).iter().find(|it| it.name == *name))
83     }
84
85     pub fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
86         self.scope_by_expr.get(&expr).copied()
87     }
88
89     pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
90         &self.scope_by_expr
91     }
92
93     fn root_scope(&mut self) -> ScopeId {
94         self.scopes.alloc(ScopeData { parent: None, block: None, label: None, entries: vec![] })
95     }
96
97     fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
98         self.scopes.alloc(ScopeData {
99             parent: Some(parent),
100             block: None,
101             label: None,
102             entries: vec![],
103         })
104     }
105
106     fn new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId {
107         self.scopes.alloc(ScopeData { parent: Some(parent), block: None, label, entries: vec![] })
108     }
109
110     fn new_block_scope(
111         &mut self,
112         parent: ScopeId,
113         block: BlockId,
114         label: Option<(LabelId, Name)>,
115     ) -> ScopeId {
116         self.scopes.alloc(ScopeData {
117             parent: Some(parent),
118             block: Some(block),
119             label,
120             entries: vec![],
121         })
122     }
123
124     fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
125         let pattern = &body[pat];
126         if let Pat::Bind { name, .. } = pattern {
127             let entry = ScopeEntry { name: name.clone(), pat };
128             self.scopes[scope].entries.push(entry);
129         }
130
131         pattern.walk_child_pats(|pat| self.add_bindings(body, scope, pat));
132     }
133
134     fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) {
135         params.iter().for_each(|pat| self.add_bindings(body, scope, *pat));
136     }
137
138     fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
139         self.scope_by_expr.insert(node, scope);
140     }
141 }
142
143 fn compute_block_scopes(
144     statements: &[Statement],
145     tail: Option<ExprId>,
146     body: &Body,
147     scopes: &mut ExprScopes,
148     mut scope: ScopeId,
149 ) {
150     for stmt in statements {
151         match stmt {
152             Statement::Let { pat, initializer, .. } => {
153                 if let Some(expr) = initializer {
154                     scopes.set_scope(*expr, scope);
155                     compute_expr_scopes(*expr, body, scopes, scope);
156                 }
157                 scope = scopes.new_scope(scope);
158                 scopes.add_bindings(body, scope, *pat);
159             }
160             Statement::Expr { expr, .. } => {
161                 scopes.set_scope(*expr, scope);
162                 compute_expr_scopes(*expr, body, scopes, scope);
163             }
164         }
165     }
166     if let Some(expr) = tail {
167         compute_expr_scopes(expr, body, scopes, scope);
168     }
169 }
170
171 fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
172     let make_label =
173         |label: &Option<_>| label.map(|label| (label, body.labels[label].name.clone()));
174
175     scopes.set_scope(expr, scope);
176     match &body[expr] {
177         Expr::Block { statements, tail, id, label } => {
178             let scope = scopes.new_block_scope(scope, *id, make_label(label));
179             // Overwrite the old scope for the block expr, so that every block scope can be found
180             // via the block itself (important for blocks that only contain items, no expressions).
181             scopes.set_scope(expr, scope);
182             compute_block_scopes(statements, *tail, body, scopes, scope);
183         }
184         Expr::For { iterable, pat, body: body_expr, label } => {
185             compute_expr_scopes(*iterable, body, scopes, scope);
186             let scope = scopes.new_labeled_scope(scope, make_label(label));
187             scopes.add_bindings(body, scope, *pat);
188             compute_expr_scopes(*body_expr, body, scopes, scope);
189         }
190         Expr::While { condition, body: body_expr, label } => {
191             let scope = scopes.new_labeled_scope(scope, make_label(label));
192             compute_expr_scopes(*condition, body, scopes, scope);
193             compute_expr_scopes(*body_expr, body, scopes, scope);
194         }
195         Expr::Loop { body: body_expr, label } => {
196             let scope = scopes.new_labeled_scope(scope, make_label(label));
197             compute_expr_scopes(*body_expr, body, scopes, scope);
198         }
199         Expr::Lambda { args, body: body_expr, .. } => {
200             let scope = scopes.new_scope(scope);
201             scopes.add_params_bindings(body, scope, args);
202             compute_expr_scopes(*body_expr, body, scopes, scope);
203         }
204         Expr::Match { expr, arms } => {
205             compute_expr_scopes(*expr, body, scopes, scope);
206             for arm in arms {
207                 let mut scope = scopes.new_scope(scope);
208                 scopes.add_bindings(body, scope, arm.pat);
209                 match arm.guard {
210                     Some(MatchGuard::If { expr: guard }) => {
211                         scopes.set_scope(guard, scope);
212                         compute_expr_scopes(guard, body, scopes, scope);
213                     }
214                     Some(MatchGuard::IfLet { pat, expr: guard }) => {
215                         scopes.set_scope(guard, scope);
216                         compute_expr_scopes(guard, body, scopes, scope);
217                         scope = scopes.new_scope(scope);
218                         scopes.add_bindings(body, scope, pat);
219                     }
220                     _ => {}
221                 };
222                 scopes.set_scope(arm.expr, scope);
223                 compute_expr_scopes(arm.expr, body, scopes, scope);
224             }
225         }
226         e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
227     };
228 }
229
230 #[cfg(test)]
231 mod tests {
232     use base_db::{fixture::WithFixture, FileId, SourceDatabase};
233     use hir_expand::{name::AsName, InFile};
234     use syntax::{algo::find_node_at_offset, ast, AstNode};
235     use test_utils::{assert_eq_text, extract_offset};
236
237     use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
238
239     fn find_function(db: &TestDB, file_id: FileId) -> FunctionId {
240         let krate = db.test_crate();
241         let crate_def_map = db.crate_def_map(krate);
242
243         let module = crate_def_map.modules_for_file(file_id).next().unwrap();
244         let (_, def) = crate_def_map[module].scope.entries().next().unwrap();
245         match def.take_values().unwrap() {
246             ModuleDefId::FunctionId(it) => it,
247             _ => panic!(),
248         }
249     }
250
251     fn do_check(ra_fixture: &str, expected: &[&str]) {
252         let (offset, code) = extract_offset(ra_fixture);
253         let code = {
254             let mut buf = String::new();
255             let off: usize = offset.into();
256             buf.push_str(&code[..off]);
257             buf.push_str("$0marker");
258             buf.push_str(&code[off..]);
259             buf
260         };
261
262         let (db, position) = TestDB::with_position(&code);
263         let file_id = position.file_id;
264         let offset = position.offset;
265
266         let file_syntax = db.parse(file_id).syntax_node();
267         let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap();
268         let function = find_function(&db, file_id);
269
270         let scopes = db.expr_scopes(function.into());
271         let (_body, source_map) = db.body_with_source_map(function.into());
272
273         let expr_id = source_map
274             .node_expr(InFile { file_id: file_id.into(), value: &marker.into() })
275             .unwrap();
276         let scope = scopes.scope_for(expr_id);
277
278         let actual = scopes
279             .scope_chain(scope)
280             .flat_map(|scope| scopes.entries(scope))
281             .map(|it| it.name().to_string())
282             .collect::<Vec<_>>()
283             .join("\n");
284         let expected = expected.join("\n");
285         assert_eq_text!(&expected, &actual);
286     }
287
288     #[test]
289     fn test_lambda_scope() {
290         do_check(
291             r"
292             fn quux(foo: i32) {
293                 let f = |bar, baz: i32| {
294                     $0
295                 };
296             }",
297             &["bar", "baz", "foo"],
298         );
299     }
300
301     #[test]
302     fn test_call_scope() {
303         do_check(
304             r"
305             fn quux() {
306                 f(|x| $0 );
307             }",
308             &["x"],
309         );
310     }
311
312     #[test]
313     fn test_method_call_scope() {
314         do_check(
315             r"
316             fn quux() {
317                 z.f(|x| $0 );
318             }",
319             &["x"],
320         );
321     }
322
323     #[test]
324     fn test_loop_scope() {
325         do_check(
326             r"
327             fn quux() {
328                 loop {
329                     let x = ();
330                     $0
331                 };
332             }",
333             &["x"],
334         );
335     }
336
337     #[test]
338     fn test_match() {
339         do_check(
340             r"
341             fn quux() {
342                 match () {
343                     Some(x) => {
344                         $0
345                     }
346                 };
347             }",
348             &["x"],
349         );
350     }
351
352     #[test]
353     fn test_shadow_variable() {
354         do_check(
355             r"
356             fn foo(x: String) {
357                 let x : &str = &x$0;
358             }",
359             &["x"],
360         );
361     }
362
363     #[test]
364     fn test_bindings_after_at() {
365         do_check(
366             r"
367 fn foo() {
368     match Some(()) {
369         opt @ Some(unit) => {
370             $0
371         }
372         _ => {}
373     }
374 }
375 ",
376             &["opt", "unit"],
377         );
378     }
379
380     #[test]
381     fn macro_inner_item() {
382         do_check(
383             r"
384             macro_rules! mac {
385                 () => {{
386                     fn inner() {}
387                     inner();
388                 }};
389             }
390
391             fn foo() {
392                 mac!();
393                 $0
394             }
395         ",
396             &[],
397         );
398     }
399
400     #[test]
401     fn broken_inner_item() {
402         do_check(
403             r"
404             fn foo() {
405                 trait {}
406                 $0
407             }
408         ",
409             &[],
410         );
411     }
412
413     fn do_check_local_name(ra_fixture: &str, expected_offset: u32) {
414         let (db, position) = TestDB::with_position(ra_fixture);
415         let file_id = position.file_id;
416         let offset = position.offset;
417
418         let file = db.parse(file_id).ok().unwrap();
419         let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
420             .expect("failed to find a name at the target offset");
421         let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), offset).unwrap();
422
423         let function = find_function(&db, file_id);
424
425         let scopes = db.expr_scopes(function.into());
426         let (_body, source_map) = db.body_with_source_map(function.into());
427
428         let expr_scope = {
429             let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
430             let expr_id =
431                 source_map.node_expr(InFile { file_id: file_id.into(), value: &expr_ast }).unwrap();
432             scopes.scope_for(expr_id).unwrap()
433         };
434
435         let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
436         let pat_src = source_map.pat_syntax(resolved.pat()).unwrap();
437
438         let local_name = pat_src.value.either(
439             |it| it.syntax_node_ptr().to_node(file.syntax()),
440             |it| it.syntax_node_ptr().to_node(file.syntax()),
441         );
442         assert_eq!(local_name.text_range(), expected_name.syntax().text_range());
443     }
444
445     #[test]
446     fn test_resolve_local_name() {
447         do_check_local_name(
448             r#"
449 fn foo(x: i32, y: u32) {
450     {
451         let z = x * 2;
452     }
453     {
454         let t = x$0 * 3;
455     }
456 }
457 "#,
458             7,
459         );
460     }
461
462     #[test]
463     fn test_resolve_local_name_declaration() {
464         do_check_local_name(
465             r#"
466 fn foo(x: String) {
467     let x : &str = &x$0;
468 }
469 "#,
470             7,
471         );
472     }
473
474     #[test]
475     fn test_resolve_local_name_shadow() {
476         do_check_local_name(
477             r"
478 fn foo(x: String) {
479     let x : &str = &x;
480     x$0
481 }
482 ",
483             28,
484         );
485     }
486
487     #[test]
488     fn ref_patterns_contribute_bindings() {
489         do_check_local_name(
490             r"
491 fn foo() {
492     if let Some(&from) = bar() {
493         from$0;
494     }
495 }
496 ",
497             28,
498         );
499     }
500
501     #[test]
502     fn while_let_desugaring() {
503         cov_mark::check!(infer_resolve_while_let);
504         do_check_local_name(
505             r#"
506 fn test() {
507     let foo: Option<f32> = None;
508     while let Option::Some(spam) = foo {
509         spam$0
510     }
511 }
512 "#,
513             75,
514         );
515     }
516 }