]> git.lizzy.rs Git - rust.git/blob - crates/ra_assists/src/introduce_variable.rs
ra_assists: assist "providers" can produce multiple assists
[rust.git] / crates / ra_assists / src / introduce_variable.rs
1 use hir::db::HirDatabase;
2 use ra_syntax::{
3     ast::{self, AstNode},
4     SyntaxKind::{
5         WHITESPACE, MATCH_ARM, LAMBDA_EXPR, PATH_EXPR, BREAK_EXPR, LOOP_EXPR, RETURN_EXPR, COMMENT
6     }, SyntaxNode, TextUnit,
7 };
8
9 use crate::{AssistCtx, Assist};
10
11 pub(crate) fn introduce_variable(mut ctx: AssistCtx<impl HirDatabase>) -> Option<Assist> {
12     let node = ctx.covering_node();
13     if !valid_covering_node(node) {
14         return None;
15     }
16     let expr = node.ancestors().filter_map(valid_target_expr).next()?;
17     let (anchor_stmt, wrap_in_block) = anchor_stmt(expr)?;
18     let indent = anchor_stmt.prev_sibling()?;
19     if indent.kind() != WHITESPACE {
20         return None;
21     }
22     ctx.add_action("introduce variable", move |edit| {
23         let mut buf = String::new();
24
25         let cursor_offset = if wrap_in_block {
26             buf.push_str("{ let var_name = ");
27             TextUnit::of_str("{ let ")
28         } else {
29             buf.push_str("let var_name = ");
30             TextUnit::of_str("let ")
31         };
32
33         expr.syntax().text().push_to(&mut buf);
34         let full_stmt = ast::ExprStmt::cast(anchor_stmt);
35         let is_full_stmt = if let Some(expr_stmt) = full_stmt {
36             Some(expr.syntax()) == expr_stmt.expr().map(|e| e.syntax())
37         } else {
38             false
39         };
40         if is_full_stmt {
41             if !full_stmt.unwrap().has_semi() {
42                 buf.push_str(";");
43             }
44             edit.replace(expr.syntax().range(), buf);
45         } else {
46             buf.push_str(";");
47
48             // We want to maintain the indent level,
49             // but we do not want to duplicate possible
50             // extra newlines in the indent block
51             for chunk in indent.text().chunks() {
52                 if chunk.starts_with("\r\n") {
53                     buf.push_str("\r\n");
54                     buf.push_str(chunk.trim_start_matches("\r\n"));
55                 } else if chunk.starts_with("\n") {
56                     buf.push_str("\n");
57                     buf.push_str(chunk.trim_start_matches("\n"));
58                 } else {
59                     buf.push_str(chunk);
60                 }
61             }
62
63             edit.target(expr.syntax().range());
64             edit.replace(expr.syntax().range(), "var_name".to_string());
65             edit.insert(anchor_stmt.range().start(), buf);
66             if wrap_in_block {
67                 edit.insert(anchor_stmt.range().end(), " }");
68             }
69         }
70         edit.set_cursor(anchor_stmt.range().start() + cursor_offset);
71     });
72
73     ctx.build()
74 }
75
76 fn valid_covering_node(node: &SyntaxNode) -> bool {
77     node.kind() != COMMENT
78 }
79 /// Check whether the node is a valid expression which can be extracted to a variable.
80 /// In general that's true for any expression, but in some cases that would produce invalid code.
81 fn valid_target_expr(node: &SyntaxNode) -> Option<&ast::Expr> {
82     match node.kind() {
83         PATH_EXPR => None,
84         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
85         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
86         LOOP_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
87         _ => ast::Expr::cast(node),
88     }
89 }
90
91 /// Returns the syntax node which will follow the freshly introduced var
92 /// and a boolean indicating whether we have to wrap it within a { } block
93 /// to produce correct code.
94 /// It can be a statement, the last in a block expression or a wanna be block
95 /// expression like a lambda or match arm.
96 fn anchor_stmt(expr: &ast::Expr) -> Option<(&SyntaxNode, bool)> {
97     expr.syntax().ancestors().find_map(|node| {
98         if ast::Stmt::cast(node).is_some() {
99             return Some((node, false));
100         }
101
102         if let Some(expr) = node.parent().and_then(ast::Block::cast).and_then(|it| it.expr()) {
103             if expr.syntax() == node {
104                 return Some((node, false));
105             }
106         }
107
108         if let Some(parent) = node.parent() {
109             if parent.kind() == MATCH_ARM || parent.kind() == LAMBDA_EXPR {
110                 return Some((node, true));
111             }
112         }
113
114         None
115     })
116 }
117
118 #[cfg(test)]
119 mod tests {
120     use super::*;
121     use crate::helpers::{check_assist, check_assist_not_applicable, check_assist_range, check_assist_target, check_assist_range_target};
122
123     #[test]
124     fn test_introduce_var_simple() {
125         check_assist_range(
126             introduce_variable,
127             "
128 fn foo() {
129     foo(<|>1 + 1<|>);
130 }",
131             "
132 fn foo() {
133     let <|>var_name = 1 + 1;
134     foo(var_name);
135 }",
136         );
137     }
138
139     #[test]
140     fn test_introduce_var_expr_stmt() {
141         check_assist_range(
142             introduce_variable,
143             "
144 fn foo() {
145     <|>1 + 1<|>;
146 }",
147             "
148 fn foo() {
149     let <|>var_name = 1 + 1;
150 }",
151         );
152     }
153
154     #[test]
155     fn test_introduce_var_part_of_expr_stmt() {
156         check_assist_range(
157             introduce_variable,
158             "
159 fn foo() {
160     <|>1<|> + 1;
161 }",
162             "
163 fn foo() {
164     let <|>var_name = 1;
165     var_name + 1;
166 }",
167         );
168     }
169
170     #[test]
171     fn test_introduce_var_last_expr() {
172         check_assist_range(
173             introduce_variable,
174             "
175 fn foo() {
176     bar(<|>1 + 1<|>)
177 }",
178             "
179 fn foo() {
180     let <|>var_name = 1 + 1;
181     bar(var_name)
182 }",
183         );
184     }
185
186     #[test]
187     fn test_introduce_var_last_full_expr() {
188         check_assist_range(
189             introduce_variable,
190             "
191 fn foo() {
192     <|>bar(1 + 1)<|>
193 }",
194             "
195 fn foo() {
196     let <|>var_name = bar(1 + 1);
197     var_name
198 }",
199         );
200     }
201
202     #[test]
203     fn test_introduce_var_block_expr_second_to_last() {
204         check_assist_range(
205             introduce_variable,
206             "
207 fn foo() {
208     <|>{ let x = 0; x }<|>
209     something_else();
210 }",
211             "
212 fn foo() {
213     let <|>var_name = { let x = 0; x };
214     something_else();
215 }",
216         );
217     }
218
219     #[test]
220     fn test_introduce_var_in_match_arm_no_block() {
221         check_assist_range(
222             introduce_variable,
223             "
224 fn main() {
225     let x = true;
226     let tuple = match x {
227         true => (<|>2 + 2<|>, true)
228         _ => (0, false)
229     };
230 }
231 ",
232             "
233 fn main() {
234     let x = true;
235     let tuple = match x {
236         true => { let <|>var_name = 2 + 2; (var_name, true) }
237         _ => (0, false)
238     };
239 }
240 ",
241         );
242     }
243
244     #[test]
245     fn test_introduce_var_in_match_arm_with_block() {
246         check_assist_range(
247             introduce_variable,
248             "
249 fn main() {
250     let x = true;
251     let tuple = match x {
252         true => {
253             let y = 1;
254             (<|>2 + y<|>, true)
255         }
256         _ => (0, false)
257     };
258 }
259 ",
260             "
261 fn main() {
262     let x = true;
263     let tuple = match x {
264         true => {
265             let y = 1;
266             let <|>var_name = 2 + y;
267             (var_name, true)
268         }
269         _ => (0, false)
270     };
271 }
272 ",
273         );
274     }
275
276     #[test]
277     fn test_introduce_var_in_closure_no_block() {
278         check_assist_range(
279             introduce_variable,
280             "
281 fn main() {
282     let lambda = |x: u32| <|>x * 2<|>;
283 }
284 ",
285             "
286 fn main() {
287     let lambda = |x: u32| { let <|>var_name = x * 2; var_name };
288 }
289 ",
290         );
291     }
292
293     #[test]
294     fn test_introduce_var_in_closure_with_block() {
295         check_assist_range(
296             introduce_variable,
297             "
298 fn main() {
299     let lambda = |x: u32| { <|>x * 2<|> };
300 }
301 ",
302             "
303 fn main() {
304     let lambda = |x: u32| { let <|>var_name = x * 2; var_name };
305 }
306 ",
307         );
308     }
309
310     #[test]
311     fn test_introduce_var_path_simple() {
312         check_assist(
313             introduce_variable,
314             "
315 fn main() {
316     let o = S<|>ome(true);
317 }
318 ",
319             "
320 fn main() {
321     let <|>var_name = Some(true);
322     let o = var_name;
323 }
324 ",
325         );
326     }
327
328     #[test]
329     fn test_introduce_var_path_method() {
330         check_assist(
331             introduce_variable,
332             "
333 fn main() {
334     let v = b<|>ar.foo();
335 }
336 ",
337             "
338 fn main() {
339     let <|>var_name = bar.foo();
340     let v = var_name;
341 }
342 ",
343         );
344     }
345
346     #[test]
347     fn test_introduce_var_return() {
348         check_assist(
349             introduce_variable,
350             "
351 fn foo() -> u32 {
352     r<|>eturn 2 + 2;
353 }
354 ",
355             "
356 fn foo() -> u32 {
357     let <|>var_name = 2 + 2;
358     return var_name;
359 }
360 ",
361         );
362     }
363
364     #[test]
365     fn test_introduce_var_does_not_add_extra_whitespace() {
366         check_assist(
367             introduce_variable,
368             "
369 fn foo() -> u32 {
370
371
372     r<|>eturn 2 + 2;
373 }
374 ",
375             "
376 fn foo() -> u32 {
377
378
379     let <|>var_name = 2 + 2;
380     return var_name;
381 }
382 ",
383         );
384
385         check_assist(
386             introduce_variable,
387             "
388 fn foo() -> u32 {
389
390         r<|>eturn 2 + 2;
391 }
392 ",
393             "
394 fn foo() -> u32 {
395
396         let <|>var_name = 2 + 2;
397         return var_name;
398 }
399 ",
400         );
401
402         check_assist(
403             introduce_variable,
404             "
405 fn foo() -> u32 {
406     let foo = 1;
407
408     // bar
409
410
411     r<|>eturn 2 + 2;
412 }
413 ",
414             "
415 fn foo() -> u32 {
416     let foo = 1;
417
418     // bar
419
420
421     let <|>var_name = 2 + 2;
422     return var_name;
423 }
424 ",
425         );
426     }
427
428     #[test]
429     fn test_introduce_var_break() {
430         check_assist(
431             introduce_variable,
432             "
433 fn main() {
434     let result = loop {
435         b<|>reak 2 + 2;
436     };
437 }
438 ",
439             "
440 fn main() {
441     let result = loop {
442         let <|>var_name = 2 + 2;
443         break var_name;
444     };
445 }
446 ",
447         );
448     }
449
450     #[test]
451     fn test_introduce_var_for_cast() {
452         check_assist(
453             introduce_variable,
454             "
455 fn main() {
456     let v = 0f32 a<|>s u32;
457 }
458 ",
459             "
460 fn main() {
461     let <|>var_name = 0f32 as u32;
462     let v = var_name;
463 }
464 ",
465         );
466     }
467
468     #[test]
469     fn test_introduce_var_for_return_not_applicable() {
470         check_assist_not_applicable(
471             introduce_variable,
472             "
473 fn foo() {
474     r<|>eturn;
475 }
476 ",
477         );
478     }
479
480     #[test]
481     fn test_introduce_var_for_break_not_applicable() {
482         check_assist_not_applicable(
483             introduce_variable,
484             "
485 fn main() {
486     loop {
487         b<|>reak;
488     };
489 }
490 ",
491         );
492     }
493
494     #[test]
495     fn test_introduce_var_in_comment_not_applicable() {
496         check_assist_not_applicable(
497             introduce_variable,
498             "
499 fn main() {
500     let x = true;
501     let tuple = match x {
502         // c<|>omment
503         true => (2 + 2, true)
504         _ => (0, false)
505     };
506 }
507 ",
508         );
509     }
510
511     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
512     #[test]
513     fn introduce_var_target() {
514         check_assist_target(
515             introduce_variable,
516             "
517 fn foo() -> u32 {
518     r<|>eturn 2 + 2;
519 }
520 ",
521             "2 + 2",
522         );
523
524         check_assist_range_target(
525             introduce_variable,
526             "
527 fn main() {
528     let x = true;
529     let tuple = match x {
530         true => (<|>2 + 2<|>, true)
531         _ => (0, false)
532     };
533 }
534 ",
535             "2 + 2",
536         );
537     }
538 }