]> git.lizzy.rs Git - rust.git/blob - crates/ra_assists/src/handlers/add_function.rs
Refactor assists API to be more convenient for adding new assists
[rust.git] / crates / ra_assists / src / handlers / add_function.rs
1 use hir::HirDisplay;
2 use ra_db::FileId;
3 use ra_syntax::{
4     ast::{self, edit::IndentLevel, ArgListOwner, AstNode, ModuleItemOwner},
5     SyntaxKind, SyntaxNode, TextSize,
6 };
7 use rustc_hash::{FxHashMap, FxHashSet};
8
9 use crate::{AssistContext, AssistId, Assists};
10
11 // Assist: add_function
12 //
13 // Adds a stub function with a signature matching the function under the cursor.
14 //
15 // ```
16 // struct Baz;
17 // fn baz() -> Baz { Baz }
18 // fn foo() {
19 //     bar<|>("", baz());
20 // }
21 //
22 // ```
23 // ->
24 // ```
25 // struct Baz;
26 // fn baz() -> Baz { Baz }
27 // fn foo() {
28 //     bar("", baz());
29 // }
30 //
31 // fn bar(arg: &str, baz: Baz) {
32 //     todo!()
33 // }
34 //
35 // ```
36 pub(crate) fn add_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
37     let path_expr: ast::PathExpr = ctx.find_node_at_offset()?;
38     let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?;
39     let path = path_expr.path()?;
40
41     if ctx.sema.resolve_path(&path).is_some() {
42         // The function call already resolves, no need to add a function
43         return None;
44     }
45
46     let target_module = if let Some(qualifier) = path.qualifier() {
47         if let Some(hir::PathResolution::Def(hir::ModuleDef::Module(module))) =
48             ctx.sema.resolve_path(&qualifier)
49         {
50             Some(module.definition_source(ctx.sema.db))
51         } else {
52             return None;
53         }
54     } else {
55         None
56     };
57
58     let function_builder = FunctionBuilder::from_call(&ctx, &call, &path, target_module)?;
59
60     let target = call.syntax().text_range();
61     acc.add(AssistId("add_function"), "Add function", target, |edit| {
62         let function_template = function_builder.render();
63         edit.set_file(function_template.file);
64         edit.set_cursor(function_template.cursor_offset);
65         edit.insert(function_template.insert_offset, function_template.fn_def.to_string());
66     })
67 }
68
69 struct FunctionTemplate {
70     insert_offset: TextSize,
71     cursor_offset: TextSize,
72     fn_def: ast::SourceFile,
73     file: FileId,
74 }
75
76 struct FunctionBuilder {
77     target: GeneratedFunctionTarget,
78     fn_name: ast::Name,
79     type_params: Option<ast::TypeParamList>,
80     params: ast::ParamList,
81     file: FileId,
82     needs_pub: bool,
83 }
84
85 impl FunctionBuilder {
86     /// Prepares a generated function that matches `call` in `generate_in`
87     /// (or as close to `call` as possible, if `generate_in` is `None`)
88     fn from_call(
89         ctx: &AssistContext,
90         call: &ast::CallExpr,
91         path: &ast::Path,
92         target_module: Option<hir::InFile<hir::ModuleSource>>,
93     ) -> Option<Self> {
94         let needs_pub = target_module.is_some();
95         let mut file = ctx.frange.file_id;
96         let target = if let Some(target_module) = target_module {
97             let (in_file, target) = next_space_for_fn_in_module(ctx.sema.db, target_module)?;
98             file = in_file;
99             target
100         } else {
101             next_space_for_fn_after_call_site(&call)?
102         };
103         let fn_name = fn_name(&path)?;
104         let (type_params, params) = fn_args(ctx, &call)?;
105         Some(Self { target, fn_name, type_params, params, file, needs_pub })
106     }
107
108     fn render(self) -> FunctionTemplate {
109         let placeholder_expr = ast::make::expr_todo();
110         let fn_body = ast::make::block_expr(vec![], Some(placeholder_expr));
111         let mut fn_def = ast::make::fn_def(self.fn_name, self.type_params, self.params, fn_body);
112         if self.needs_pub {
113             fn_def = ast::make::add_pub_crate_modifier(fn_def);
114         }
115
116         let (fn_def, insert_offset) = match self.target {
117             GeneratedFunctionTarget::BehindItem(it) => {
118                 let with_leading_blank_line = ast::make::add_leading_newlines(2, fn_def);
119                 let indented = IndentLevel::from_node(&it).increase_indent(with_leading_blank_line);
120                 (indented, it.text_range().end())
121             }
122             GeneratedFunctionTarget::InEmptyItemList(it) => {
123                 let indent_once = IndentLevel(1);
124                 let indent = IndentLevel::from_node(it.syntax());
125
126                 let fn_def = ast::make::add_leading_newlines(1, fn_def);
127                 let fn_def = indent_once.increase_indent(fn_def);
128                 let fn_def = ast::make::add_trailing_newlines(1, fn_def);
129                 let fn_def = indent.increase_indent(fn_def);
130                 (fn_def, it.syntax().text_range().start() + TextSize::of('{'))
131             }
132         };
133
134         let placeholder_expr =
135             fn_def.syntax().descendants().find_map(ast::MacroCall::cast).unwrap();
136         let cursor_offset_from_fn_start = placeholder_expr.syntax().text_range().start();
137         let cursor_offset = insert_offset + cursor_offset_from_fn_start;
138         FunctionTemplate { insert_offset, cursor_offset, fn_def, file: self.file }
139     }
140 }
141
142 enum GeneratedFunctionTarget {
143     BehindItem(SyntaxNode),
144     InEmptyItemList(ast::ItemList),
145 }
146
147 fn fn_name(call: &ast::Path) -> Option<ast::Name> {
148     let name = call.segment()?.syntax().to_string();
149     Some(ast::make::name(&name))
150 }
151
152 /// Computes the type variables and arguments required for the generated function
153 fn fn_args(
154     ctx: &AssistContext,
155     call: &ast::CallExpr,
156 ) -> Option<(Option<ast::TypeParamList>, ast::ParamList)> {
157     let mut arg_names = Vec::new();
158     let mut arg_types = Vec::new();
159     for arg in call.arg_list()?.args() {
160         let arg_name = match fn_arg_name(&arg) {
161             Some(name) => name,
162             None => String::from("arg"),
163         };
164         arg_names.push(arg_name);
165         arg_types.push(match fn_arg_type(ctx, &arg) {
166             Some(ty) => ty,
167             None => String::from("()"),
168         });
169     }
170     deduplicate_arg_names(&mut arg_names);
171     let params = arg_names.into_iter().zip(arg_types).map(|(name, ty)| ast::make::param(name, ty));
172     Some((None, ast::make::param_list(params)))
173 }
174
175 /// Makes duplicate argument names unique by appending incrementing numbers.
176 ///
177 /// ```
178 /// let mut names: Vec<String> =
179 ///     vec!["foo".into(), "foo".into(), "bar".into(), "baz".into(), "bar".into()];
180 /// deduplicate_arg_names(&mut names);
181 /// let expected: Vec<String> =
182 ///     vec!["foo_1".into(), "foo_2".into(), "bar_1".into(), "baz".into(), "bar_2".into()];
183 /// assert_eq!(names, expected);
184 /// ```
185 fn deduplicate_arg_names(arg_names: &mut Vec<String>) {
186     let arg_name_counts = arg_names.iter().fold(FxHashMap::default(), |mut m, name| {
187         *m.entry(name).or_insert(0) += 1;
188         m
189     });
190     let duplicate_arg_names: FxHashSet<String> = arg_name_counts
191         .into_iter()
192         .filter(|(_, count)| *count >= 2)
193         .map(|(name, _)| name.clone())
194         .collect();
195
196     let mut counter_per_name = FxHashMap::default();
197     for arg_name in arg_names.iter_mut() {
198         if duplicate_arg_names.contains(arg_name) {
199             let counter = counter_per_name.entry(arg_name.clone()).or_insert(1);
200             arg_name.push('_');
201             arg_name.push_str(&counter.to_string());
202             *counter += 1;
203         }
204     }
205 }
206
207 fn fn_arg_name(fn_arg: &ast::Expr) -> Option<String> {
208     match fn_arg {
209         ast::Expr::CastExpr(cast_expr) => fn_arg_name(&cast_expr.expr()?),
210         _ => Some(
211             fn_arg
212                 .syntax()
213                 .descendants()
214                 .filter(|d| ast::NameRef::can_cast(d.kind()))
215                 .last()?
216                 .to_string(),
217         ),
218     }
219 }
220
221 fn fn_arg_type(ctx: &AssistContext, fn_arg: &ast::Expr) -> Option<String> {
222     let ty = ctx.sema.type_of_expr(fn_arg)?;
223     if ty.is_unknown() {
224         return None;
225     }
226     Some(ty.display(ctx.sema.db).to_string())
227 }
228
229 /// Returns the position inside the current mod or file
230 /// directly after the current block
231 /// We want to write the generated function directly after
232 /// fns, impls or macro calls, but inside mods
233 fn next_space_for_fn_after_call_site(expr: &ast::CallExpr) -> Option<GeneratedFunctionTarget> {
234     let mut ancestors = expr.syntax().ancestors().peekable();
235     let mut last_ancestor: Option<SyntaxNode> = None;
236     while let Some(next_ancestor) = ancestors.next() {
237         match next_ancestor.kind() {
238             SyntaxKind::SOURCE_FILE => {
239                 break;
240             }
241             SyntaxKind::ITEM_LIST => {
242                 if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) {
243                     break;
244                 }
245             }
246             _ => {}
247         }
248         last_ancestor = Some(next_ancestor);
249     }
250     last_ancestor.map(GeneratedFunctionTarget::BehindItem)
251 }
252
253 fn next_space_for_fn_in_module(
254     db: &dyn hir::db::AstDatabase,
255     module: hir::InFile<hir::ModuleSource>,
256 ) -> Option<(FileId, GeneratedFunctionTarget)> {
257     let file = module.file_id.original_file(db);
258     let assist_item = match module.value {
259         hir::ModuleSource::SourceFile(it) => {
260             if let Some(last_item) = it.items().last() {
261                 GeneratedFunctionTarget::BehindItem(last_item.syntax().clone())
262             } else {
263                 GeneratedFunctionTarget::BehindItem(it.syntax().clone())
264             }
265         }
266         hir::ModuleSource::Module(it) => {
267             if let Some(last_item) = it.item_list().and_then(|it| it.items().last()) {
268                 GeneratedFunctionTarget::BehindItem(last_item.syntax().clone())
269             } else {
270                 GeneratedFunctionTarget::InEmptyItemList(it.item_list()?)
271             }
272         }
273     };
274     Some((file, assist_item))
275 }
276
277 #[cfg(test)]
278 mod tests {
279     use crate::tests::{check_assist, check_assist_not_applicable};
280
281     use super::*;
282
283     #[test]
284     fn add_function_with_no_args() {
285         check_assist(
286             add_function,
287             r"
288 fn foo() {
289     bar<|>();
290 }
291 ",
292             r"
293 fn foo() {
294     bar();
295 }
296
297 fn bar() {
298     <|>todo!()
299 }
300 ",
301         )
302     }
303
304     #[test]
305     fn add_function_from_method() {
306         // This ensures that the function is correctly generated
307         // in the next outer mod or file
308         check_assist(
309             add_function,
310             r"
311 impl Foo {
312     fn foo() {
313         bar<|>();
314     }
315 }
316 ",
317             r"
318 impl Foo {
319     fn foo() {
320         bar();
321     }
322 }
323
324 fn bar() {
325     <|>todo!()
326 }
327 ",
328         )
329     }
330
331     #[test]
332     fn add_function_directly_after_current_block() {
333         // The new fn should not be created at the end of the file or module
334         check_assist(
335             add_function,
336             r"
337 fn foo1() {
338     bar<|>();
339 }
340
341 fn foo2() {}
342 ",
343             r"
344 fn foo1() {
345     bar();
346 }
347
348 fn bar() {
349     <|>todo!()
350 }
351
352 fn foo2() {}
353 ",
354         )
355     }
356
357     #[test]
358     fn add_function_with_no_args_in_same_module() {
359         check_assist(
360             add_function,
361             r"
362 mod baz {
363     fn foo() {
364         bar<|>();
365     }
366 }
367 ",
368             r"
369 mod baz {
370     fn foo() {
371         bar();
372     }
373
374     fn bar() {
375         <|>todo!()
376     }
377 }
378 ",
379         )
380     }
381
382     #[test]
383     fn add_function_with_function_call_arg() {
384         check_assist(
385             add_function,
386             r"
387 struct Baz;
388 fn baz() -> Baz { todo!() }
389 fn foo() {
390     bar<|>(baz());
391 }
392 ",
393             r"
394 struct Baz;
395 fn baz() -> Baz { todo!() }
396 fn foo() {
397     bar(baz());
398 }
399
400 fn bar(baz: Baz) {
401     <|>todo!()
402 }
403 ",
404         );
405     }
406
407     #[test]
408     fn add_function_with_method_call_arg() {
409         check_assist(
410             add_function,
411             r"
412 struct Baz;
413 impl Baz {
414     fn foo(&self) -> Baz {
415         ba<|>r(self.baz())
416     }
417     fn baz(&self) -> Baz {
418         Baz
419     }
420 }
421 ",
422             r"
423 struct Baz;
424 impl Baz {
425     fn foo(&self) -> Baz {
426         bar(self.baz())
427     }
428     fn baz(&self) -> Baz {
429         Baz
430     }
431 }
432
433 fn bar(baz: Baz) {
434     <|>todo!()
435 }
436 ",
437         )
438     }
439
440     #[test]
441     fn add_function_with_string_literal_arg() {
442         check_assist(
443             add_function,
444             r#"
445 fn foo() {
446     <|>bar("bar")
447 }
448 "#,
449             r#"
450 fn foo() {
451     bar("bar")
452 }
453
454 fn bar(arg: &str) {
455     <|>todo!()
456 }
457 "#,
458         )
459     }
460
461     #[test]
462     fn add_function_with_char_literal_arg() {
463         check_assist(
464             add_function,
465             r#"
466 fn foo() {
467     <|>bar('x')
468 }
469 "#,
470             r#"
471 fn foo() {
472     bar('x')
473 }
474
475 fn bar(arg: char) {
476     <|>todo!()
477 }
478 "#,
479         )
480     }
481
482     #[test]
483     fn add_function_with_int_literal_arg() {
484         check_assist(
485             add_function,
486             r"
487 fn foo() {
488     <|>bar(42)
489 }
490 ",
491             r"
492 fn foo() {
493     bar(42)
494 }
495
496 fn bar(arg: i32) {
497     <|>todo!()
498 }
499 ",
500         )
501     }
502
503     #[test]
504     fn add_function_with_cast_int_literal_arg() {
505         check_assist(
506             add_function,
507             r"
508 fn foo() {
509     <|>bar(42 as u8)
510 }
511 ",
512             r"
513 fn foo() {
514     bar(42 as u8)
515 }
516
517 fn bar(arg: u8) {
518     <|>todo!()
519 }
520 ",
521         )
522     }
523
524     #[test]
525     fn name_of_cast_variable_is_used() {
526         // Ensures that the name of the cast type isn't used
527         // in the generated function signature.
528         check_assist(
529             add_function,
530             r"
531 fn foo() {
532     let x = 42;
533     bar<|>(x as u8)
534 }
535 ",
536             r"
537 fn foo() {
538     let x = 42;
539     bar(x as u8)
540 }
541
542 fn bar(x: u8) {
543     <|>todo!()
544 }
545 ",
546         )
547     }
548
549     #[test]
550     fn add_function_with_variable_arg() {
551         check_assist(
552             add_function,
553             r"
554 fn foo() {
555     let worble = ();
556     <|>bar(worble)
557 }
558 ",
559             r"
560 fn foo() {
561     let worble = ();
562     bar(worble)
563 }
564
565 fn bar(worble: ()) {
566     <|>todo!()
567 }
568 ",
569         )
570     }
571
572     #[test]
573     fn add_function_with_impl_trait_arg() {
574         check_assist(
575             add_function,
576             r"
577 trait Foo {}
578 fn foo() -> impl Foo {
579     todo!()
580 }
581 fn baz() {
582     <|>bar(foo())
583 }
584 ",
585             r"
586 trait Foo {}
587 fn foo() -> impl Foo {
588     todo!()
589 }
590 fn baz() {
591     bar(foo())
592 }
593
594 fn bar(foo: impl Foo) {
595     <|>todo!()
596 }
597 ",
598         )
599     }
600
601     #[test]
602     #[ignore]
603     // FIXME print paths properly to make this test pass
604     fn add_function_with_qualified_path_arg() {
605         check_assist(
606             add_function,
607             r"
608 mod Baz {
609     pub struct Bof;
610     pub fn baz() -> Bof { Bof }
611 }
612 mod Foo {
613     fn foo() {
614         <|>bar(super::Baz::baz())
615     }
616 }
617 ",
618             r"
619 mod Baz {
620     pub struct Bof;
621     pub fn baz() -> Bof { Bof }
622 }
623 mod Foo {
624     fn foo() {
625         bar(super::Baz::baz())
626     }
627
628     fn bar(baz: super::Baz::Bof) {
629         <|>todo!()
630     }
631 }
632 ",
633         )
634     }
635
636     #[test]
637     #[ignore]
638     // FIXME fix printing the generics of a `Ty` to make this test pass
639     fn add_function_with_generic_arg() {
640         check_assist(
641             add_function,
642             r"
643 fn foo<T>(t: T) {
644     <|>bar(t)
645 }
646 ",
647             r"
648 fn foo<T>(t: T) {
649     bar(t)
650 }
651
652 fn bar<T>(t: T) {
653     <|>todo!()
654 }
655 ",
656         )
657     }
658
659     #[test]
660     #[ignore]
661     // FIXME Fix function type printing to make this test pass
662     fn add_function_with_fn_arg() {
663         check_assist(
664             add_function,
665             r"
666 struct Baz;
667 impl Baz {
668     fn new() -> Self { Baz }
669 }
670 fn foo() {
671     <|>bar(Baz::new);
672 }
673 ",
674             r"
675 struct Baz;
676 impl Baz {
677     fn new() -> Self { Baz }
678 }
679 fn foo() {
680     bar(Baz::new);
681 }
682
683 fn bar(arg: fn() -> Baz) {
684     <|>todo!()
685 }
686 ",
687         )
688     }
689
690     #[test]
691     #[ignore]
692     // FIXME Fix closure type printing to make this test pass
693     fn add_function_with_closure_arg() {
694         check_assist(
695             add_function,
696             r"
697 fn foo() {
698     let closure = |x: i64| x - 1;
699     <|>bar(closure)
700 }
701 ",
702             r"
703 fn foo() {
704     let closure = |x: i64| x - 1;
705     bar(closure)
706 }
707
708 fn bar(closure: impl Fn(i64) -> i64) {
709     <|>todo!()
710 }
711 ",
712         )
713     }
714
715     #[test]
716     fn unresolveable_types_default_to_unit() {
717         check_assist(
718             add_function,
719             r"
720 fn foo() {
721     <|>bar(baz)
722 }
723 ",
724             r"
725 fn foo() {
726     bar(baz)
727 }
728
729 fn bar(baz: ()) {
730     <|>todo!()
731 }
732 ",
733         )
734     }
735
736     #[test]
737     fn arg_names_dont_overlap() {
738         check_assist(
739             add_function,
740             r"
741 struct Baz;
742 fn baz() -> Baz { Baz }
743 fn foo() {
744     <|>bar(baz(), baz())
745 }
746 ",
747             r"
748 struct Baz;
749 fn baz() -> Baz { Baz }
750 fn foo() {
751     bar(baz(), baz())
752 }
753
754 fn bar(baz_1: Baz, baz_2: Baz) {
755     <|>todo!()
756 }
757 ",
758         )
759     }
760
761     #[test]
762     fn arg_name_counters_start_at_1_per_name() {
763         check_assist(
764             add_function,
765             r#"
766 struct Baz;
767 fn baz() -> Baz { Baz }
768 fn foo() {
769     <|>bar(baz(), baz(), "foo", "bar")
770 }
771 "#,
772             r#"
773 struct Baz;
774 fn baz() -> Baz { Baz }
775 fn foo() {
776     bar(baz(), baz(), "foo", "bar")
777 }
778
779 fn bar(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) {
780     <|>todo!()
781 }
782 "#,
783         )
784     }
785
786     #[test]
787     fn add_function_in_module() {
788         check_assist(
789             add_function,
790             r"
791 mod bar {}
792
793 fn foo() {
794     bar::my_fn<|>()
795 }
796 ",
797             r"
798 mod bar {
799     pub(crate) fn my_fn() {
800         <|>todo!()
801     }
802 }
803
804 fn foo() {
805     bar::my_fn()
806 }
807 ",
808         )
809     }
810
811     #[test]
812     fn add_function_in_module_containing_other_items() {
813         check_assist(
814             add_function,
815             r"
816 mod bar {
817     fn something_else() {}
818 }
819
820 fn foo() {
821     bar::my_fn<|>()
822 }
823 ",
824             r"
825 mod bar {
826     fn something_else() {}
827
828     pub(crate) fn my_fn() {
829         <|>todo!()
830     }
831 }
832
833 fn foo() {
834     bar::my_fn()
835 }
836 ",
837         )
838     }
839
840     #[test]
841     fn add_function_in_nested_module() {
842         check_assist(
843             add_function,
844             r"
845 mod bar {
846     mod baz {}
847 }
848
849 fn foo() {
850     bar::baz::my_fn<|>()
851 }
852 ",
853             r"
854 mod bar {
855     mod baz {
856         pub(crate) fn my_fn() {
857             <|>todo!()
858         }
859     }
860 }
861
862 fn foo() {
863     bar::baz::my_fn()
864 }
865 ",
866         )
867     }
868
869     #[test]
870     fn add_function_in_another_file() {
871         check_assist(
872             add_function,
873             r"
874 //- /main.rs
875 mod foo;
876
877 fn main() {
878     foo::bar<|>()
879 }
880 //- /foo.rs
881 ",
882             r"
883
884
885 pub(crate) fn bar() {
886     <|>todo!()
887 }",
888         )
889     }
890
891     #[test]
892     fn add_function_not_applicable_if_function_already_exists() {
893         check_assist_not_applicable(
894             add_function,
895             r"
896 fn foo() {
897     bar<|>();
898 }
899
900 fn bar() {}
901 ",
902         )
903     }
904
905     #[test]
906     fn add_function_not_applicable_if_unresolved_variable_in_call_is_selected() {
907         check_assist_not_applicable(
908             // bar is resolved, but baz isn't.
909             // The assist is only active if the cursor is on an unresolved path,
910             // but the assist should only be offered if the path is a function call.
911             add_function,
912             r"
913 fn foo() {
914     bar(b<|>az);
915 }
916
917 fn bar(baz: ()) {}
918 ",
919         )
920     }
921
922     #[test]
923     fn add_function_not_applicable_if_function_path_not_singleton() {
924         // In the future this assist could be extended to generate functions
925         // if the path is in the same crate (or even the same workspace).
926         // For the beginning, I think this is fine.
927         check_assist_not_applicable(
928             add_function,
929             r"
930 fn foo() {
931     other_crate::bar<|>();
932 }
933         ",
934         )
935     }
936
937     #[test]
938     #[ignore]
939     fn create_method_with_no_args() {
940         check_assist(
941             add_function,
942             r"
943 struct Foo;
944 impl Foo {
945     fn foo(&self) {
946         self.bar()<|>;
947     }
948 }
949         ",
950             r"
951 struct Foo;
952 impl Foo {
953     fn foo(&self) {
954         self.bar();
955     }
956     fn bar(&self) {
957         todo!();
958     }
959 }
960         ",
961         )
962     }
963 }