]> git.lizzy.rs Git - rust.git/blob - crates/hir_def/src/body/scope.rs
prepare to publish el libro de arena
[rust.git] / crates / hir_def / src / body / scope.rs
1 //! Name resolution for expressions.
2 use std::sync::Arc;
3
4 use hir_expand::name::Name;
5 use la_arena::{Arena, Idx};
6 use rustc_hash::FxHashMap;
7
8 use crate::{
9     body::Body,
10     db::DefDatabase,
11     expr::{Expr, ExprId, Pat, PatId, Statement},
12     DefWithBodyId,
13 };
14
15 pub type ScopeId = Idx<ScopeData>;
16
17 #[derive(Debug, PartialEq, Eq)]
18 pub struct ExprScopes {
19     scopes: Arena<ScopeData>,
20     scope_by_expr: FxHashMap<ExprId, ScopeId>,
21 }
22
23 #[derive(Debug, PartialEq, Eq)]
24 pub struct ScopeEntry {
25     name: Name,
26     pat: PatId,
27 }
28
29 impl ScopeEntry {
30     pub fn name(&self) -> &Name {
31         &self.name
32     }
33
34     pub fn pat(&self) -> PatId {
35         self.pat
36     }
37 }
38
39 #[derive(Debug, PartialEq, Eq)]
40 pub struct ScopeData {
41     parent: Option<ScopeId>,
42     entries: Vec<ScopeEntry>,
43 }
44
45 impl ExprScopes {
46     pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> {
47         let body = db.body(def);
48         Arc::new(ExprScopes::new(&*body))
49     }
50
51     fn new(body: &Body) -> ExprScopes {
52         let mut scopes =
53             ExprScopes { scopes: Arena::default(), scope_by_expr: FxHashMap::default() };
54         let root = scopes.root_scope();
55         scopes.add_params_bindings(body, root, &body.params);
56         compute_expr_scopes(body.body_expr, body, &mut scopes, root);
57         scopes
58     }
59
60     pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
61         &self.scopes[scope].entries
62     }
63
64     pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
65         std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
66     }
67
68     pub fn resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry> {
69         self.scope_chain(Some(scope))
70             .find_map(|scope| self.entries(scope).iter().find(|it| it.name == *name))
71     }
72
73     pub fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
74         self.scope_by_expr.get(&expr).copied()
75     }
76
77     pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
78         &self.scope_by_expr
79     }
80
81     fn root_scope(&mut self) -> ScopeId {
82         self.scopes.alloc(ScopeData { parent: None, entries: vec![] })
83     }
84
85     fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
86         self.scopes.alloc(ScopeData { parent: Some(parent), entries: vec![] })
87     }
88
89     fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
90         let pattern = &body[pat];
91         if let Pat::Bind { name, .. } = pattern {
92             let entry = ScopeEntry { name: name.clone(), pat };
93             self.scopes[scope].entries.push(entry);
94         }
95
96         pattern.walk_child_pats(|pat| self.add_bindings(body, scope, pat));
97     }
98
99     fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) {
100         params.iter().for_each(|pat| self.add_bindings(body, scope, *pat));
101     }
102
103     fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
104         self.scope_by_expr.insert(node, scope);
105     }
106 }
107
108 fn compute_block_scopes(
109     statements: &[Statement],
110     tail: Option<ExprId>,
111     body: &Body,
112     scopes: &mut ExprScopes,
113     mut scope: ScopeId,
114 ) {
115     for stmt in statements {
116         match stmt {
117             Statement::Let { pat, initializer, .. } => {
118                 if let Some(expr) = initializer {
119                     scopes.set_scope(*expr, scope);
120                     compute_expr_scopes(*expr, body, scopes, scope);
121                 }
122                 scope = scopes.new_scope(scope);
123                 scopes.add_bindings(body, scope, *pat);
124             }
125             Statement::Expr(expr) => {
126                 scopes.set_scope(*expr, scope);
127                 compute_expr_scopes(*expr, body, scopes, scope);
128             }
129         }
130     }
131     if let Some(expr) = tail {
132         compute_expr_scopes(expr, body, scopes, scope);
133     }
134 }
135
136 fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
137     scopes.set_scope(expr, scope);
138     match &body[expr] {
139         Expr::Block { statements, tail, .. } => {
140             compute_block_scopes(&statements, *tail, body, scopes, scope);
141         }
142         Expr::For { iterable, pat, body: body_expr, .. } => {
143             compute_expr_scopes(*iterable, body, scopes, scope);
144             let scope = scopes.new_scope(scope);
145             scopes.add_bindings(body, scope, *pat);
146             compute_expr_scopes(*body_expr, body, scopes, scope);
147         }
148         Expr::Lambda { args, body: body_expr, .. } => {
149             let scope = scopes.new_scope(scope);
150             scopes.add_params_bindings(body, scope, &args);
151             compute_expr_scopes(*body_expr, body, scopes, scope);
152         }
153         Expr::Match { expr, arms } => {
154             compute_expr_scopes(*expr, body, scopes, scope);
155             for arm in arms {
156                 let scope = scopes.new_scope(scope);
157                 scopes.add_bindings(body, scope, arm.pat);
158                 if let Some(guard) = arm.guard {
159                     scopes.set_scope(guard, scope);
160                     compute_expr_scopes(guard, body, scopes, scope);
161                 }
162                 scopes.set_scope(arm.expr, scope);
163                 compute_expr_scopes(arm.expr, body, scopes, scope);
164             }
165         }
166         e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
167     };
168 }
169
170 #[cfg(test)]
171 mod tests {
172     use base_db::{fixture::WithFixture, FileId, SourceDatabase};
173     use hir_expand::{name::AsName, InFile};
174     use syntax::{algo::find_node_at_offset, ast, AstNode};
175     use test_utils::{assert_eq_text, extract_offset, mark};
176
177     use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
178
179     fn find_function(db: &TestDB, file_id: FileId) -> FunctionId {
180         let krate = db.test_crate();
181         let crate_def_map = db.crate_def_map(krate);
182
183         let module = crate_def_map.modules_for_file(file_id).next().unwrap();
184         let (_, def) = crate_def_map[module].scope.entries().next().unwrap();
185         match def.take_values().unwrap() {
186             ModuleDefId::FunctionId(it) => it,
187             _ => panic!(),
188         }
189     }
190
191     fn do_check(ra_fixture: &str, expected: &[&str]) {
192         let (offset, code) = extract_offset(ra_fixture);
193         let code = {
194             let mut buf = String::new();
195             let off: usize = offset.into();
196             buf.push_str(&code[..off]);
197             buf.push_str("$0marker");
198             buf.push_str(&code[off..]);
199             buf
200         };
201
202         let (db, position) = TestDB::with_position(&code);
203         let file_id = position.file_id;
204         let offset = position.offset;
205
206         let file_syntax = db.parse(file_id).syntax_node();
207         let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap();
208         let function = find_function(&db, file_id);
209
210         let scopes = db.expr_scopes(function.into());
211         let (_body, source_map) = db.body_with_source_map(function.into());
212
213         let expr_id = source_map
214             .node_expr(InFile { file_id: file_id.into(), value: &marker.into() })
215             .unwrap();
216         let scope = scopes.scope_for(expr_id);
217
218         let actual = scopes
219             .scope_chain(scope)
220             .flat_map(|scope| scopes.entries(scope))
221             .map(|it| it.name().to_string())
222             .collect::<Vec<_>>()
223             .join("\n");
224         let expected = expected.join("\n");
225         assert_eq_text!(&expected, &actual);
226     }
227
228     #[test]
229     fn test_lambda_scope() {
230         do_check(
231             r"
232             fn quux(foo: i32) {
233                 let f = |bar, baz: i32| {
234                     $0
235                 };
236             }",
237             &["bar", "baz", "foo"],
238         );
239     }
240
241     #[test]
242     fn test_call_scope() {
243         do_check(
244             r"
245             fn quux() {
246                 f(|x| $0 );
247             }",
248             &["x"],
249         );
250     }
251
252     #[test]
253     fn test_method_call_scope() {
254         do_check(
255             r"
256             fn quux() {
257                 z.f(|x| $0 );
258             }",
259             &["x"],
260         );
261     }
262
263     #[test]
264     fn test_loop_scope() {
265         do_check(
266             r"
267             fn quux() {
268                 loop {
269                     let x = ();
270                     $0
271                 };
272             }",
273             &["x"],
274         );
275     }
276
277     #[test]
278     fn test_match() {
279         do_check(
280             r"
281             fn quux() {
282                 match () {
283                     Some(x) => {
284                         $0
285                     }
286                 };
287             }",
288             &["x"],
289         );
290     }
291
292     #[test]
293     fn test_shadow_variable() {
294         do_check(
295             r"
296             fn foo(x: String) {
297                 let x : &str = &x$0;
298             }",
299             &["x"],
300         );
301     }
302
303     #[test]
304     fn test_bindings_after_at() {
305         do_check(
306             r"
307 fn foo() {
308     match Some(()) {
309         opt @ Some(unit) => {
310             $0
311         }
312         _ => {}
313     }
314 }
315 ",
316             &["opt", "unit"],
317         );
318     }
319
320     #[test]
321     fn macro_inner_item() {
322         do_check(
323             r"
324             macro_rules! mac {
325                 () => {{
326                     fn inner() {}
327                     inner();
328                 }};
329             }
330
331             fn foo() {
332                 mac!();
333                 $0
334             }
335         ",
336             &[],
337         );
338     }
339
340     #[test]
341     fn broken_inner_item() {
342         do_check(
343             r"
344             fn foo() {
345                 trait {}
346                 $0
347             }
348         ",
349             &[],
350         );
351     }
352
353     fn do_check_local_name(ra_fixture: &str, expected_offset: u32) {
354         let (db, position) = TestDB::with_position(ra_fixture);
355         let file_id = position.file_id;
356         let offset = position.offset;
357
358         let file = db.parse(file_id).ok().unwrap();
359         let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
360             .expect("failed to find a name at the target offset");
361         let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), offset).unwrap();
362
363         let function = find_function(&db, file_id);
364
365         let scopes = db.expr_scopes(function.into());
366         let (_body, source_map) = db.body_with_source_map(function.into());
367
368         let expr_scope = {
369             let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
370             let expr_id =
371                 source_map.node_expr(InFile { file_id: file_id.into(), value: &expr_ast }).unwrap();
372             scopes.scope_for(expr_id).unwrap()
373         };
374
375         let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
376         let pat_src = source_map.pat_syntax(resolved.pat()).unwrap();
377
378         let local_name = pat_src.value.either(
379             |it| it.syntax_node_ptr().to_node(file.syntax()),
380             |it| it.syntax_node_ptr().to_node(file.syntax()),
381         );
382         assert_eq!(local_name.text_range(), expected_name.syntax().text_range());
383     }
384
385     #[test]
386     fn test_resolve_local_name() {
387         do_check_local_name(
388             r#"
389 fn foo(x: i32, y: u32) {
390     {
391         let z = x * 2;
392     }
393     {
394         let t = x$0 * 3;
395     }
396 }
397 "#,
398             7,
399         );
400     }
401
402     #[test]
403     fn test_resolve_local_name_declaration() {
404         do_check_local_name(
405             r#"
406 fn foo(x: String) {
407     let x : &str = &x$0;
408 }
409 "#,
410             7,
411         );
412     }
413
414     #[test]
415     fn test_resolve_local_name_shadow() {
416         do_check_local_name(
417             r"
418 fn foo(x: String) {
419     let x : &str = &x;
420     x$0
421 }
422 ",
423             28,
424         );
425     }
426
427     #[test]
428     fn ref_patterns_contribute_bindings() {
429         do_check_local_name(
430             r"
431 fn foo() {
432     if let Some(&from) = bar() {
433         from$0;
434     }
435 }
436 ",
437             28,
438         );
439     }
440
441     #[test]
442     fn while_let_desugaring() {
443         mark::check!(infer_resolve_while_let);
444         do_check_local_name(
445             r#"
446 fn test() {
447     let foo: Option<f32> = None;
448     while let Option::Some(spam) = foo {
449         spam$0
450     }
451 }
452 "#,
453             75,
454         );
455     }
456 }