]> git.lizzy.rs Git - rust.git/blob - crates/ide/src/highlight_related.rs
Refine tail exit point highlighting to highlight inner tails
[rust.git] / crates / ide / src / highlight_related.rs
1 use hir::Semantics;
2 use ide_db::{
3     base_db::FilePosition,
4     defs::Definition,
5     helpers::pick_best_token,
6     search::{FileReference, ReferenceAccess, SearchScope},
7     RootDatabase,
8 };
9 use syntax::{
10     ast::{self, LoopBodyOwner},
11     match_ast, AstNode, SyntaxNode, SyntaxToken, TextRange, WalkEvent, T,
12 };
13
14 use crate::{display::TryToNav, references, NavigationTarget};
15
16 pub struct HighlightedRange {
17     pub range: TextRange,
18     pub access: Option<ReferenceAccess>,
19 }
20
21 // Feature: Highlight Related
22 //
23 // Highlights constructs related to the thing under the cursor:
24 // - if on an identifier, highlights all references to that identifier in the current file
25 // - if on an `async` or `await token, highlights all yield points for that async context
26 // - if on a `return` token, `?` character or `->` return type arrow, highlights all exit points for that context
27 pub(crate) fn highlight_related(
28     sema: &Semantics<RootDatabase>,
29     position: FilePosition,
30 ) -> Option<Vec<HighlightedRange>> {
31     let _p = profile::span("highlight_related");
32     let syntax = sema.parse(position.file_id).syntax().clone();
33
34     let token = pick_best_token(syntax.token_at_offset(position.offset), |kind| match kind {
35         T![?] => 2, // prefer `?` when the cursor is sandwiched like `await$0?`
36         T![await] | T![async] | T![return] | T![->] => 1,
37         _ => 0,
38     })?;
39
40     match token.kind() {
41         T![return] | T![?] | T![->] => highlight_exit_points(sema, token),
42         T![await] | T![async] => highlight_yield_points(token),
43         _ => highlight_references(sema, &syntax, position),
44     }
45 }
46
47 fn highlight_references(
48     sema: &Semantics<RootDatabase>,
49     syntax: &SyntaxNode,
50     FilePosition { offset, file_id }: FilePosition,
51 ) -> Option<Vec<HighlightedRange>> {
52     let def = references::find_def(sema, syntax, offset)?;
53     let usages = def.usages(sema).set_scope(Some(SearchScope::single_file(file_id))).all();
54
55     let declaration = match def {
56         Definition::ModuleDef(hir::ModuleDef::Module(module)) => {
57             Some(NavigationTarget::from_module_to_decl(sema.db, module))
58         }
59         def => def.try_to_nav(sema.db),
60     }
61     .filter(|decl| decl.file_id == file_id)
62     .and_then(|decl| {
63         let range = decl.focus_range?;
64         let access = references::decl_access(&def, syntax, range);
65         Some(HighlightedRange { range, access })
66     });
67
68     let file_refs = usages.references.get(&file_id).map_or(&[][..], Vec::as_slice);
69     let mut res = Vec::with_capacity(file_refs.len() + 1);
70     res.extend(declaration);
71     res.extend(
72         file_refs
73             .iter()
74             .map(|&FileReference { access, range, .. }| HighlightedRange { range, access }),
75     );
76     Some(res)
77 }
78
79 fn highlight_exit_points(
80     sema: &Semantics<RootDatabase>,
81     token: SyntaxToken,
82 ) -> Option<Vec<HighlightedRange>> {
83     fn hl(
84         sema: &Semantics<RootDatabase>,
85         body: Option<ast::Expr>,
86     ) -> Option<Vec<HighlightedRange>> {
87         let mut highlights = Vec::new();
88         let body = body?;
89         walk(&body, &mut |expr| match expr {
90             ast::Expr::ReturnExpr(expr) => {
91                 if let Some(token) = expr.return_token() {
92                     highlights.push(HighlightedRange { access: None, range: token.text_range() });
93                 }
94             }
95             ast::Expr::TryExpr(try_) => {
96                 if let Some(token) = try_.question_mark_token() {
97                     highlights.push(HighlightedRange { access: None, range: token.text_range() });
98                 }
99             }
100             ast::Expr::MethodCallExpr(_) | ast::Expr::CallExpr(_) | ast::Expr::MacroCall(_) => {
101                 if sema.type_of_expr(&expr).map_or(false, |ty| ty.is_never()) {
102                     highlights
103                         .push(HighlightedRange { access: None, range: expr.syntax().text_range() });
104                 }
105             }
106             _ => (),
107         });
108         let tail = match body {
109             ast::Expr::BlockExpr(b) => b.tail_expr(),
110             e => Some(e),
111         };
112
113         if let Some(tail) = tail {
114             for_each_inner_tail(&tail, &mut |tail| {
115                 highlights
116                     .push(HighlightedRange { access: None, range: tail.syntax().text_range() })
117             });
118         }
119         Some(highlights)
120     }
121     for anc in token.ancestors() {
122         return match_ast! {
123             match anc {
124                 ast::Fn(fn_) => hl(sema, fn_.body().map(ast::Expr::BlockExpr)),
125                 ast::ClosureExpr(closure) => hl(sema, closure.body()),
126                 ast::EffectExpr(effect) => if matches!(effect.effect(), ast::Effect::Async(_) | ast::Effect::Try(_)| ast::Effect::Const(_)) {
127                     hl(sema, effect.block_expr().map(ast::Expr::BlockExpr))
128                 } else {
129                     continue;
130                 },
131                 _ => continue,
132             }
133         };
134     }
135     None
136 }
137
138 fn highlight_yield_points(token: SyntaxToken) -> Option<Vec<HighlightedRange>> {
139     fn hl(
140         async_token: Option<SyntaxToken>,
141         body: Option<ast::Expr>,
142     ) -> Option<Vec<HighlightedRange>> {
143         let mut highlights = Vec::new();
144         highlights.push(HighlightedRange { access: None, range: async_token?.text_range() });
145         if let Some(body) = body {
146             walk(&body, &mut |expr| {
147                 if let ast::Expr::AwaitExpr(expr) = expr {
148                     if let Some(token) = expr.await_token() {
149                         highlights
150                             .push(HighlightedRange { access: None, range: token.text_range() });
151                     }
152                 }
153             });
154         }
155         Some(highlights)
156     }
157     for anc in token.ancestors() {
158         return match_ast! {
159             match anc {
160                 ast::Fn(fn_) => hl(fn_.async_token(), fn_.body().map(ast::Expr::BlockExpr)),
161                 ast::EffectExpr(effect) => hl(effect.async_token(), effect.block_expr().map(ast::Expr::BlockExpr)),
162                 ast::ClosureExpr(closure) => hl(closure.async_token(), closure.body()),
163                 _ => continue,
164             }
165         };
166     }
167     None
168 }
169
170 /// Preorder walk all the expression's child expressions
171 fn walk(expr: &ast::Expr, cb: &mut dyn FnMut(ast::Expr)) {
172     let mut preorder = expr.syntax().preorder();
173     while let Some(event) = preorder.next() {
174         let node = match event {
175             WalkEvent::Enter(node) => node,
176             WalkEvent::Leave(_) => continue,
177         };
178         match ast::Stmt::cast(node.clone()) {
179             // recursively walk the initializer, skipping potential const pat expressions
180             // lets statements aren't usually nested too deeply so this is fine to recurse on
181             Some(ast::Stmt::LetStmt(l)) => {
182                 if let Some(expr) = l.initializer() {
183                     walk(&expr, cb);
184                 }
185                 preorder.skip_subtree();
186             }
187             // Don't skip subtree since we want to process the expression child next
188             Some(ast::Stmt::ExprStmt(_)) => (),
189             // skip inner items which might have their own expressions
190             Some(ast::Stmt::Item(_)) => preorder.skip_subtree(),
191             None => {
192                 if let Some(expr) = ast::Expr::cast(node) {
193                     let is_different_context = match &expr {
194                         ast::Expr::EffectExpr(effect) => {
195                             matches!(
196                                 effect.effect(),
197                                 ast::Effect::Async(_) | ast::Effect::Try(_) | ast::Effect::Const(_)
198                             )
199                         }
200                         ast::Expr::ClosureExpr(__) => true,
201                         _ => false,
202                     };
203                     cb(expr);
204                     if is_different_context {
205                         preorder.skip_subtree();
206                     }
207                 } else {
208                     preorder.skip_subtree();
209                 }
210             }
211         }
212     }
213 }
214
215 // FIXME: doesn't account for labeled breaks in labeled blocks
216 fn for_each_inner_tail(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) {
217     match expr {
218         ast::Expr::BlockExpr(b) => {
219             if let Some(e) = b.tail_expr() {
220                 for_each_inner_tail(&e, cb);
221             }
222         }
223         ast::Expr::EffectExpr(e) => match e.effect() {
224             ast::Effect::Label(_) | ast::Effect::Unsafe(_) => {
225                 if let Some(e) = e.block_expr().and_then(|b| b.tail_expr()) {
226                     for_each_inner_tail(&e, cb);
227                 }
228             }
229             ast::Effect::Async(_) | ast::Effect::Try(_) | ast::Effect::Const(_) => cb(expr),
230         },
231         ast::Expr::IfExpr(if_) => {
232             if_.blocks().for_each(|block| for_each_inner_tail(&ast::Expr::BlockExpr(block), cb))
233         }
234         ast::Expr::LoopExpr(l) => for_each_break(l, cb),
235         ast::Expr::MatchExpr(m) => {
236             if let Some(arms) = m.match_arm_list() {
237                 arms.arms().filter_map(|arm| arm.expr()).for_each(|e| for_each_inner_tail(&e, cb));
238             }
239         }
240         ast::Expr::ArrayExpr(_)
241         | ast::Expr::AwaitExpr(_)
242         | ast::Expr::BinExpr(_)
243         | ast::Expr::BoxExpr(_)
244         | ast::Expr::BreakExpr(_)
245         | ast::Expr::CallExpr(_)
246         | ast::Expr::CastExpr(_)
247         | ast::Expr::ClosureExpr(_)
248         | ast::Expr::ContinueExpr(_)
249         | ast::Expr::FieldExpr(_)
250         | ast::Expr::ForExpr(_)
251         | ast::Expr::IndexExpr(_)
252         | ast::Expr::Literal(_)
253         | ast::Expr::MacroCall(_)
254         | ast::Expr::MacroStmts(_)
255         | ast::Expr::MethodCallExpr(_)
256         | ast::Expr::ParenExpr(_)
257         | ast::Expr::PathExpr(_)
258         | ast::Expr::PrefixExpr(_)
259         | ast::Expr::RangeExpr(_)
260         | ast::Expr::RecordExpr(_)
261         | ast::Expr::RefExpr(_)
262         | ast::Expr::ReturnExpr(_)
263         | ast::Expr::TryExpr(_)
264         | ast::Expr::TupleExpr(_)
265         | ast::Expr::WhileExpr(_)
266         | ast::Expr::YieldExpr(_) => cb(expr),
267     }
268 }
269
270 fn for_each_break(l: &ast::LoopExpr, cb: &mut dyn FnMut(&ast::Expr)) {
271     let label = l.label().and_then(|lbl| lbl.lifetime());
272     let mut depth = 0;
273     if let Some(b) = l.loop_body() {
274         let preorder = &mut b.syntax().preorder();
275         let ev_as_expr = |ev| match ev {
276             WalkEvent::Enter(it) => Some(WalkEvent::Enter(ast::Expr::cast(it)?)),
277             WalkEvent::Leave(it) => Some(WalkEvent::Leave(ast::Expr::cast(it)?)),
278         };
279         let eq_label = |lt: Option<ast::Lifetime>| {
280             lt.zip(label.as_ref()).map_or(false, |(lt, lbl)| lt.text() == lbl.text())
281         };
282         while let Some(node) = preorder.find_map(ev_as_expr) {
283             match node {
284                 WalkEvent::Enter(expr) => match &expr {
285                     ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_) => {
286                         depth += 1
287                     }
288                     ast::Expr::EffectExpr(e) if e.label().is_some() => depth += 1,
289                     ast::Expr::BreakExpr(b) if depth == 0 || eq_label(b.lifetime()) => {
290                         cb(&expr);
291                     }
292                     _ => (),
293                 },
294                 WalkEvent::Leave(expr) => match expr {
295                     ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_) => {
296                         depth -= 1
297                     }
298                     ast::Expr::EffectExpr(e) if e.label().is_some() => depth -= 1,
299                     _ => (),
300                 },
301             }
302         }
303     }
304 }
305
306 #[cfg(test)]
307 mod tests {
308     use crate::fixture;
309
310     use super::*;
311
312     fn check(ra_fixture: &str) {
313         let (analysis, pos, annotations) = fixture::annotations(ra_fixture);
314         let hls = analysis.highlight_related(pos).unwrap().unwrap();
315
316         let mut expected = annotations
317             .into_iter()
318             .map(|(r, access)| (r.range, (!access.is_empty()).then(|| access)))
319             .collect::<Vec<_>>();
320
321         let mut actual = hls
322             .into_iter()
323             .map(|hl| {
324                 (
325                     hl.range,
326                     hl.access.map(|it| {
327                         match it {
328                             ReferenceAccess::Read => "read",
329                             ReferenceAccess::Write => "write",
330                         }
331                         .to_string()
332                     }),
333                 )
334             })
335             .collect::<Vec<_>>();
336         actual.sort_by_key(|(range, _)| range.start());
337         expected.sort_by_key(|(range, _)| range.start());
338
339         assert_eq!(expected, actual);
340     }
341
342     #[test]
343     fn test_hl_module() {
344         check(
345             r#"
346 //- /lib.rs
347 mod foo$0;
348  // ^^^
349 //- /foo.rs
350 struct Foo;
351 "#,
352         );
353     }
354
355     #[test]
356     fn test_hl_self_in_crate_root() {
357         check(
358             r#"
359 use self$0;
360 "#,
361         );
362     }
363
364     #[test]
365     fn test_hl_self_in_module() {
366         check(
367             r#"
368 //- /lib.rs
369 mod foo;
370 //- /foo.rs
371 use self$0;
372 "#,
373         );
374     }
375
376     #[test]
377     fn test_hl_local() {
378         check(
379             r#"
380 fn foo() {
381     let mut bar = 3;
382          // ^^^ write
383     bar$0;
384  // ^^^ read
385 }
386 "#,
387         );
388     }
389
390     #[test]
391     fn test_hl_yield_points() {
392         check(
393             r#"
394 pub async fn foo() {
395  // ^^^^^
396     let x = foo()
397         .await$0
398       // ^^^^^
399         .await;
400       // ^^^^^
401     || { 0.await };
402     (async { 0.await }).await
403                      // ^^^^^
404 }
405 "#,
406         );
407     }
408
409     #[test]
410     fn test_hl_yield_points2() {
411         check(
412             r#"
413 pub async$0 fn foo() {
414  // ^^^^^
415     let x = foo()
416         .await
417       // ^^^^^
418         .await;
419       // ^^^^^
420     || { 0.await };
421     (async { 0.await }).await
422                      // ^^^^^
423 }
424 "#,
425         );
426     }
427
428     #[test]
429     fn test_hl_yield_nested_fn() {
430         check(
431             r#"
432 async fn foo() {
433     async fn foo2() {
434  // ^^^^^
435         async fn foo3() {
436             0.await
437         }
438         0.await$0
439        // ^^^^^
440     }
441     0.await
442 }
443 "#,
444         );
445     }
446
447     #[test]
448     fn test_hl_yield_nested_async_blocks() {
449         check(
450             r#"
451 async fn foo() {
452     (async {
453   // ^^^^^
454         (async {
455            0.await
456         }).await$0 }
457         // ^^^^^
458     ).await;
459 }
460 "#,
461         );
462     }
463
464     #[test]
465     fn test_hl_exit_points() {
466         check(
467             r#"
468 fn foo() -> u32 {
469     if true {
470         return$0 0;
471      // ^^^^^^
472     }
473
474     0?;
475   // ^
476     0xDEAD_BEEF
477  // ^^^^^^^^^^^
478 }
479 "#,
480         );
481     }
482
483     #[test]
484     fn test_hl_exit_points2() {
485         check(
486             r#"
487 fn foo() ->$0 u32 {
488     if true {
489         return 0;
490      // ^^^^^^
491     }
492
493     0?;
494   // ^
495     0xDEAD_BEEF
496  // ^^^^^^^^^^^
497 }
498 "#,
499         );
500     }
501
502     #[test]
503     fn test_hl_prefer_ref_over_tail_exit() {
504         check(
505             r#"
506 fn foo() -> u32 {
507 // ^^^
508     if true {
509         return 0;
510     }
511
512     0?;
513
514     foo$0()
515  // ^^^
516 }
517 "#,
518         );
519     }
520
521     #[test]
522     fn test_hl_never_call_is_exit_point() {
523         check(
524             r#"
525 struct Never;
526 impl Never {
527     fn never(self) -> ! { loop {} }
528 }
529 macro_rules! never {
530     () => { never() }
531 }
532 fn never() -> ! { loop {} }
533 fn foo() ->$0 u32 {
534     never();
535  // ^^^^^^^
536     never!();
537  // FIXME sema doesn't give us types for macrocalls
538
539     Never.never();
540  // ^^^^^^^^^^^^^
541
542     0
543  // ^
544 }
545 "#,
546         );
547     }
548
549     #[test]
550     fn test_hl_inner_tail_exit_points() {
551         check(
552             r#"
553 fn foo() ->$0 u32 {
554     if true {
555         unsafe {
556             return 5;
557          // ^^^^^^
558             5
559          // ^
560         }
561     } else {
562         match 5 {
563             6 => 100,
564               // ^^^
565             7 => loop {
566                 break 5;
567              // ^^^^^^^
568             }
569             8 => 'a: loop {
570                 'b: loop {
571                     break 'a 5;
572                  // ^^^^^^^^^^
573                     break 'b 5;
574                     break 5;
575                 };
576             }
577             //
578             _ => 500,
579               // ^^^
580         }
581     }
582 }
583 "#,
584         );
585     }
586 }