]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_iter_for_each_to_for.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / convert_iter_for_each_to_for.rs
1 use hir::known;
2 use ide_db::helpers::FamousDefs;
3 use stdx::format_to;
4 use syntax::{
5     ast::{self, edit_in_place::Indent, make, HasArgList, HasLoopBody},
6     AstNode,
7 };
8
9 use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11 // Assist: convert_iter_for_each_to_for
12 //
13 // Converts an Iterator::for_each function into a for loop.
14 //
15 // ```
16 // # //- minicore: iterators
17 // # use core::iter;
18 // fn main() {
19 //     let iter = iter::repeat((9, 2));
20 //     iter.for_each$0(|(x, y)| {
21 //         println!("x: {}, y: {}", x, y);
22 //     });
23 // }
24 // ```
25 // ->
26 // ```
27 // # use core::iter;
28 // fn main() {
29 //     let iter = iter::repeat((9, 2));
30 //     for (x, y) in iter {
31 //         println!("x: {}, y: {}", x, y);
32 //     }
33 // }
34 // ```
35 pub(crate) fn convert_iter_for_each_to_for(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
36     let method = ctx.find_node_at_offset::<ast::MethodCallExpr>()?;
37
38     let closure = match method.arg_list()?.args().next()? {
39         ast::Expr::ClosureExpr(expr) => expr,
40         _ => return None,
41     };
42
43     let (method, receiver) = validate_method_call_expr(ctx, method)?;
44
45     let param_list = closure.param_list()?;
46     let param = param_list.params().next()?.pat()?;
47     let body = closure.body()?;
48
49     let stmt = method.syntax().parent().and_then(ast::ExprStmt::cast);
50     let range = stmt.as_ref().map_or(method.syntax(), AstNode::syntax).text_range();
51
52     acc.add(
53         AssistId("convert_iter_for_each_to_for", AssistKind::RefactorRewrite),
54         "Replace this `Iterator::for_each` with a for loop",
55         range,
56         |builder| {
57             let indent =
58                 stmt.as_ref().map_or_else(|| method.indent_level(), ast::ExprStmt::indent_level);
59
60             let block = match body {
61                 ast::Expr::BlockExpr(block) => block,
62                 _ => make::block_expr(Vec::new(), Some(body)),
63             }
64             .clone_for_update();
65             block.reindent_to(indent);
66
67             let expr_for_loop = make::expr_for_loop(param, receiver, block);
68             builder.replace(range, expr_for_loop.to_string())
69         },
70     )
71 }
72
73 // Assist: convert_for_loop_with_for_each
74 //
75 // Converts a for loop into a for_each loop on the Iterator.
76 //
77 // ```
78 // fn main() {
79 //     let x = vec![1, 2, 3];
80 //     for$0 v in x {
81 //         let y = v * 2;
82 //     }
83 // }
84 // ```
85 // ->
86 // ```
87 // fn main() {
88 //     let x = vec![1, 2, 3];
89 //     x.into_iter().for_each(|v| {
90 //         let y = v * 2;
91 //     });
92 // }
93 // ```
94 pub(crate) fn convert_for_loop_with_for_each(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
95     let for_loop = ctx.find_node_at_offset::<ast::ForExpr>()?;
96     let iterable = for_loop.iterable()?;
97     let pat = for_loop.pat()?;
98     let body = for_loop.loop_body()?;
99     if body.syntax().text_range().start() < ctx.offset() {
100         cov_mark::hit!(not_available_in_body);
101         return None;
102     }
103
104     acc.add(
105         AssistId("convert_for_loop_with_for_each", AssistKind::RefactorRewrite),
106         "Replace this for loop with `Iterator::for_each`",
107         for_loop.syntax().text_range(),
108         |builder| {
109             let mut buf = String::new();
110
111             if let Some((expr_behind_ref, method)) =
112                 is_ref_and_impls_iter_method(&ctx.sema, &iterable)
113             {
114                 // We have either "for x in &col" and col implements a method called iter
115                 //             or "for x in &mut col" and col implements a method called iter_mut
116                 format_to!(buf, "{}.{}()", expr_behind_ref, method);
117             } else if let ast::Expr::RangeExpr(..) = iterable {
118                 // range expressions need to be parenthesized for the syntax to be correct
119                 format_to!(buf, "({})", iterable);
120             } else if impls_core_iter(&ctx.sema, &iterable) {
121                 format_to!(buf, "{}", iterable);
122             } else if let ast::Expr::RefExpr(_) = iterable {
123                 format_to!(buf, "({}).into_iter()", iterable);
124             } else {
125                 format_to!(buf, "{}.into_iter()", iterable);
126             }
127
128             format_to!(buf, ".for_each(|{}| {});", pat, body);
129
130             builder.replace(for_loop.syntax().text_range(), buf)
131         },
132     )
133 }
134
135 /// If iterable is a reference where the expression behind the reference implements a method
136 /// returning an Iterator called iter or iter_mut (depending on the type of reference) then return
137 /// the expression behind the reference and the method name
138 fn is_ref_and_impls_iter_method(
139     sema: &hir::Semantics<ide_db::RootDatabase>,
140     iterable: &ast::Expr,
141 ) -> Option<(ast::Expr, hir::Name)> {
142     let ref_expr = match iterable {
143         ast::Expr::RefExpr(r) => r,
144         _ => return None,
145     };
146     let wanted_method = if ref_expr.mut_token().is_some() { known::iter_mut } else { known::iter };
147     let expr_behind_ref = ref_expr.expr()?;
148     let ty = sema.type_of_expr(&expr_behind_ref)?.adjusted();
149     let scope = sema.scope(iterable.syntax());
150     let krate = scope.module()?.krate();
151     let traits_in_scope = scope.visible_traits();
152     let iter_trait = FamousDefs(sema, Some(krate)).core_iter_Iterator()?;
153
154     let has_wanted_method = ty
155         .iterate_method_candidates(
156             sema.db,
157             krate,
158             &traits_in_scope,
159             None,
160             Some(&wanted_method),
161             |_, func| {
162                 if func.ret_type(sema.db).impls_trait(sema.db, iter_trait, &[]) {
163                     return Some(());
164                 }
165                 None
166             },
167         )
168         .is_some();
169     if !has_wanted_method {
170         return None;
171     }
172
173     Some((expr_behind_ref, wanted_method))
174 }
175
176 /// Whether iterable implements core::Iterator
177 fn impls_core_iter(sema: &hir::Semantics<ide_db::RootDatabase>, iterable: &ast::Expr) -> bool {
178     let it_typ = match sema.type_of_expr(iterable) {
179         Some(it) => it.adjusted(),
180         None => return false,
181     };
182
183     let module = match sema.scope(iterable.syntax()).module() {
184         Some(it) => it,
185         None => return false,
186     };
187
188     let krate = module.krate();
189     match FamousDefs(sema, Some(krate)).core_iter_Iterator() {
190         Some(iter_trait) => {
191             cov_mark::hit!(test_already_impls_iterator);
192             it_typ.impls_trait(sema.db, iter_trait, &[])
193         }
194         None => false,
195     }
196 }
197
198 fn validate_method_call_expr(
199     ctx: &AssistContext,
200     expr: ast::MethodCallExpr,
201 ) -> Option<(ast::Expr, ast::Expr)> {
202     let name_ref = expr.name_ref()?;
203     if !name_ref.syntax().text_range().contains_range(ctx.selection_trimmed()) {
204         cov_mark::hit!(test_for_each_not_applicable_invalid_cursor_pos);
205         return None;
206     }
207     if name_ref.text() != "for_each" {
208         return None;
209     }
210
211     let sema = &ctx.sema;
212
213     let receiver = expr.receiver()?;
214     let expr = ast::Expr::MethodCallExpr(expr);
215
216     let it_type = sema.type_of_expr(&receiver)?.adjusted();
217     let module = sema.scope(receiver.syntax()).module()?;
218     let krate = module.krate();
219
220     let iter_trait = FamousDefs(sema, Some(krate)).core_iter_Iterator()?;
221     it_type.impls_trait(sema.db, iter_trait, &[]).then(|| (expr, receiver))
222 }
223
224 #[cfg(test)]
225 mod tests {
226     use crate::tests::{check_assist, check_assist_not_applicable};
227
228     use super::*;
229
230     #[test]
231     fn test_for_each_in_method_stmt() {
232         check_assist(
233             convert_iter_for_each_to_for,
234             r#"
235 //- minicore: iterators
236 fn main() {
237     let it = core::iter::repeat(92);
238     it.$0for_each(|(x, y)| {
239         println!("x: {}, y: {}", x, y);
240     });
241 }
242 "#,
243             r#"
244 fn main() {
245     let it = core::iter::repeat(92);
246     for (x, y) in it {
247         println!("x: {}, y: {}", x, y);
248     }
249 }
250 "#,
251         )
252     }
253
254     #[test]
255     fn test_for_each_in_method() {
256         check_assist(
257             convert_iter_for_each_to_for,
258             r#"
259 //- minicore: iterators
260 fn main() {
261     let it = core::iter::repeat(92);
262     it.$0for_each(|(x, y)| {
263         println!("x: {}, y: {}", x, y);
264     })
265 }
266 "#,
267             r#"
268 fn main() {
269     let it = core::iter::repeat(92);
270     for (x, y) in it {
271         println!("x: {}, y: {}", x, y);
272     }
273 }
274 "#,
275         )
276     }
277
278     #[test]
279     fn test_for_each_without_braces_stmt() {
280         check_assist(
281             convert_iter_for_each_to_for,
282             r#"
283 //- minicore: iterators
284 fn main() {
285     let it = core::iter::repeat(92);
286     it.$0for_each(|(x, y)| println!("x: {}, y: {}", x, y));
287 }
288 "#,
289             r#"
290 fn main() {
291     let it = core::iter::repeat(92);
292     for (x, y) in it {
293         println!("x: {}, y: {}", x, y)
294     }
295 }
296 "#,
297         )
298     }
299
300     #[test]
301     fn test_for_each_not_applicable() {
302         check_assist_not_applicable(
303             convert_iter_for_each_to_for,
304             r#"
305 //- minicore: iterators
306 fn main() {
307     ().$0for_each(|x| println!("{}", x));
308 }"#,
309         )
310     }
311
312     #[test]
313     fn test_for_each_not_applicable_invalid_cursor_pos() {
314         cov_mark::check!(test_for_each_not_applicable_invalid_cursor_pos);
315         check_assist_not_applicable(
316             convert_iter_for_each_to_for,
317             r#"
318 //- minicore: iterators
319 fn main() {
320     core::iter::repeat(92).for_each(|(x, y)| $0println!("x: {}, y: {}", x, y));
321 }"#,
322         )
323     }
324
325     #[test]
326     fn each_to_for_not_for() {
327         check_assist_not_applicable(
328             convert_for_loop_with_for_each,
329             r"
330 let mut x = vec![1, 2, 3];
331 x.iter_mut().$0for_each(|v| *v *= 2);
332         ",
333         )
334     }
335
336     #[test]
337     fn each_to_for_simple_for() {
338         check_assist(
339             convert_for_loop_with_for_each,
340             r"
341 fn main() {
342     let x = vec![1, 2, 3];
343     for $0v in x {
344         v *= 2;
345     }
346 }",
347             r"
348 fn main() {
349     let x = vec![1, 2, 3];
350     x.into_iter().for_each(|v| {
351         v *= 2;
352     });
353 }",
354         )
355     }
356
357     #[test]
358     fn each_to_for_for_in_range() {
359         check_assist(
360             convert_for_loop_with_for_each,
361             r#"
362 //- minicore: range, iterators
363 impl<T> core::iter::Iterator for core::ops::Range<T> {
364     type Item = T;
365
366     fn next(&mut self) -> Option<Self::Item> {
367         None
368     }
369 }
370
371 fn main() {
372     for $0x in 0..92 {
373         print!("{}", x);
374     }
375 }"#,
376             r#"
377 impl<T> core::iter::Iterator for core::ops::Range<T> {
378     type Item = T;
379
380     fn next(&mut self) -> Option<Self::Item> {
381         None
382     }
383 }
384
385 fn main() {
386     (0..92).for_each(|x| {
387         print!("{}", x);
388     });
389 }"#,
390         )
391     }
392
393     #[test]
394     fn each_to_for_not_available_in_body() {
395         cov_mark::check!(not_available_in_body);
396         check_assist_not_applicable(
397             convert_for_loop_with_for_each,
398             r"
399 fn main() {
400     let x = vec![1, 2, 3];
401     for v in x {
402         $0v *= 2;
403     }
404 }",
405         )
406     }
407
408     #[test]
409     fn each_to_for_for_borrowed() {
410         check_assist(
411             convert_for_loop_with_for_each,
412             r#"
413 //- minicore: iterators
414 use core::iter::{Repeat, repeat};
415
416 struct S;
417 impl S {
418     fn iter(&self) -> Repeat<i32> { repeat(92) }
419     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
420 }
421
422 fn main() {
423     let x = S;
424     for $0v in &x {
425         let a = v * 2;
426     }
427 }
428 "#,
429             r#"
430 use core::iter::{Repeat, repeat};
431
432 struct S;
433 impl S {
434     fn iter(&self) -> Repeat<i32> { repeat(92) }
435     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
436 }
437
438 fn main() {
439     let x = S;
440     x.iter().for_each(|v| {
441         let a = v * 2;
442     });
443 }
444 "#,
445         )
446     }
447
448     #[test]
449     fn each_to_for_for_borrowed_no_iter_method() {
450         check_assist(
451             convert_for_loop_with_for_each,
452             r"
453 struct NoIterMethod;
454 fn main() {
455     let x = NoIterMethod;
456     for $0v in &x {
457         let a = v * 2;
458     }
459 }
460 ",
461             r"
462 struct NoIterMethod;
463 fn main() {
464     let x = NoIterMethod;
465     (&x).into_iter().for_each(|v| {
466         let a = v * 2;
467     });
468 }
469 ",
470         )
471     }
472
473     #[test]
474     fn each_to_for_for_borrowed_mut() {
475         check_assist(
476             convert_for_loop_with_for_each,
477             r#"
478 //- minicore: iterators
479 use core::iter::{Repeat, repeat};
480
481 struct S;
482 impl S {
483     fn iter(&self) -> Repeat<i32> { repeat(92) }
484     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
485 }
486
487 fn main() {
488     let x = S;
489     for $0v in &mut x {
490         let a = v * 2;
491     }
492 }
493 "#,
494             r#"
495 use core::iter::{Repeat, repeat};
496
497 struct S;
498 impl S {
499     fn iter(&self) -> Repeat<i32> { repeat(92) }
500     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
501 }
502
503 fn main() {
504     let x = S;
505     x.iter_mut().for_each(|v| {
506         let a = v * 2;
507     });
508 }
509 "#,
510         )
511     }
512
513     #[test]
514     fn each_to_for_for_borrowed_mut_behind_var() {
515         check_assist(
516             convert_for_loop_with_for_each,
517             r"
518 fn main() {
519     let x = vec![1, 2, 3];
520     let y = &mut x;
521     for $0v in y {
522         *v *= 2;
523     }
524 }",
525             r"
526 fn main() {
527     let x = vec![1, 2, 3];
528     let y = &mut x;
529     y.into_iter().for_each(|v| {
530         *v *= 2;
531     });
532 }",
533         )
534     }
535
536     #[test]
537     fn each_to_for_already_impls_iterator() {
538         cov_mark::check!(test_already_impls_iterator);
539         check_assist(
540             convert_for_loop_with_for_each,
541             r#"
542 //- minicore: iterators
543 fn main() {
544     for$0 a in core::iter::repeat(92).take(1) {
545         println!("{}", a);
546     }
547 }
548 "#,
549             r#"
550 fn main() {
551     core::iter::repeat(92).take(1).for_each(|a| {
552         println!("{}", a);
553     });
554 }
555 "#,
556         );
557     }
558 }