]> git.lizzy.rs Git - rust.git/blob - crates/ra_assists/src/introduce_variable.rs
954b97b055db3e465bfa3a3db43ff95ab34e82a0
[rust.git] / crates / ra_assists / src / introduce_variable.rs
1 use hir::db::HirDatabase;
2 use ra_syntax::{
3     ast::{self, AstNode},
4     SyntaxKind::{
5         WHITESPACE, MATCH_ARM, LAMBDA_EXPR, PATH_EXPR, BREAK_EXPR, LOOP_EXPR, RETURN_EXPR, COMMENT
6     }, SyntaxNode, TextUnit,
7 };
8
9 use crate::{AssistCtx, Assist};
10
11 pub(crate) fn introduce_variable(ctx: AssistCtx<impl HirDatabase>) -> Option<Assist> {
12     let node = ctx.covering_node();
13     if !valid_covering_node(node) {
14         return None;
15     }
16     let expr = node.ancestors().filter_map(valid_target_expr).next()?;
17     let (anchor_stmt, wrap_in_block) = anchor_stmt(expr)?;
18     let indent = anchor_stmt.prev_sibling()?;
19     if indent.kind() != WHITESPACE {
20         return None;
21     }
22     ctx.build("introduce variable", move |edit| {
23         let mut buf = String::new();
24
25         let cursor_offset = if wrap_in_block {
26             buf.push_str("{ let var_name = ");
27             TextUnit::of_str("{ let ")
28         } else {
29             buf.push_str("let var_name = ");
30             TextUnit::of_str("let ")
31         };
32
33         expr.syntax().text().push_to(&mut buf);
34         let full_stmt = ast::ExprStmt::cast(anchor_stmt);
35         let is_full_stmt = if let Some(expr_stmt) = full_stmt {
36             Some(expr.syntax()) == expr_stmt.expr().map(|e| e.syntax())
37         } else {
38             false
39         };
40         if is_full_stmt {
41             if !full_stmt.unwrap().has_semi() {
42                 buf.push_str(";");
43             }
44             edit.replace(expr.syntax().range(), buf);
45         } else {
46             buf.push_str(";");
47
48             // We want to maintain the indent level,
49             // but we do not want to duplicate possible
50             // extra newlines in the indent block
51             for chunk in indent.text().chunks() {
52                 if chunk.starts_with("\r\n") {
53                     buf.push_str("\r\n");
54                     buf.push_str(chunk.trim_start_matches("\r\n"));
55                 } else if chunk.starts_with("\n") {
56                     buf.push_str("\n");
57                     buf.push_str(chunk.trim_start_matches("\n"));
58                 } else {
59                     buf.push_str(chunk);
60                 }
61             }
62
63             edit.target(expr.syntax().range());
64             edit.replace(expr.syntax().range(), "var_name".to_string());
65             edit.insert(anchor_stmt.range().start(), buf);
66             if wrap_in_block {
67                 edit.insert(anchor_stmt.range().end(), " }");
68             }
69         }
70         edit.set_cursor(anchor_stmt.range().start() + cursor_offset);
71     })
72 }
73
74 fn valid_covering_node(node: &SyntaxNode) -> bool {
75     node.kind() != COMMENT
76 }
77 /// Check whether the node is a valid expression which can be extracted to a variable.
78 /// In general that's true for any expression, but in some cases that would produce invalid code.
79 fn valid_target_expr(node: &SyntaxNode) -> Option<&ast::Expr> {
80     match node.kind() {
81         PATH_EXPR => None,
82         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
83         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
84         LOOP_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
85         _ => ast::Expr::cast(node),
86     }
87 }
88
89 /// Returns the syntax node which will follow the freshly introduced var
90 /// and a boolean indicating whether we have to wrap it within a { } block
91 /// to produce correct code.
92 /// It can be a statement, the last in a block expression or a wanna be block
93 /// expression like a lambda or match arm.
94 fn anchor_stmt(expr: &ast::Expr) -> Option<(&SyntaxNode, bool)> {
95     expr.syntax().ancestors().find_map(|node| {
96         if ast::Stmt::cast(node).is_some() {
97             return Some((node, false));
98         }
99
100         if let Some(expr) = node.parent().and_then(ast::Block::cast).and_then(|it| it.expr()) {
101             if expr.syntax() == node {
102                 return Some((node, false));
103             }
104         }
105
106         if let Some(parent) = node.parent() {
107             if parent.kind() == MATCH_ARM || parent.kind() == LAMBDA_EXPR {
108                 return Some((node, true));
109             }
110         }
111
112         None
113     })
114 }
115
116 #[cfg(test)]
117 mod tests {
118     use super::*;
119     use crate::helpers::{check_assist, check_assist_not_applicable, check_assist_range, check_assist_target, check_assist_range_target};
120
121     #[test]
122     fn test_introduce_var_simple() {
123         check_assist_range(
124             introduce_variable,
125             "
126 fn foo() {
127     foo(<|>1 + 1<|>);
128 }",
129             "
130 fn foo() {
131     let <|>var_name = 1 + 1;
132     foo(var_name);
133 }",
134         );
135     }
136
137     #[test]
138     fn test_introduce_var_expr_stmt() {
139         check_assist_range(
140             introduce_variable,
141             "
142 fn foo() {
143     <|>1 + 1<|>;
144 }",
145             "
146 fn foo() {
147     let <|>var_name = 1 + 1;
148 }",
149         );
150     }
151
152     #[test]
153     fn test_introduce_var_part_of_expr_stmt() {
154         check_assist_range(
155             introduce_variable,
156             "
157 fn foo() {
158     <|>1<|> + 1;
159 }",
160             "
161 fn foo() {
162     let <|>var_name = 1;
163     var_name + 1;
164 }",
165         );
166     }
167
168     #[test]
169     fn test_introduce_var_last_expr() {
170         check_assist_range(
171             introduce_variable,
172             "
173 fn foo() {
174     bar(<|>1 + 1<|>)
175 }",
176             "
177 fn foo() {
178     let <|>var_name = 1 + 1;
179     bar(var_name)
180 }",
181         );
182     }
183
184     #[test]
185     fn test_introduce_var_last_full_expr() {
186         check_assist_range(
187             introduce_variable,
188             "
189 fn foo() {
190     <|>bar(1 + 1)<|>
191 }",
192             "
193 fn foo() {
194     let <|>var_name = bar(1 + 1);
195     var_name
196 }",
197         );
198     }
199
200     #[test]
201     fn test_introduce_var_block_expr_second_to_last() {
202         check_assist_range(
203             introduce_variable,
204             "
205 fn foo() {
206     <|>{ let x = 0; x }<|>
207     something_else();
208 }",
209             "
210 fn foo() {
211     let <|>var_name = { let x = 0; x };
212     something_else();
213 }",
214         );
215     }
216
217     #[test]
218     fn test_introduce_var_in_match_arm_no_block() {
219         check_assist_range(
220             introduce_variable,
221             "
222 fn main() {
223     let x = true;
224     let tuple = match x {
225         true => (<|>2 + 2<|>, true)
226         _ => (0, false)
227     };
228 }
229 ",
230             "
231 fn main() {
232     let x = true;
233     let tuple = match x {
234         true => { let <|>var_name = 2 + 2; (var_name, true) }
235         _ => (0, false)
236     };
237 }
238 ",
239         );
240     }
241
242     #[test]
243     fn test_introduce_var_in_match_arm_with_block() {
244         check_assist_range(
245             introduce_variable,
246             "
247 fn main() {
248     let x = true;
249     let tuple = match x {
250         true => {
251             let y = 1;
252             (<|>2 + y<|>, true)
253         }
254         _ => (0, false)
255     };
256 }
257 ",
258             "
259 fn main() {
260     let x = true;
261     let tuple = match x {
262         true => {
263             let y = 1;
264             let <|>var_name = 2 + y;
265             (var_name, true)
266         }
267         _ => (0, false)
268     };
269 }
270 ",
271         );
272     }
273
274     #[test]
275     fn test_introduce_var_in_closure_no_block() {
276         check_assist_range(
277             introduce_variable,
278             "
279 fn main() {
280     let lambda = |x: u32| <|>x * 2<|>;
281 }
282 ",
283             "
284 fn main() {
285     let lambda = |x: u32| { let <|>var_name = x * 2; var_name };
286 }
287 ",
288         );
289     }
290
291     #[test]
292     fn test_introduce_var_in_closure_with_block() {
293         check_assist_range(
294             introduce_variable,
295             "
296 fn main() {
297     let lambda = |x: u32| { <|>x * 2<|> };
298 }
299 ",
300             "
301 fn main() {
302     let lambda = |x: u32| { let <|>var_name = x * 2; var_name };
303 }
304 ",
305         );
306     }
307
308     #[test]
309     fn test_introduce_var_path_simple() {
310         check_assist(
311             introduce_variable,
312             "
313 fn main() {
314     let o = S<|>ome(true);
315 }
316 ",
317             "
318 fn main() {
319     let <|>var_name = Some(true);
320     let o = var_name;
321 }
322 ",
323         );
324     }
325
326     #[test]
327     fn test_introduce_var_path_method() {
328         check_assist(
329             introduce_variable,
330             "
331 fn main() {
332     let v = b<|>ar.foo();
333 }
334 ",
335             "
336 fn main() {
337     let <|>var_name = bar.foo();
338     let v = var_name;
339 }
340 ",
341         );
342     }
343
344     #[test]
345     fn test_introduce_var_return() {
346         check_assist(
347             introduce_variable,
348             "
349 fn foo() -> u32 {
350     r<|>eturn 2 + 2;
351 }
352 ",
353             "
354 fn foo() -> u32 {
355     let <|>var_name = 2 + 2;
356     return var_name;
357 }
358 ",
359         );
360     }
361
362     #[test]
363     fn test_introduce_var_does_not_add_extra_whitespace() {
364         check_assist(
365             introduce_variable,
366             "
367 fn foo() -> u32 {
368
369
370     r<|>eturn 2 + 2;
371 }
372 ",
373             "
374 fn foo() -> u32 {
375
376
377     let <|>var_name = 2 + 2;
378     return var_name;
379 }
380 ",
381         );
382
383         check_assist(
384             introduce_variable,
385             "
386 fn foo() -> u32 {
387
388         r<|>eturn 2 + 2;
389 }
390 ",
391             "
392 fn foo() -> u32 {
393
394         let <|>var_name = 2 + 2;
395         return var_name;
396 }
397 ",
398         );
399
400         check_assist(
401             introduce_variable,
402             "
403 fn foo() -> u32 {
404     let foo = 1;
405
406     // bar
407
408
409     r<|>eturn 2 + 2;
410 }
411 ",
412             "
413 fn foo() -> u32 {
414     let foo = 1;
415
416     // bar
417
418
419     let <|>var_name = 2 + 2;
420     return var_name;
421 }
422 ",
423         );
424     }
425
426     #[test]
427     fn test_introduce_var_break() {
428         check_assist(
429             introduce_variable,
430             "
431 fn main() {
432     let result = loop {
433         b<|>reak 2 + 2;
434     };
435 }
436 ",
437             "
438 fn main() {
439     let result = loop {
440         let <|>var_name = 2 + 2;
441         break var_name;
442     };
443 }
444 ",
445         );
446     }
447
448     #[test]
449     fn test_introduce_var_for_cast() {
450         check_assist(
451             introduce_variable,
452             "
453 fn main() {
454     let v = 0f32 a<|>s u32;
455 }
456 ",
457             "
458 fn main() {
459     let <|>var_name = 0f32 as u32;
460     let v = var_name;
461 }
462 ",
463         );
464     }
465
466     #[test]
467     fn test_introduce_var_for_return_not_applicable() {
468         check_assist_not_applicable(
469             introduce_variable,
470             "
471 fn foo() {
472     r<|>eturn;
473 }
474 ",
475         );
476     }
477
478     #[test]
479     fn test_introduce_var_for_break_not_applicable() {
480         check_assist_not_applicable(
481             introduce_variable,
482             "
483 fn main() {
484     loop {
485         b<|>reak;
486     };
487 }
488 ",
489         );
490     }
491
492     #[test]
493     fn test_introduce_var_in_comment_not_applicable() {
494         check_assist_not_applicable(
495             introduce_variable,
496             "
497 fn main() {
498     let x = true;
499     let tuple = match x {
500         // c<|>omment
501         true => (2 + 2, true)
502         _ => (0, false)
503     };
504 }
505 ",
506         );
507     }
508
509     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
510     #[test]
511     fn introduce_var_target() {
512         check_assist_target(
513             introduce_variable,
514             "
515 fn foo() -> u32 {
516     r<|>eturn 2 + 2;
517 }
518 ",
519             "2 + 2",
520         );
521
522         check_assist_range_target(
523             introduce_variable,
524             "
525 fn main() {
526     let x = true;
527     let tuple = match x {
528         true => (<|>2 + 2<|>, true)
529         _ => (0, false)
530     };
531 }
532 ",
533             "2 + 2",
534         );
535     }
536 }