]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs
Auto merge of #103913 - Neutron3529:patch-1, r=thomcc
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / convert_iter_for_each_to_for.rs
1 use hir::known;
2 use ide_db::famous_defs::FamousDefs;
3 use stdx::format_to;
4 use syntax::{
5     ast::{self, edit_in_place::Indent, make, HasArgList, HasLoopBody},
6     AstNode,
7 };
8
9 use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11 // Assist: convert_iter_for_each_to_for
12 //
13 // Converts an Iterator::for_each function into a for loop.
14 //
15 // ```
16 // # //- minicore: iterators
17 // # use core::iter;
18 // fn main() {
19 //     let iter = iter::repeat((9, 2));
20 //     iter.for_each$0(|(x, y)| {
21 //         println!("x: {}, y: {}", x, y);
22 //     });
23 // }
24 // ```
25 // ->
26 // ```
27 // # use core::iter;
28 // fn main() {
29 //     let iter = iter::repeat((9, 2));
30 //     for (x, y) in iter {
31 //         println!("x: {}, y: {}", x, y);
32 //     }
33 // }
34 // ```
35 pub(crate) fn convert_iter_for_each_to_for(
36     acc: &mut Assists,
37     ctx: &AssistContext<'_>,
38 ) -> Option<()> {
39     let method = ctx.find_node_at_offset::<ast::MethodCallExpr>()?;
40
41     let closure = match method.arg_list()?.args().next()? {
42         ast::Expr::ClosureExpr(expr) => expr,
43         _ => return None,
44     };
45
46     let (method, receiver) = validate_method_call_expr(ctx, method)?;
47
48     let param_list = closure.param_list()?;
49     let param = param_list.params().next()?.pat()?;
50     let body = closure.body()?;
51
52     let stmt = method.syntax().parent().and_then(ast::ExprStmt::cast);
53     let range = stmt.as_ref().map_or(method.syntax(), AstNode::syntax).text_range();
54
55     acc.add(
56         AssistId("convert_iter_for_each_to_for", AssistKind::RefactorRewrite),
57         "Replace this `Iterator::for_each` with a for loop",
58         range,
59         |builder| {
60             let indent =
61                 stmt.as_ref().map_or_else(|| method.indent_level(), ast::ExprStmt::indent_level);
62
63             let block = match body {
64                 ast::Expr::BlockExpr(block) => block,
65                 _ => make::block_expr(Vec::new(), Some(body)),
66             }
67             .clone_for_update();
68             block.reindent_to(indent);
69
70             let expr_for_loop = make::expr_for_loop(param, receiver, block);
71             builder.replace(range, expr_for_loop.to_string())
72         },
73     )
74 }
75
76 // Assist: convert_for_loop_with_for_each
77 //
78 // Converts a for loop into a for_each loop on the Iterator.
79 //
80 // ```
81 // fn main() {
82 //     let x = vec![1, 2, 3];
83 //     for$0 v in x {
84 //         let y = v * 2;
85 //     }
86 // }
87 // ```
88 // ->
89 // ```
90 // fn main() {
91 //     let x = vec![1, 2, 3];
92 //     x.into_iter().for_each(|v| {
93 //         let y = v * 2;
94 //     });
95 // }
96 // ```
97 pub(crate) fn convert_for_loop_with_for_each(
98     acc: &mut Assists,
99     ctx: &AssistContext<'_>,
100 ) -> Option<()> {
101     let for_loop = ctx.find_node_at_offset::<ast::ForExpr>()?;
102     let iterable = for_loop.iterable()?;
103     let pat = for_loop.pat()?;
104     let body = for_loop.loop_body()?;
105     if body.syntax().text_range().start() < ctx.offset() {
106         cov_mark::hit!(not_available_in_body);
107         return None;
108     }
109
110     acc.add(
111         AssistId("convert_for_loop_with_for_each", AssistKind::RefactorRewrite),
112         "Replace this for loop with `Iterator::for_each`",
113         for_loop.syntax().text_range(),
114         |builder| {
115             let mut buf = String::new();
116
117             if let Some((expr_behind_ref, method)) =
118                 is_ref_and_impls_iter_method(&ctx.sema, &iterable)
119             {
120                 // We have either "for x in &col" and col implements a method called iter
121                 //             or "for x in &mut col" and col implements a method called iter_mut
122                 format_to!(buf, "{expr_behind_ref}.{method}()");
123             } else if let ast::Expr::RangeExpr(..) = iterable {
124                 // range expressions need to be parenthesized for the syntax to be correct
125                 format_to!(buf, "({iterable})");
126             } else if impls_core_iter(&ctx.sema, &iterable) {
127                 format_to!(buf, "{iterable}");
128             } else if let ast::Expr::RefExpr(_) = iterable {
129                 format_to!(buf, "({iterable}).into_iter()");
130             } else {
131                 format_to!(buf, "{iterable}.into_iter()");
132             }
133
134             format_to!(buf, ".for_each(|{pat}| {body});");
135
136             builder.replace(for_loop.syntax().text_range(), buf)
137         },
138     )
139 }
140
141 /// If iterable is a reference where the expression behind the reference implements a method
142 /// returning an Iterator called iter or iter_mut (depending on the type of reference) then return
143 /// the expression behind the reference and the method name
144 fn is_ref_and_impls_iter_method(
145     sema: &hir::Semantics<'_, ide_db::RootDatabase>,
146     iterable: &ast::Expr,
147 ) -> Option<(ast::Expr, hir::Name)> {
148     let ref_expr = match iterable {
149         ast::Expr::RefExpr(r) => r,
150         _ => return None,
151     };
152     let wanted_method = if ref_expr.mut_token().is_some() { known::iter_mut } else { known::iter };
153     let expr_behind_ref = ref_expr.expr()?;
154     let ty = sema.type_of_expr(&expr_behind_ref)?.adjusted();
155     let scope = sema.scope(iterable.syntax())?;
156     let krate = scope.krate();
157     let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
158
159     let has_wanted_method = ty
160         .iterate_method_candidates(
161             sema.db,
162             &scope,
163             &scope.visible_traits().0,
164             None,
165             Some(&wanted_method),
166             |func| {
167                 if func.ret_type(sema.db).impls_trait(sema.db, iter_trait, &[]) {
168                     return Some(());
169                 }
170                 None
171             },
172         )
173         .is_some();
174     if !has_wanted_method {
175         return None;
176     }
177
178     Some((expr_behind_ref, wanted_method))
179 }
180
181 /// Whether iterable implements core::Iterator
182 fn impls_core_iter(sema: &hir::Semantics<'_, ide_db::RootDatabase>, iterable: &ast::Expr) -> bool {
183     (|| {
184         let it_typ = sema.type_of_expr(iterable)?.adjusted();
185
186         let module = sema.scope(iterable.syntax())?.module();
187
188         let krate = module.krate();
189         let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
190         cov_mark::hit!(test_already_impls_iterator);
191         Some(it_typ.impls_trait(sema.db, iter_trait, &[]))
192     })()
193     .unwrap_or(false)
194 }
195
196 fn validate_method_call_expr(
197     ctx: &AssistContext<'_>,
198     expr: ast::MethodCallExpr,
199 ) -> Option<(ast::Expr, ast::Expr)> {
200     let name_ref = expr.name_ref()?;
201     if !name_ref.syntax().text_range().contains_range(ctx.selection_trimmed()) {
202         cov_mark::hit!(test_for_each_not_applicable_invalid_cursor_pos);
203         return None;
204     }
205     if name_ref.text() != "for_each" {
206         return None;
207     }
208
209     let sema = &ctx.sema;
210
211     let receiver = expr.receiver()?;
212     let expr = ast::Expr::MethodCallExpr(expr);
213
214     let it_type = sema.type_of_expr(&receiver)?.adjusted();
215     let module = sema.scope(receiver.syntax())?.module();
216     let krate = module.krate();
217
218     let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
219     it_type.impls_trait(sema.db, iter_trait, &[]).then(|| (expr, receiver))
220 }
221
222 #[cfg(test)]
223 mod tests {
224     use crate::tests::{check_assist, check_assist_not_applicable};
225
226     use super::*;
227
228     #[test]
229     fn test_for_each_in_method_stmt() {
230         check_assist(
231             convert_iter_for_each_to_for,
232             r#"
233 //- minicore: iterators
234 fn main() {
235     let it = core::iter::repeat(92);
236     it.$0for_each(|(x, y)| {
237         println!("x: {}, y: {}", x, y);
238     });
239 }
240 "#,
241             r#"
242 fn main() {
243     let it = core::iter::repeat(92);
244     for (x, y) in it {
245         println!("x: {}, y: {}", x, y);
246     }
247 }
248 "#,
249         )
250     }
251
252     #[test]
253     fn test_for_each_in_method() {
254         check_assist(
255             convert_iter_for_each_to_for,
256             r#"
257 //- minicore: iterators
258 fn main() {
259     let it = core::iter::repeat(92);
260     it.$0for_each(|(x, y)| {
261         println!("x: {}, y: {}", x, y);
262     })
263 }
264 "#,
265             r#"
266 fn main() {
267     let it = core::iter::repeat(92);
268     for (x, y) in it {
269         println!("x: {}, y: {}", x, y);
270     }
271 }
272 "#,
273         )
274     }
275
276     #[test]
277     fn test_for_each_without_braces_stmt() {
278         check_assist(
279             convert_iter_for_each_to_for,
280             r#"
281 //- minicore: iterators
282 fn main() {
283     let it = core::iter::repeat(92);
284     it.$0for_each(|(x, y)| println!("x: {}, y: {}", x, y));
285 }
286 "#,
287             r#"
288 fn main() {
289     let it = core::iter::repeat(92);
290     for (x, y) in it {
291         println!("x: {}, y: {}", x, y)
292     }
293 }
294 "#,
295         )
296     }
297
298     #[test]
299     fn test_for_each_not_applicable() {
300         check_assist_not_applicable(
301             convert_iter_for_each_to_for,
302             r#"
303 //- minicore: iterators
304 fn main() {
305     ().$0for_each(|x| println!("{}", x));
306 }"#,
307         )
308     }
309
310     #[test]
311     fn test_for_each_not_applicable_invalid_cursor_pos() {
312         cov_mark::check!(test_for_each_not_applicable_invalid_cursor_pos);
313         check_assist_not_applicable(
314             convert_iter_for_each_to_for,
315             r#"
316 //- minicore: iterators
317 fn main() {
318     core::iter::repeat(92).for_each(|(x, y)| $0println!("x: {}, y: {}", x, y));
319 }"#,
320         )
321     }
322
323     #[test]
324     fn each_to_for_not_for() {
325         check_assist_not_applicable(
326             convert_for_loop_with_for_each,
327             r"
328 let mut x = vec![1, 2, 3];
329 x.iter_mut().$0for_each(|v| *v *= 2);
330         ",
331         )
332     }
333
334     #[test]
335     fn each_to_for_simple_for() {
336         check_assist(
337             convert_for_loop_with_for_each,
338             r"
339 fn main() {
340     let x = vec![1, 2, 3];
341     for $0v in x {
342         v *= 2;
343     }
344 }",
345             r"
346 fn main() {
347     let x = vec![1, 2, 3];
348     x.into_iter().for_each(|v| {
349         v *= 2;
350     });
351 }",
352         )
353     }
354
355     #[test]
356     fn each_to_for_for_in_range() {
357         check_assist(
358             convert_for_loop_with_for_each,
359             r#"
360 //- minicore: range, iterators
361 impl<T> core::iter::Iterator for core::ops::Range<T> {
362     type Item = T;
363
364     fn next(&mut self) -> Option<Self::Item> {
365         None
366     }
367 }
368
369 fn main() {
370     for $0x in 0..92 {
371         print!("{}", x);
372     }
373 }"#,
374             r#"
375 impl<T> core::iter::Iterator for core::ops::Range<T> {
376     type Item = T;
377
378     fn next(&mut self) -> Option<Self::Item> {
379         None
380     }
381 }
382
383 fn main() {
384     (0..92).for_each(|x| {
385         print!("{}", x);
386     });
387 }"#,
388         )
389     }
390
391     #[test]
392     fn each_to_for_not_available_in_body() {
393         cov_mark::check!(not_available_in_body);
394         check_assist_not_applicable(
395             convert_for_loop_with_for_each,
396             r"
397 fn main() {
398     let x = vec![1, 2, 3];
399     for v in x {
400         $0v *= 2;
401     }
402 }",
403         )
404     }
405
406     #[test]
407     fn each_to_for_for_borrowed() {
408         check_assist(
409             convert_for_loop_with_for_each,
410             r#"
411 //- minicore: iterators
412 use core::iter::{Repeat, repeat};
413
414 struct S;
415 impl S {
416     fn iter(&self) -> Repeat<i32> { repeat(92) }
417     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
418 }
419
420 fn main() {
421     let x = S;
422     for $0v in &x {
423         let a = v * 2;
424     }
425 }
426 "#,
427             r#"
428 use core::iter::{Repeat, repeat};
429
430 struct S;
431 impl S {
432     fn iter(&self) -> Repeat<i32> { repeat(92) }
433     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
434 }
435
436 fn main() {
437     let x = S;
438     x.iter().for_each(|v| {
439         let a = v * 2;
440     });
441 }
442 "#,
443         )
444     }
445
446     #[test]
447     fn each_to_for_for_borrowed_no_iter_method() {
448         check_assist(
449             convert_for_loop_with_for_each,
450             r"
451 struct NoIterMethod;
452 fn main() {
453     let x = NoIterMethod;
454     for $0v in &x {
455         let a = v * 2;
456     }
457 }
458 ",
459             r"
460 struct NoIterMethod;
461 fn main() {
462     let x = NoIterMethod;
463     (&x).into_iter().for_each(|v| {
464         let a = v * 2;
465     });
466 }
467 ",
468         )
469     }
470
471     #[test]
472     fn each_to_for_for_borrowed_mut() {
473         check_assist(
474             convert_for_loop_with_for_each,
475             r#"
476 //- minicore: iterators
477 use core::iter::{Repeat, repeat};
478
479 struct S;
480 impl S {
481     fn iter(&self) -> Repeat<i32> { repeat(92) }
482     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
483 }
484
485 fn main() {
486     let x = S;
487     for $0v in &mut x {
488         let a = v * 2;
489     }
490 }
491 "#,
492             r#"
493 use core::iter::{Repeat, repeat};
494
495 struct S;
496 impl S {
497     fn iter(&self) -> Repeat<i32> { repeat(92) }
498     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
499 }
500
501 fn main() {
502     let x = S;
503     x.iter_mut().for_each(|v| {
504         let a = v * 2;
505     });
506 }
507 "#,
508         )
509     }
510
511     #[test]
512     fn each_to_for_for_borrowed_mut_behind_var() {
513         check_assist(
514             convert_for_loop_with_for_each,
515             r"
516 fn main() {
517     let x = vec![1, 2, 3];
518     let y = &mut x;
519     for $0v in y {
520         *v *= 2;
521     }
522 }",
523             r"
524 fn main() {
525     let x = vec![1, 2, 3];
526     let y = &mut x;
527     y.into_iter().for_each(|v| {
528         *v *= 2;
529     });
530 }",
531         )
532     }
533
534     #[test]
535     fn each_to_for_already_impls_iterator() {
536         cov_mark::check!(test_already_impls_iterator);
537         check_assist(
538             convert_for_loop_with_for_each,
539             r#"
540 //- minicore: iterators
541 fn main() {
542     for$0 a in core::iter::repeat(92).take(1) {
543         println!("{}", a);
544     }
545 }
546 "#,
547             r#"
548 fn main() {
549     core::iter::repeat(92).take(1).for_each(|a| {
550         println!("{}", a);
551     });
552 }
553 "#,
554         );
555     }
556 }