]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_iter_for_each_to_for.rs
Merge #11391
[rust.git] / crates / ide_assists / src / handlers / convert_iter_for_each_to_for.rs
1 use hir::known;
2 use ide_db::helpers::FamousDefs;
3 use stdx::format_to;
4 use syntax::{
5     ast::{self, edit_in_place::Indent, make, HasArgList, HasLoopBody},
6     AstNode,
7 };
8
9 use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11 // Assist: convert_iter_for_each_to_for
12 //
13 // Converts an Iterator::for_each function into a for loop.
14 //
15 // ```
16 // # //- minicore: iterators
17 // # use core::iter;
18 // fn main() {
19 //     let iter = iter::repeat((9, 2));
20 //     iter.for_each$0(|(x, y)| {
21 //         println!("x: {}, y: {}", x, y);
22 //     });
23 // }
24 // ```
25 // ->
26 // ```
27 // # use core::iter;
28 // fn main() {
29 //     let iter = iter::repeat((9, 2));
30 //     for (x, y) in iter {
31 //         println!("x: {}, y: {}", x, y);
32 //     }
33 // }
34 // ```
35 pub(crate) fn convert_iter_for_each_to_for(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
36     let method = ctx.find_node_at_offset::<ast::MethodCallExpr>()?;
37
38     let closure = match method.arg_list()?.args().next()? {
39         ast::Expr::ClosureExpr(expr) => expr,
40         _ => return None,
41     };
42
43     let (method, receiver) = validate_method_call_expr(ctx, method)?;
44
45     let param_list = closure.param_list()?;
46     let param = param_list.params().next()?.pat()?;
47     let body = closure.body()?;
48
49     let stmt = method.syntax().parent().and_then(ast::ExprStmt::cast);
50     let range = stmt.as_ref().map_or(method.syntax(), AstNode::syntax).text_range();
51
52     acc.add(
53         AssistId("convert_iter_for_each_to_for", AssistKind::RefactorRewrite),
54         "Replace this `Iterator::for_each` with a for loop",
55         range,
56         |builder| {
57             let indent =
58                 stmt.as_ref().map_or_else(|| method.indent_level(), ast::ExprStmt::indent_level);
59
60             let block = match body {
61                 ast::Expr::BlockExpr(block) => block,
62                 _ => make::block_expr(Vec::new(), Some(body)),
63             }
64             .clone_for_update();
65             block.reindent_to(indent);
66
67             let expr_for_loop = make::expr_for_loop(param, receiver, block);
68             builder.replace(range, expr_for_loop.to_string())
69         },
70     )
71 }
72
73 // Assist: convert_for_loop_with_for_each
74 //
75 // Converts a for loop into a for_each loop on the Iterator.
76 //
77 // ```
78 // fn main() {
79 //     let x = vec![1, 2, 3];
80 //     for$0 v in x {
81 //         let y = v * 2;
82 //     }
83 // }
84 // ```
85 // ->
86 // ```
87 // fn main() {
88 //     let x = vec![1, 2, 3];
89 //     x.into_iter().for_each(|v| {
90 //         let y = v * 2;
91 //     });
92 // }
93 // ```
94 pub(crate) fn convert_for_loop_with_for_each(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
95     let for_loop = ctx.find_node_at_offset::<ast::ForExpr>()?;
96     let iterable = for_loop.iterable()?;
97     let pat = for_loop.pat()?;
98     let body = for_loop.loop_body()?;
99     if body.syntax().text_range().start() < ctx.offset() {
100         cov_mark::hit!(not_available_in_body);
101         return None;
102     }
103
104     acc.add(
105         AssistId("convert_for_loop_with_for_each", AssistKind::RefactorRewrite),
106         "Replace this for loop with `Iterator::for_each`",
107         for_loop.syntax().text_range(),
108         |builder| {
109             let mut buf = String::new();
110
111             if let Some((expr_behind_ref, method)) =
112                 is_ref_and_impls_iter_method(&ctx.sema, &iterable)
113             {
114                 // We have either "for x in &col" and col implements a method called iter
115                 //             or "for x in &mut col" and col implements a method called iter_mut
116                 format_to!(buf, "{}.{}()", expr_behind_ref, method);
117             } else if let ast::Expr::RangeExpr(..) = iterable {
118                 // range expressions need to be parenthesized for the syntax to be correct
119                 format_to!(buf, "({})", iterable);
120             } else if impls_core_iter(&ctx.sema, &iterable) {
121                 format_to!(buf, "{}", iterable);
122             } else if let ast::Expr::RefExpr(_) = iterable {
123                 format_to!(buf, "({}).into_iter()", iterable);
124             } else {
125                 format_to!(buf, "{}.into_iter()", iterable);
126             }
127
128             format_to!(buf, ".for_each(|{}| {});", pat, body);
129
130             builder.replace(for_loop.syntax().text_range(), buf)
131         },
132     )
133 }
134
135 /// If iterable is a reference where the expression behind the reference implements a method
136 /// returning an Iterator called iter or iter_mut (depending on the type of reference) then return
137 /// the expression behind the reference and the method name
138 fn is_ref_and_impls_iter_method(
139     sema: &hir::Semantics<ide_db::RootDatabase>,
140     iterable: &ast::Expr,
141 ) -> Option<(ast::Expr, hir::Name)> {
142     let ref_expr = match iterable {
143         ast::Expr::RefExpr(r) => r,
144         _ => return None,
145     };
146     let wanted_method = if ref_expr.mut_token().is_some() { known::iter_mut } else { known::iter };
147     let expr_behind_ref = ref_expr.expr()?;
148     let ty = sema.type_of_expr(&expr_behind_ref)?.adjusted();
149     let scope = sema.scope(iterable.syntax());
150     let krate = scope.module()?.krate();
151     let traits_in_scope = scope.visible_traits();
152     let iter_trait = FamousDefs(sema, Some(krate)).core_iter_Iterator()?;
153
154     let has_wanted_method = ty
155         .iterate_method_candidates(
156             sema.db,
157             krate,
158             &traits_in_scope,
159             Some(&wanted_method),
160             |_, func| {
161                 if func.ret_type(sema.db).impls_trait(sema.db, iter_trait, &[]) {
162                     return Some(());
163                 }
164                 None
165             },
166         )
167         .is_some();
168     if !has_wanted_method {
169         return None;
170     }
171
172     Some((expr_behind_ref, wanted_method))
173 }
174
175 /// Whether iterable implements core::Iterator
176 fn impls_core_iter(sema: &hir::Semantics<ide_db::RootDatabase>, iterable: &ast::Expr) -> bool {
177     let it_typ = match sema.type_of_expr(iterable) {
178         Some(it) => it.adjusted(),
179         None => return false,
180     };
181
182     let module = match sema.scope(iterable.syntax()).module() {
183         Some(it) => it,
184         None => return false,
185     };
186
187     let krate = module.krate();
188     match FamousDefs(sema, Some(krate)).core_iter_Iterator() {
189         Some(iter_trait) => {
190             cov_mark::hit!(test_already_impls_iterator);
191             it_typ.impls_trait(sema.db, iter_trait, &[])
192         }
193         None => false,
194     }
195 }
196
197 fn validate_method_call_expr(
198     ctx: &AssistContext,
199     expr: ast::MethodCallExpr,
200 ) -> Option<(ast::Expr, ast::Expr)> {
201     let name_ref = expr.name_ref()?;
202     if !name_ref.syntax().text_range().contains_range(ctx.selection_trimmed()) {
203         cov_mark::hit!(test_for_each_not_applicable_invalid_cursor_pos);
204         return None;
205     }
206     if name_ref.text() != "for_each" {
207         return None;
208     }
209
210     let sema = &ctx.sema;
211
212     let receiver = expr.receiver()?;
213     let expr = ast::Expr::MethodCallExpr(expr);
214
215     let it_type = sema.type_of_expr(&receiver)?.adjusted();
216     let module = sema.scope(receiver.syntax()).module()?;
217     let krate = module.krate();
218
219     let iter_trait = FamousDefs(sema, Some(krate)).core_iter_Iterator()?;
220     it_type.impls_trait(sema.db, iter_trait, &[]).then(|| (expr, receiver))
221 }
222
223 #[cfg(test)]
224 mod tests {
225     use crate::tests::{check_assist, check_assist_not_applicable};
226
227     use super::*;
228
229     #[test]
230     fn test_for_each_in_method_stmt() {
231         check_assist(
232             convert_iter_for_each_to_for,
233             r#"
234 //- minicore: iterators
235 fn main() {
236     let it = core::iter::repeat(92);
237     it.$0for_each(|(x, y)| {
238         println!("x: {}, y: {}", x, y);
239     });
240 }
241 "#,
242             r#"
243 fn main() {
244     let it = core::iter::repeat(92);
245     for (x, y) in it {
246         println!("x: {}, y: {}", x, y);
247     }
248 }
249 "#,
250         )
251     }
252
253     #[test]
254     fn test_for_each_in_method() {
255         check_assist(
256             convert_iter_for_each_to_for,
257             r#"
258 //- minicore: iterators
259 fn main() {
260     let it = core::iter::repeat(92);
261     it.$0for_each(|(x, y)| {
262         println!("x: {}, y: {}", x, y);
263     })
264 }
265 "#,
266             r#"
267 fn main() {
268     let it = core::iter::repeat(92);
269     for (x, y) in it {
270         println!("x: {}, y: {}", x, y);
271     }
272 }
273 "#,
274         )
275     }
276
277     #[test]
278     fn test_for_each_without_braces_stmt() {
279         check_assist(
280             convert_iter_for_each_to_for,
281             r#"
282 //- minicore: iterators
283 fn main() {
284     let it = core::iter::repeat(92);
285     it.$0for_each(|(x, y)| println!("x: {}, y: {}", x, y));
286 }
287 "#,
288             r#"
289 fn main() {
290     let it = core::iter::repeat(92);
291     for (x, y) in it {
292         println!("x: {}, y: {}", x, y)
293     }
294 }
295 "#,
296         )
297     }
298
299     #[test]
300     fn test_for_each_not_applicable() {
301         check_assist_not_applicable(
302             convert_iter_for_each_to_for,
303             r#"
304 //- minicore: iterators
305 fn main() {
306     ().$0for_each(|x| println!("{}", x));
307 }"#,
308         )
309     }
310
311     #[test]
312     fn test_for_each_not_applicable_invalid_cursor_pos() {
313         cov_mark::check!(test_for_each_not_applicable_invalid_cursor_pos);
314         check_assist_not_applicable(
315             convert_iter_for_each_to_for,
316             r#"
317 //- minicore: iterators
318 fn main() {
319     core::iter::repeat(92).for_each(|(x, y)| $0println!("x: {}, y: {}", x, y));
320 }"#,
321         )
322     }
323
324     #[test]
325     fn each_to_for_not_for() {
326         check_assist_not_applicable(
327             convert_for_loop_with_for_each,
328             r"
329 let mut x = vec![1, 2, 3];
330 x.iter_mut().$0for_each(|v| *v *= 2);
331         ",
332         )
333     }
334
335     #[test]
336     fn each_to_for_simple_for() {
337         check_assist(
338             convert_for_loop_with_for_each,
339             r"
340 fn main() {
341     let x = vec![1, 2, 3];
342     for $0v in x {
343         v *= 2;
344     }
345 }",
346             r"
347 fn main() {
348     let x = vec![1, 2, 3];
349     x.into_iter().for_each(|v| {
350         v *= 2;
351     });
352 }",
353         )
354     }
355
356     #[test]
357     fn each_to_for_for_in_range() {
358         check_assist(
359             convert_for_loop_with_for_each,
360             r#"
361 //- minicore: range, iterators
362 impl<T> core::iter::Iterator for core::ops::Range<T> {
363     type Item = T;
364
365     fn next(&mut self) -> Option<Self::Item> {
366         None
367     }
368 }
369
370 fn main() {
371     for $0x in 0..92 {
372         print!("{}", x);
373     }
374 }"#,
375             r#"
376 impl<T> core::iter::Iterator for core::ops::Range<T> {
377     type Item = T;
378
379     fn next(&mut self) -> Option<Self::Item> {
380         None
381     }
382 }
383
384 fn main() {
385     (0..92).for_each(|x| {
386         print!("{}", x);
387     });
388 }"#,
389         )
390     }
391
392     #[test]
393     fn each_to_for_not_available_in_body() {
394         cov_mark::check!(not_available_in_body);
395         check_assist_not_applicable(
396             convert_for_loop_with_for_each,
397             r"
398 fn main() {
399     let x = vec![1, 2, 3];
400     for v in x {
401         $0v *= 2;
402     }
403 }",
404         )
405     }
406
407     #[test]
408     fn each_to_for_for_borrowed() {
409         check_assist(
410             convert_for_loop_with_for_each,
411             r#"
412 //- minicore: iterators
413 use core::iter::{Repeat, repeat};
414
415 struct S;
416 impl S {
417     fn iter(&self) -> Repeat<i32> { repeat(92) }
418     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
419 }
420
421 fn main() {
422     let x = S;
423     for $0v in &x {
424         let a = v * 2;
425     }
426 }
427 "#,
428             r#"
429 use core::iter::{Repeat, repeat};
430
431 struct S;
432 impl S {
433     fn iter(&self) -> Repeat<i32> { repeat(92) }
434     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
435 }
436
437 fn main() {
438     let x = S;
439     x.iter().for_each(|v| {
440         let a = v * 2;
441     });
442 }
443 "#,
444         )
445     }
446
447     #[test]
448     fn each_to_for_for_borrowed_no_iter_method() {
449         check_assist(
450             convert_for_loop_with_for_each,
451             r"
452 struct NoIterMethod;
453 fn main() {
454     let x = NoIterMethod;
455     for $0v in &x {
456         let a = v * 2;
457     }
458 }
459 ",
460             r"
461 struct NoIterMethod;
462 fn main() {
463     let x = NoIterMethod;
464     (&x).into_iter().for_each(|v| {
465         let a = v * 2;
466     });
467 }
468 ",
469         )
470     }
471
472     #[test]
473     fn each_to_for_for_borrowed_mut() {
474         check_assist(
475             convert_for_loop_with_for_each,
476             r#"
477 //- minicore: iterators
478 use core::iter::{Repeat, repeat};
479
480 struct S;
481 impl S {
482     fn iter(&self) -> Repeat<i32> { repeat(92) }
483     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
484 }
485
486 fn main() {
487     let x = S;
488     for $0v in &mut x {
489         let a = v * 2;
490     }
491 }
492 "#,
493             r#"
494 use core::iter::{Repeat, repeat};
495
496 struct S;
497 impl S {
498     fn iter(&self) -> Repeat<i32> { repeat(92) }
499     fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
500 }
501
502 fn main() {
503     let x = S;
504     x.iter_mut().for_each(|v| {
505         let a = v * 2;
506     });
507 }
508 "#,
509         )
510     }
511
512     #[test]
513     fn each_to_for_for_borrowed_mut_behind_var() {
514         check_assist(
515             convert_for_loop_with_for_each,
516             r"
517 fn main() {
518     let x = vec![1, 2, 3];
519     let y = &mut x;
520     for $0v in y {
521         *v *= 2;
522     }
523 }",
524             r"
525 fn main() {
526     let x = vec![1, 2, 3];
527     let y = &mut x;
528     y.into_iter().for_each(|v| {
529         *v *= 2;
530     });
531 }",
532         )
533     }
534
535     #[test]
536     fn each_to_for_already_impls_iterator() {
537         cov_mark::check!(test_already_impls_iterator);
538         check_assist(
539             convert_for_loop_with_for_each,
540             r#"
541 //- minicore: iterators
542 fn main() {
543     for$0 a in core::iter::repeat(92).take(1) {
544         println!("{}", a);
545     }
546 }
547 "#,
548             r#"
549 fn main() {
550     core::iter::repeat(92).take(1).for_each(|a| {
551         println!("{}", a);
552     });
553 }
554 "#,
555         );
556     }
557 }