]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/unnecessary_async.rs
Auto merge of #103913 - Neutron3529:patch-1, r=thomcc
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / unnecessary_async.rs
1 use ide_db::{
2     assists::{AssistId, AssistKind},
3     base_db::FileId,
4     defs::Definition,
5     search::FileReference,
6     syntax_helpers::node_ext::full_path_of_name_ref,
7 };
8 use syntax::{
9     ast::{self, NameLike, NameRef},
10     AstNode, SyntaxKind, TextRange,
11 };
12
13 use crate::{AssistContext, Assists};
14
15 // Assist: unnecessary_async
16 //
17 // Removes the `async` mark from functions which have no `.await` in their body.
18 // Looks for calls to the functions and removes the `.await` on the call site.
19 //
20 // ```
21 // pub async f$0n foo() {}
22 // pub async fn bar() { foo().await }
23 // ```
24 // ->
25 // ```
26 // pub fn foo() {}
27 // pub async fn bar() { foo() }
28 // ```
29 pub(crate) fn unnecessary_async(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
30     let function: ast::Fn = ctx.find_node_at_offset()?;
31
32     // Do nothing if the cursor is not on the prototype. This is so that the check does not pollute
33     // when the user asks us for assists when in the middle of the function body.
34     // We consider the prototype to be anything that is before the body of the function.
35     let cursor_position = ctx.offset();
36     if cursor_position >= function.body()?.syntax().text_range().start() {
37         return None;
38     }
39     // Do nothing if the function isn't async.
40     if let None = function.async_token() {
41         return None;
42     }
43     // Do nothing if the function has an `await` expression in its body.
44     if function.body()?.syntax().descendants().find_map(ast::AwaitExpr::cast).is_some() {
45         return None;
46     }
47     // Do nothing if the method is a member of trait.
48     if let Some(impl_) = function.syntax().ancestors().nth(2).and_then(ast::Impl::cast) {
49         if let Some(_) = impl_.trait_() {
50             return None;
51         }
52     }
53
54     // Remove the `async` keyword plus whitespace after it, if any.
55     let async_range = {
56         let async_token = function.async_token()?;
57         let next_token = async_token.next_token()?;
58         if matches!(next_token.kind(), SyntaxKind::WHITESPACE) {
59             TextRange::new(async_token.text_range().start(), next_token.text_range().end())
60         } else {
61             async_token.text_range()
62         }
63     };
64
65     // Otherwise, we may remove the `async` keyword.
66     acc.add(
67         AssistId("unnecessary_async", AssistKind::QuickFix),
68         "Remove unnecessary async",
69         async_range,
70         |edit| {
71             // Remove async on the function definition.
72             edit.replace(async_range, "");
73
74             // Remove all `.await`s from calls to the function we remove `async` from.
75             if let Some(fn_def) = ctx.sema.to_def(&function) {
76                 for await_expr in find_all_references(ctx, &Definition::Function(fn_def))
77                     // Keep only references that correspond NameRefs.
78                     .filter_map(|(_, reference)| match reference.name {
79                         NameLike::NameRef(nameref) => Some(nameref),
80                         _ => None,
81                     })
82                     // Keep only references that correspond to await expressions
83                     .filter_map(|nameref| find_await_expression(ctx, &nameref))
84                 {
85                     if let Some(await_token) = &await_expr.await_token() {
86                         edit.replace(await_token.text_range(), "");
87                     }
88                     if let Some(dot_token) = &await_expr.dot_token() {
89                         edit.replace(dot_token.text_range(), "");
90                     }
91                 }
92             }
93         },
94     )
95 }
96
97 fn find_all_references(
98     ctx: &AssistContext<'_>,
99     def: &Definition,
100 ) -> impl Iterator<Item = (FileId, FileReference)> {
101     def.usages(&ctx.sema).all().into_iter().flat_map(|(file_id, references)| {
102         references.into_iter().map(move |reference| (file_id, reference))
103     })
104 }
105
106 /// Finds the await expression for the given `NameRef`.
107 /// If no await expression is found, returns None.
108 fn find_await_expression(ctx: &AssistContext<'_>, nameref: &NameRef) -> Option<ast::AwaitExpr> {
109     // From the nameref, walk up the tree to the await expression.
110     let await_expr = if let Some(path) = full_path_of_name_ref(&nameref) {
111         // Function calls.
112         path.syntax()
113             .parent()
114             .and_then(ast::PathExpr::cast)?
115             .syntax()
116             .parent()
117             .and_then(ast::CallExpr::cast)?
118             .syntax()
119             .parent()
120             .and_then(ast::AwaitExpr::cast)
121     } else {
122         // Method calls.
123         nameref
124             .syntax()
125             .parent()
126             .and_then(ast::MethodCallExpr::cast)?
127             .syntax()
128             .parent()
129             .and_then(ast::AwaitExpr::cast)
130     };
131
132     ctx.sema.original_ast_node(await_expr?)
133 }
134
135 #[cfg(test)]
136 mod tests {
137     use super::*;
138
139     use crate::tests::{check_assist, check_assist_not_applicable};
140
141     #[test]
142     fn applies_on_empty_function() {
143         check_assist(unnecessary_async, "pub async f$0n f() {}", "pub fn f() {}")
144     }
145
146     #[test]
147     fn applies_and_removes_whitespace() {
148         check_assist(unnecessary_async, "pub async       f$0n f() {}", "pub fn f() {}")
149     }
150
151     #[test]
152     fn does_not_apply_on_non_async_function() {
153         check_assist_not_applicable(unnecessary_async, "pub f$0n f() {}")
154     }
155
156     #[test]
157     fn applies_on_function_with_a_non_await_expr() {
158         check_assist(unnecessary_async, "pub async f$0n f() { f2() }", "pub fn f() { f2() }")
159     }
160
161     #[test]
162     fn does_not_apply_on_function_with_an_await_expr() {
163         check_assist_not_applicable(unnecessary_async, "pub async f$0n f() { f2().await }")
164     }
165
166     #[test]
167     fn applies_and_removes_await_on_reference() {
168         check_assist(
169             unnecessary_async,
170             r#"
171 pub async fn f4() { }
172 pub async f$0n f2() { }
173 pub async fn f() { f2().await }
174 pub async fn f3() { f2().await }"#,
175             r#"
176 pub async fn f4() { }
177 pub fn f2() { }
178 pub async fn f() { f2() }
179 pub async fn f3() { f2() }"#,
180         )
181     }
182
183     #[test]
184     fn applies_and_removes_await_from_within_module() {
185         check_assist(
186             unnecessary_async,
187             r#"
188 pub async fn f4() { }
189 mod a { pub async f$0n f2() { } }
190 pub async fn f() { a::f2().await }
191 pub async fn f3() { a::f2().await }"#,
192             r#"
193 pub async fn f4() { }
194 mod a { pub fn f2() { } }
195 pub async fn f() { a::f2() }
196 pub async fn f3() { a::f2() }"#,
197         )
198     }
199
200     #[test]
201     fn applies_and_removes_await_on_inner_await() {
202         check_assist(
203             unnecessary_async,
204             // Ensure that it is the first await on the 3rd line that is removed
205             r#"
206 pub async fn f() { f2().await }
207 pub async f$0n f2() -> i32 { 1 }
208 pub async fn f3() { f4(f2().await).await }
209 pub async fn f4(i: i32) { }"#,
210             r#"
211 pub async fn f() { f2() }
212 pub fn f2() -> i32 { 1 }
213 pub async fn f3() { f4(f2()).await }
214 pub async fn f4(i: i32) { }"#,
215         )
216     }
217
218     #[test]
219     fn applies_and_removes_await_on_outer_await() {
220         check_assist(
221             unnecessary_async,
222             // Ensure that it is the second await on the 3rd line that is removed
223             r#"
224 pub async fn f() { f2().await }
225 pub async f$0n f2(i: i32) { }
226 pub async fn f3() { f2(f4().await).await }
227 pub async fn f4() -> i32 { 1 }"#,
228             r#"
229 pub async fn f() { f2() }
230 pub fn f2(i: i32) { }
231 pub async fn f3() { f2(f4().await) }
232 pub async fn f4() -> i32 { 1 }"#,
233         )
234     }
235
236     #[test]
237     fn applies_on_method_call() {
238         check_assist(
239             unnecessary_async,
240             r#"
241 pub struct S { }
242 impl S { pub async f$0n f2(&self) { } }
243 pub async fn f(s: &S) { s.f2().await }"#,
244             r#"
245 pub struct S { }
246 impl S { pub fn f2(&self) { } }
247 pub async fn f(s: &S) { s.f2() }"#,
248         )
249     }
250
251     #[test]
252     fn does_not_apply_on_function_with_a_nested_await_expr() {
253         check_assist_not_applicable(
254             unnecessary_async,
255             "async f$0n f() { if true { loop { f2().await } } }",
256         )
257     }
258
259     #[test]
260     fn does_not_apply_when_not_on_prototype() {
261         check_assist_not_applicable(unnecessary_async, "pub async fn f() { $0f2() }")
262     }
263
264     #[test]
265     fn does_not_apply_on_async_trait_method() {
266         check_assist_not_applicable(
267             unnecessary_async,
268             r#"
269 trait Trait {
270     async fn foo();
271 }
272 impl Trait for () {
273     $0async fn foo() {}
274 }"#,
275         );
276     }
277 }