]> git.lizzy.rs Git - rust.git/blob - crates/hir_def/src/body/scope.rs
parameters.split_last()
[rust.git] / crates / hir_def / src / body / scope.rs
1 //! Name resolution for expressions.
2 use std::sync::Arc;
3
4 use hir_expand::name::Name;
5 use la_arena::{Arena, Idx};
6 use rustc_hash::FxHashMap;
7
8 use crate::{
9     body::Body,
10     db::DefDatabase,
11     expr::{Expr, ExprId, LabelId, MatchGuard, Pat, PatId, Statement},
12     BlockId, DefWithBodyId,
13 };
14
15 pub type ScopeId = Idx<ScopeData>;
16
17 #[derive(Debug, PartialEq, Eq)]
18 pub struct ExprScopes {
19     scopes: Arena<ScopeData>,
20     scope_by_expr: FxHashMap<ExprId, ScopeId>,
21 }
22
23 #[derive(Debug, PartialEq, Eq)]
24 pub struct ScopeEntry {
25     name: Name,
26     pat: PatId,
27 }
28
29 impl ScopeEntry {
30     pub fn name(&self) -> &Name {
31         &self.name
32     }
33
34     pub fn pat(&self) -> PatId {
35         self.pat
36     }
37 }
38
39 #[derive(Debug, PartialEq, Eq)]
40 pub struct ScopeData {
41     parent: Option<ScopeId>,
42     block: Option<BlockId>,
43     label: Option<(LabelId, Name)>,
44     entries: Vec<ScopeEntry>,
45 }
46
47 impl ExprScopes {
48     pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> {
49         let body = db.body(def);
50         Arc::new(ExprScopes::new(&*body))
51     }
52
53     fn new(body: &Body) -> ExprScopes {
54         let mut scopes =
55             ExprScopes { scopes: Arena::default(), scope_by_expr: FxHashMap::default() };
56         let root = scopes.root_scope();
57         scopes.add_params_bindings(body, root, &body.params);
58         compute_expr_scopes(body.body_expr, body, &mut scopes, root);
59         scopes
60     }
61
62     pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
63         &self.scopes[scope].entries
64     }
65
66     /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`.
67     pub fn block(&self, scope: ScopeId) -> Option<BlockId> {
68         self.scopes[scope].block
69     }
70
71     /// If `scope` refers to a labeled expression scope, returns the corresponding `Label`.
72     pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> {
73         self.scopes[scope].label.clone()
74     }
75
76     pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
77         std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
78     }
79
80     pub fn resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry> {
81         self.scope_chain(Some(scope))
82             .find_map(|scope| self.entries(scope).iter().find(|it| it.name == *name))
83     }
84
85     pub fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
86         self.scope_by_expr.get(&expr).copied()
87     }
88
89     pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
90         &self.scope_by_expr
91     }
92
93     fn root_scope(&mut self) -> ScopeId {
94         self.scopes.alloc(ScopeData { parent: None, block: None, label: None, entries: vec![] })
95     }
96
97     fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
98         self.scopes.alloc(ScopeData {
99             parent: Some(parent),
100             block: None,
101             label: None,
102             entries: vec![],
103         })
104     }
105
106     fn new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId {
107         self.scopes.alloc(ScopeData { parent: Some(parent), block: None, label, entries: vec![] })
108     }
109
110     fn new_block_scope(
111         &mut self,
112         parent: ScopeId,
113         block: BlockId,
114         label: Option<(LabelId, Name)>,
115     ) -> ScopeId {
116         self.scopes.alloc(ScopeData {
117             parent: Some(parent),
118             block: Some(block),
119             label,
120             entries: vec![],
121         })
122     }
123
124     fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
125         let pattern = &body[pat];
126         if let Pat::Bind { name, .. } = pattern {
127             let entry = ScopeEntry { name: name.clone(), pat };
128             self.scopes[scope].entries.push(entry);
129         }
130
131         pattern.walk_child_pats(|pat| self.add_bindings(body, scope, pat));
132     }
133
134     fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) {
135         params.iter().for_each(|pat| self.add_bindings(body, scope, *pat));
136     }
137
138     fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
139         self.scope_by_expr.insert(node, scope);
140     }
141 }
142
143 fn compute_block_scopes(
144     statements: &[Statement],
145     tail: Option<ExprId>,
146     body: &Body,
147     scopes: &mut ExprScopes,
148     mut scope: ScopeId,
149 ) {
150     for stmt in statements {
151         match stmt {
152             Statement::Let { pat, initializer, else_branch, .. } => {
153                 if let Some(expr) = initializer {
154                     compute_expr_scopes(*expr, body, scopes, scope);
155                 }
156                 if let Some(expr) = else_branch {
157                     compute_expr_scopes(*expr, body, scopes, scope);
158                 }
159                 scope = scopes.new_scope(scope);
160                 scopes.add_bindings(body, scope, *pat);
161             }
162             Statement::Expr { expr, .. } => {
163                 compute_expr_scopes(*expr, body, scopes, scope);
164             }
165         }
166     }
167     if let Some(expr) = tail {
168         compute_expr_scopes(expr, body, scopes, scope);
169     }
170 }
171
172 fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
173     let make_label =
174         |label: &Option<LabelId>| label.map(|label| (label, body.labels[label].name.clone()));
175
176     scopes.set_scope(expr, scope);
177     match &body[expr] {
178         Expr::Block { statements, tail, id, label } => {
179             let scope = scopes.new_block_scope(scope, *id, make_label(label));
180             // Overwrite the old scope for the block expr, so that every block scope can be found
181             // via the block itself (important for blocks that only contain items, no expressions).
182             scopes.set_scope(expr, scope);
183             compute_block_scopes(statements, *tail, body, scopes, scope);
184         }
185         Expr::For { iterable, pat, body: body_expr, label } => {
186             compute_expr_scopes(*iterable, body, scopes, scope);
187             let scope = scopes.new_labeled_scope(scope, make_label(label));
188             scopes.add_bindings(body, scope, *pat);
189             compute_expr_scopes(*body_expr, body, scopes, scope);
190         }
191         Expr::While { condition, body: body_expr, label } => {
192             let scope = scopes.new_labeled_scope(scope, make_label(label));
193             compute_expr_scopes(*condition, body, scopes, scope);
194             compute_expr_scopes(*body_expr, body, scopes, scope);
195         }
196         Expr::Loop { body: body_expr, label } => {
197             let scope = scopes.new_labeled_scope(scope, make_label(label));
198             compute_expr_scopes(*body_expr, body, scopes, scope);
199         }
200         Expr::Lambda { args, body: body_expr, .. } => {
201             let scope = scopes.new_scope(scope);
202             scopes.add_params_bindings(body, scope, args);
203             compute_expr_scopes(*body_expr, body, scopes, scope);
204         }
205         Expr::Match { expr, arms } => {
206             compute_expr_scopes(*expr, body, scopes, scope);
207             for arm in arms.iter() {
208                 let mut scope = scopes.new_scope(scope);
209                 scopes.add_bindings(body, scope, arm.pat);
210                 match arm.guard {
211                     Some(MatchGuard::If { expr: guard }) => {
212                         scopes.set_scope(guard, scope);
213                         compute_expr_scopes(guard, body, scopes, scope);
214                     }
215                     Some(MatchGuard::IfLet { pat, expr: guard }) => {
216                         scopes.set_scope(guard, scope);
217                         compute_expr_scopes(guard, body, scopes, scope);
218                         scope = scopes.new_scope(scope);
219                         scopes.add_bindings(body, scope, pat);
220                     }
221                     _ => {}
222                 };
223                 scopes.set_scope(arm.expr, scope);
224                 compute_expr_scopes(arm.expr, body, scopes, scope);
225             }
226         }
227         e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
228     };
229 }
230
231 #[cfg(test)]
232 mod tests {
233     use base_db::{fixture::WithFixture, FileId, SourceDatabase};
234     use hir_expand::{name::AsName, InFile};
235     use syntax::{algo::find_node_at_offset, ast, AstNode};
236     use test_utils::{assert_eq_text, extract_offset};
237
238     use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
239
240     fn find_function(db: &TestDB, file_id: FileId) -> FunctionId {
241         let krate = db.test_crate();
242         let crate_def_map = db.crate_def_map(krate);
243
244         let module = crate_def_map.modules_for_file(file_id).next().unwrap();
245         let (_, def) = crate_def_map[module].scope.entries().next().unwrap();
246         match def.take_values().unwrap() {
247             ModuleDefId::FunctionId(it) => it,
248             _ => panic!(),
249         }
250     }
251
252     fn do_check(ra_fixture: &str, expected: &[&str]) {
253         let (offset, code) = extract_offset(ra_fixture);
254         let code = {
255             let mut buf = String::new();
256             let off: usize = offset.into();
257             buf.push_str(&code[..off]);
258             buf.push_str("$0marker");
259             buf.push_str(&code[off..]);
260             buf
261         };
262
263         let (db, position) = TestDB::with_position(&code);
264         let file_id = position.file_id;
265         let offset = position.offset;
266
267         let file_syntax = db.parse(file_id).syntax_node();
268         let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap();
269         let function = find_function(&db, file_id);
270
271         let scopes = db.expr_scopes(function.into());
272         let (_body, source_map) = db.body_with_source_map(function.into());
273
274         let expr_id = source_map
275             .node_expr(InFile { file_id: file_id.into(), value: &marker.into() })
276             .unwrap();
277         let scope = scopes.scope_for(expr_id);
278
279         let actual = scopes
280             .scope_chain(scope)
281             .flat_map(|scope| scopes.entries(scope))
282             .map(|it| it.name().to_smol_str())
283             .collect::<Vec<_>>()
284             .join("\n");
285         let expected = expected.join("\n");
286         assert_eq_text!(&expected, &actual);
287     }
288
289     #[test]
290     fn test_lambda_scope() {
291         do_check(
292             r"
293             fn quux(foo: i32) {
294                 let f = |bar, baz: i32| {
295                     $0
296                 };
297             }",
298             &["bar", "baz", "foo"],
299         );
300     }
301
302     #[test]
303     fn test_call_scope() {
304         do_check(
305             r"
306             fn quux() {
307                 f(|x| $0 );
308             }",
309             &["x"],
310         );
311     }
312
313     #[test]
314     fn test_method_call_scope() {
315         do_check(
316             r"
317             fn quux() {
318                 z.f(|x| $0 );
319             }",
320             &["x"],
321         );
322     }
323
324     #[test]
325     fn test_loop_scope() {
326         do_check(
327             r"
328             fn quux() {
329                 loop {
330                     let x = ();
331                     $0
332                 };
333             }",
334             &["x"],
335         );
336     }
337
338     #[test]
339     fn test_match() {
340         do_check(
341             r"
342             fn quux() {
343                 match () {
344                     Some(x) => {
345                         $0
346                     }
347                 };
348             }",
349             &["x"],
350         );
351     }
352
353     #[test]
354     fn test_shadow_variable() {
355         do_check(
356             r"
357             fn foo(x: String) {
358                 let x : &str = &x$0;
359             }",
360             &["x"],
361         );
362     }
363
364     #[test]
365     fn test_bindings_after_at() {
366         do_check(
367             r"
368 fn foo() {
369     match Some(()) {
370         opt @ Some(unit) => {
371             $0
372         }
373         _ => {}
374     }
375 }
376 ",
377             &["opt", "unit"],
378         );
379     }
380
381     #[test]
382     fn macro_inner_item() {
383         do_check(
384             r"
385             macro_rules! mac {
386                 () => {{
387                     fn inner() {}
388                     inner();
389                 }};
390             }
391
392             fn foo() {
393                 mac!();
394                 $0
395             }
396         ",
397             &[],
398         );
399     }
400
401     #[test]
402     fn broken_inner_item() {
403         do_check(
404             r"
405             fn foo() {
406                 trait {}
407                 $0
408             }
409         ",
410             &[],
411         );
412     }
413
414     fn do_check_local_name(ra_fixture: &str, expected_offset: u32) {
415         let (db, position) = TestDB::with_position(ra_fixture);
416         let file_id = position.file_id;
417         let offset = position.offset;
418
419         let file = db.parse(file_id).ok().unwrap();
420         let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
421             .expect("failed to find a name at the target offset");
422         let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), offset).unwrap();
423
424         let function = find_function(&db, file_id);
425
426         let scopes = db.expr_scopes(function.into());
427         let (_body, source_map) = db.body_with_source_map(function.into());
428
429         let expr_scope = {
430             let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
431             let expr_id =
432                 source_map.node_expr(InFile { file_id: file_id.into(), value: &expr_ast }).unwrap();
433             scopes.scope_for(expr_id).unwrap()
434         };
435
436         let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
437         let pat_src = source_map.pat_syntax(resolved.pat()).unwrap();
438
439         let local_name = pat_src.value.either(
440             |it| it.syntax_node_ptr().to_node(file.syntax()),
441             |it| it.syntax_node_ptr().to_node(file.syntax()),
442         );
443         assert_eq!(local_name.text_range(), expected_name.syntax().text_range());
444     }
445
446     #[test]
447     fn test_resolve_local_name() {
448         do_check_local_name(
449             r#"
450 fn foo(x: i32, y: u32) {
451     {
452         let z = x * 2;
453     }
454     {
455         let t = x$0 * 3;
456     }
457 }
458 "#,
459             7,
460         );
461     }
462
463     #[test]
464     fn test_resolve_local_name_declaration() {
465         do_check_local_name(
466             r#"
467 fn foo(x: String) {
468     let x : &str = &x$0;
469 }
470 "#,
471             7,
472         );
473     }
474
475     #[test]
476     fn test_resolve_local_name_shadow() {
477         do_check_local_name(
478             r"
479 fn foo(x: String) {
480     let x : &str = &x;
481     x$0
482 }
483 ",
484             28,
485         );
486     }
487
488     #[test]
489     fn ref_patterns_contribute_bindings() {
490         do_check_local_name(
491             r"
492 fn foo() {
493     if let Some(&from) = bar() {
494         from$0;
495     }
496 }
497 ",
498             28,
499         );
500     }
501
502     #[test]
503     fn while_let_desugaring() {
504         cov_mark::check!(infer_resolve_while_let);
505         do_check_local_name(
506             r#"
507 fn test() {
508     let foo: Option<f32> = None;
509     while let Option::Some(spam) = foo {
510         spam$0
511     }
512 }
513 "#,
514             75,
515         );
516     }
517 }