]> git.lizzy.rs Git - rust.git/blob - crates/assists/src/utils/insert_use.rs
6d110aaaf728973db3edd39511b3ad4bfc81f8f1
[rust.git] / crates / assists / src / utils / insert_use.rs
1 //! Handle syntactic aspects of inserting a new `use`.
2 use std::iter::{self, successors};
3
4 use algo::skip_trivia_token;
5 use ast::{
6     edit::{AstNodeEdit, IndentLevel},
7     PathSegmentKind, VisibilityOwner,
8 };
9 use syntax::{
10     algo,
11     ast::{self, make, AstNode},
12     Direction, InsertPosition, SyntaxElement, SyntaxNode, T,
13 };
14 use test_utils::mark;
15
16 #[derive(Debug)]
17 pub enum ImportScope {
18     File(ast::SourceFile),
19     Module(ast::ItemList),
20 }
21
22 impl ImportScope {
23     pub fn from(syntax: SyntaxNode) -> Option<Self> {
24         if let Some(module) = ast::Module::cast(syntax.clone()) {
25             module.item_list().map(ImportScope::Module)
26         } else if let this @ Some(_) = ast::SourceFile::cast(syntax.clone()) {
27             this.map(ImportScope::File)
28         } else {
29             ast::ItemList::cast(syntax).map(ImportScope::Module)
30         }
31     }
32
33     /// Determines the containing syntax node in which to insert a `use` statement affecting `position`.
34     pub(crate) fn find_insert_use_container(
35         position: &SyntaxNode,
36         ctx: &crate::assist_context::AssistContext,
37     ) -> Option<Self> {
38         ctx.sema.ancestors_with_macros(position.clone()).find_map(Self::from)
39     }
40
41     pub(crate) fn as_syntax_node(&self) -> &SyntaxNode {
42         match self {
43             ImportScope::File(file) => file.syntax(),
44             ImportScope::Module(item_list) => item_list.syntax(),
45         }
46     }
47
48     fn indent_level(&self) -> IndentLevel {
49         match self {
50             ImportScope::File(file) => file.indent_level(),
51             ImportScope::Module(item_list) => item_list.indent_level() + 1,
52         }
53     }
54
55     fn first_insert_pos(&self) -> (InsertPosition<SyntaxElement>, AddBlankLine) {
56         match self {
57             ImportScope::File(_) => (InsertPosition::First, AddBlankLine::AfterTwice),
58             // don't insert the imports before the item list's opening curly brace
59             ImportScope::Module(item_list) => item_list
60                 .l_curly_token()
61                 .map(|b| (InsertPosition::After(b.into()), AddBlankLine::Around))
62                 .unwrap_or((InsertPosition::First, AddBlankLine::AfterTwice)),
63         }
64     }
65
66     fn insert_pos_after_inner_attribute(&self) -> (InsertPosition<SyntaxElement>, AddBlankLine) {
67         // check if the scope has inner attributes, we dont want to insert in front of them
68         match self
69             .as_syntax_node()
70             .children()
71             // no flat_map here cause we want to short circuit the iterator
72             .map(ast::Attr::cast)
73             .take_while(|attr| {
74                 attr.as_ref().map(|attr| attr.kind() == ast::AttrKind::Inner).unwrap_or(false)
75             })
76             .last()
77             .flatten()
78         {
79             Some(attr) => {
80                 (InsertPosition::After(attr.syntax().clone().into()), AddBlankLine::BeforeTwice)
81             }
82             None => self.first_insert_pos(),
83         }
84     }
85 }
86
87 /// Insert an import path into the given file/node. A `merge` value of none indicates that no import merging is allowed to occur.
88 pub(crate) fn insert_use(
89     scope: &ImportScope,
90     path: ast::Path,
91     merge: Option<MergeBehaviour>,
92 ) -> SyntaxNode {
93     let use_item = make::use_(make::use_tree(path.clone(), None, None, false));
94     // merge into existing imports if possible
95     if let Some(mb) = merge {
96         for existing_use in scope.as_syntax_node().children().filter_map(ast::Use::cast) {
97             if let Some(merged) = try_merge_imports(&existing_use, &use_item, mb) {
98                 let to_delete: SyntaxElement = existing_use.syntax().clone().into();
99                 let to_delete = to_delete.clone()..=to_delete;
100                 let to_insert = iter::once(merged.syntax().clone().into());
101                 return algo::replace_children(scope.as_syntax_node(), to_delete, to_insert);
102             }
103         }
104     }
105
106     // either we weren't allowed to merge or there is no import that fits the merge conditions
107     // so look for the place we have to insert to
108     let (insert_position, add_blank) = find_insert_position(scope, path);
109
110     let to_insert: Vec<SyntaxElement> = {
111         let mut buf = Vec::new();
112
113         match add_blank {
114             AddBlankLine::Before | AddBlankLine::Around => {
115                 buf.push(make::tokens::single_newline().into())
116             }
117             AddBlankLine::BeforeTwice => buf.push(make::tokens::blank_line().into()),
118             _ => (),
119         }
120
121         if let ident_level @ 1..=usize::MAX = scope.indent_level().0 as usize {
122             // FIXME: this alone doesnt properly re-align all cases
123             buf.push(make::tokens::whitespace(&" ".repeat(4 * ident_level)).into());
124         }
125         buf.push(use_item.syntax().clone().into());
126
127         match add_blank {
128             AddBlankLine::After | AddBlankLine::Around => {
129                 buf.push(make::tokens::single_newline().into())
130             }
131             AddBlankLine::AfterTwice => buf.push(make::tokens::blank_line().into()),
132             _ => (),
133         }
134
135         buf
136     };
137
138     algo::insert_children(scope.as_syntax_node(), insert_position, to_insert)
139 }
140
141 fn eq_visibility(vis0: Option<ast::Visibility>, vis1: Option<ast::Visibility>) -> bool {
142     match (vis0, vis1) {
143         (None, None) => true,
144         // FIXME: Don't use the string representation to check for equality
145         // spaces inside of the node would break this comparison
146         (Some(vis0), Some(vis1)) => vis0.to_string() == vis1.to_string(),
147         _ => false,
148     }
149 }
150
151 pub(crate) fn try_merge_imports(
152     old: &ast::Use,
153     new: &ast::Use,
154     merge_behaviour: MergeBehaviour,
155 ) -> Option<ast::Use> {
156     // don't merge imports with different visibilities
157     if !eq_visibility(old.visibility(), new.visibility()) {
158         return None;
159     }
160     let old_tree = old.use_tree()?;
161     let new_tree = new.use_tree()?;
162     let merged = try_merge_trees(&old_tree, &new_tree, merge_behaviour)?;
163     Some(old.with_use_tree(merged))
164 }
165
166 /// Simple function that checks if a UseTreeList is deeper than one level
167 fn use_tree_list_is_nested(tl: &ast::UseTreeList) -> bool {
168     tl.use_trees().any(|use_tree| {
169         use_tree.use_tree_list().is_some() || use_tree.path().and_then(|p| p.qualifier()).is_some()
170     })
171 }
172
173 // FIXME: currently this merely prepends the new tree into old, ideally it would insert the items in a sorted fashion
174 pub(crate) fn try_merge_trees(
175     old: &ast::UseTree,
176     new: &ast::UseTree,
177     merge_behaviour: MergeBehaviour,
178 ) -> Option<ast::UseTree> {
179     let lhs_path = old.path()?;
180     let rhs_path = new.path()?;
181
182     let (lhs_prefix, rhs_prefix) = common_prefix(&lhs_path, &rhs_path)?;
183     let lhs = old.split_prefix(&lhs_prefix);
184     let rhs = new.split_prefix(&rhs_prefix);
185     let lhs_tl = lhs.use_tree_list()?;
186     let rhs_tl = rhs.use_tree_list()?;
187
188     // if we are only allowed to merge the last level check if the split off paths are only one level deep
189     if merge_behaviour == MergeBehaviour::Last
190         && (use_tree_list_is_nested(&lhs_tl) || use_tree_list_is_nested(&rhs_tl))
191     {
192         mark::hit!(test_last_merge_too_long);
193         return None;
194     }
195
196     let should_insert_comma = lhs_tl
197         .r_curly_token()
198         .and_then(|it| skip_trivia_token(it.prev_token()?, Direction::Prev))
199         .map(|it| it.kind())
200         != Some(T![,]);
201     let mut to_insert: Vec<SyntaxElement> = Vec::new();
202     if should_insert_comma {
203         to_insert.push(make::token(T![,]).into());
204         to_insert.push(make::tokens::single_space().into());
205     }
206     to_insert.extend(
207         rhs_tl.syntax().children_with_tokens().filter(|it| !matches!(it.kind(), T!['{'] | T!['}'])),
208     );
209     let pos = InsertPosition::Before(lhs_tl.r_curly_token()?.into());
210     let use_tree_list = lhs_tl.insert_children(pos, to_insert);
211     Some(lhs.with_use_tree_list(use_tree_list))
212 }
213
214 /// Traverses both paths until they differ, returning the common prefix of both.
215 fn common_prefix(lhs: &ast::Path, rhs: &ast::Path) -> Option<(ast::Path, ast::Path)> {
216     let mut res = None;
217     let mut lhs_curr = first_path(&lhs);
218     let mut rhs_curr = first_path(&rhs);
219     loop {
220         match (lhs_curr.segment(), rhs_curr.segment()) {
221             (Some(lhs), Some(rhs)) if lhs.syntax().text() == rhs.syntax().text() => (),
222             _ => break,
223         }
224         res = Some((lhs_curr.clone(), rhs_curr.clone()));
225
226         match lhs_curr.parent_path().zip(rhs_curr.parent_path()) {
227             Some((lhs, rhs)) => {
228                 lhs_curr = lhs;
229                 rhs_curr = rhs;
230             }
231             _ => break,
232         }
233     }
234
235     res
236 }
237
238 /// What type of merges are allowed.
239 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
240 pub enum MergeBehaviour {
241     /// Merge everything together creating deeply nested imports.
242     Full,
243     /// Only merge the last import level, doesn't allow import nesting.
244     Last,
245 }
246
247 #[derive(Eq, PartialEq, PartialOrd, Ord)]
248 enum ImportGroup {
249     // the order here defines the order of new group inserts
250     Std,
251     ExternCrate,
252     ThisCrate,
253     ThisModule,
254     SuperModule,
255 }
256
257 impl ImportGroup {
258     fn new(path: &ast::Path) -> ImportGroup {
259         let default = ImportGroup::ExternCrate;
260
261         let first_segment = match first_segment(path) {
262             Some(it) => it,
263             None => return default,
264         };
265
266         let kind = first_segment.kind().unwrap_or(PathSegmentKind::SelfKw);
267         match kind {
268             PathSegmentKind::SelfKw => ImportGroup::ThisModule,
269             PathSegmentKind::SuperKw => ImportGroup::SuperModule,
270             PathSegmentKind::CrateKw => ImportGroup::ThisCrate,
271             PathSegmentKind::Name(name) => match name.text().as_str() {
272                 "std" => ImportGroup::Std,
273                 "core" => ImportGroup::Std,
274                 // FIXME: can be ThisModule as well
275                 _ => ImportGroup::ExternCrate,
276             },
277             PathSegmentKind::Type { .. } => unreachable!(),
278         }
279     }
280 }
281
282 fn first_segment(path: &ast::Path) -> Option<ast::PathSegment> {
283     first_path(path).segment()
284 }
285
286 fn first_path(path: &ast::Path) -> ast::Path {
287     successors(Some(path.clone()), ast::Path::qualifier).last().unwrap()
288 }
289
290 fn segment_iter(path: &ast::Path) -> impl Iterator<Item = ast::PathSegment> + Clone {
291     // cant make use of SyntaxNode::siblings, because the returned Iterator is not clone
292     successors(first_segment(path), |p| p.parent_path().parent_path().and_then(|p| p.segment()))
293 }
294
295 #[derive(PartialEq, Eq)]
296 enum AddBlankLine {
297     Before,
298     BeforeTwice,
299     Around,
300     After,
301     AfterTwice,
302 }
303
304 fn find_insert_position(
305     scope: &ImportScope,
306     insert_path: ast::Path,
307 ) -> (InsertPosition<SyntaxElement>, AddBlankLine) {
308     let group = ImportGroup::new(&insert_path);
309     let path_node_iter = scope
310         .as_syntax_node()
311         .children()
312         .filter_map(|node| ast::Use::cast(node.clone()).zip(Some(node)))
313         .flat_map(|(use_, node)| use_.use_tree().and_then(|tree| tree.path()).zip(Some(node)));
314     // Iterator that discards anything thats not in the required grouping
315     // This implementation allows the user to rearrange their import groups as this only takes the first group that fits
316     let group_iter = path_node_iter
317         .clone()
318         .skip_while(|(path, _)| ImportGroup::new(path) != group)
319         .take_while(|(path, _)| ImportGroup::new(path) == group);
320
321     let segments = segment_iter(&insert_path);
322     // track the last element we iterated over, if this is still None after the iteration then that means we never iterated in the first place
323     let mut last = None;
324     // find the element that would come directly after our new import
325     let post_insert =
326         group_iter.inspect(|(_, node)| last = Some(node.clone())).find(|(path, _)| {
327             let check_segments = segment_iter(&path);
328             segments
329                 .clone()
330                 .zip(check_segments)
331                 .flat_map(|(seg, seg2)| seg.name_ref().zip(seg2.name_ref()))
332                 .all(|(l, r)| l.text() <= r.text())
333         });
334     match post_insert {
335         // insert our import before that element
336         Some((_, node)) => (InsertPosition::Before(node.into()), AddBlankLine::After),
337         // there is no element after our new import, so append it to the end of the group
338         None => match last {
339             Some(node) => (InsertPosition::After(node.into()), AddBlankLine::Before),
340             // the group we were looking for actually doesnt exist, so insert
341             None => {
342                 // similar concept here to the `last` from above
343                 let mut last = None;
344                 // find the group that comes after where we want to insert
345                 let post_group = path_node_iter
346                     .inspect(|(_, node)| last = Some(node.clone()))
347                     .find(|(p, _)| ImportGroup::new(p) > group);
348                 match post_group {
349                     Some((_, node)) => {
350                         (InsertPosition::Before(node.into()), AddBlankLine::AfterTwice)
351                     }
352                     // there is no such group, so append after the last one
353                     None => match last {
354                         Some(node) => {
355                             (InsertPosition::After(node.into()), AddBlankLine::BeforeTwice)
356                         }
357                         // there are no imports in this file at all
358                         None => scope.insert_pos_after_inner_attribute(),
359                     },
360                 }
361             }
362         },
363     }
364 }
365
366 #[cfg(test)]
367 mod tests {
368     use super::*;
369
370     use test_utils::assert_eq_text;
371
372     #[test]
373     fn insert_start() {
374         check_none(
375             "std::bar::AA",
376             r"
377 use std::bar::B;
378 use std::bar::D;
379 use std::bar::F;
380 use std::bar::G;",
381             r"
382 use std::bar::AA;
383 use std::bar::B;
384 use std::bar::D;
385 use std::bar::F;
386 use std::bar::G;",
387         )
388     }
389
390     #[test]
391     fn insert_middle() {
392         check_none(
393             "std::bar::EE",
394             r"
395 use std::bar::A;
396 use std::bar::D;
397 use std::bar::F;
398 use std::bar::G;",
399             r"
400 use std::bar::A;
401 use std::bar::D;
402 use std::bar::EE;
403 use std::bar::F;
404 use std::bar::G;",
405         )
406     }
407
408     #[test]
409     fn insert_end() {
410         check_none(
411             "std::bar::ZZ",
412             r"
413 use std::bar::A;
414 use std::bar::D;
415 use std::bar::F;
416 use std::bar::G;",
417             r"
418 use std::bar::A;
419 use std::bar::D;
420 use std::bar::F;
421 use std::bar::G;
422 use std::bar::ZZ;",
423         )
424     }
425
426     #[test]
427     fn insert_middle_nested() {
428         check_none(
429             "std::bar::EE",
430             r"
431 use std::bar::A;
432 use std::bar::{D, Z}; // example of weird imports due to user
433 use std::bar::F;
434 use std::bar::G;",
435             r"
436 use std::bar::A;
437 use std::bar::EE;
438 use std::bar::{D, Z}; // example of weird imports due to user
439 use std::bar::F;
440 use std::bar::G;",
441         )
442     }
443
444     #[test]
445     fn insert_middle_groups() {
446         check_none(
447             "foo::bar::GG",
448             r"
449 use std::bar::A;
450 use std::bar::D;
451
452 use foo::bar::F;
453 use foo::bar::H;",
454             r"
455 use std::bar::A;
456 use std::bar::D;
457
458 use foo::bar::F;
459 use foo::bar::GG;
460 use foo::bar::H;",
461         )
462     }
463
464     #[test]
465     fn insert_first_matching_group() {
466         check_none(
467             "foo::bar::GG",
468             r"
469 use foo::bar::A;
470 use foo::bar::D;
471
472 use std;
473
474 use foo::bar::F;
475 use foo::bar::H;",
476             r"
477 use foo::bar::A;
478 use foo::bar::D;
479 use foo::bar::GG;
480
481 use std;
482
483 use foo::bar::F;
484 use foo::bar::H;",
485         )
486     }
487
488     #[test]
489     fn insert_missing_group_std() {
490         check_none(
491             "std::fmt",
492             r"
493 use foo::bar::A;
494 use foo::bar::D;",
495             r"
496 use std::fmt;
497
498 use foo::bar::A;
499 use foo::bar::D;",
500         )
501     }
502
503     #[test]
504     fn insert_missing_group_self() {
505         check_none(
506             "self::fmt",
507             r"
508 use foo::bar::A;
509 use foo::bar::D;",
510             r"
511 use foo::bar::A;
512 use foo::bar::D;
513
514 use self::fmt;",
515         )
516     }
517
518     #[test]
519     fn insert_no_imports() {
520         check_full(
521             "foo::bar",
522             "fn main() {}",
523             r"use foo::bar;
524
525 fn main() {}",
526         )
527     }
528
529     #[test]
530     fn insert_empty_file() {
531         // empty files will get two trailing newlines
532         // this is due to the test case insert_no_imports above
533         check_full(
534             "foo::bar",
535             "",
536             r"use foo::bar;
537
538 ",
539         )
540     }
541
542     #[test]
543     fn insert_after_inner_attr() {
544         check_full(
545             "foo::bar",
546             r"#![allow(unused_imports)]",
547             r"#![allow(unused_imports)]
548
549 use foo::bar;",
550         )
551     }
552
553     #[test]
554     fn insert_after_inner_attr2() {
555         check_full(
556             "foo::bar",
557             r"#![allow(unused_imports)]
558
559 fn main() {}",
560             r"#![allow(unused_imports)]
561
562 use foo::bar;
563
564 fn main() {}",
565         )
566     }
567
568     #[test]
569     fn merge_groups() {
570         check_last("std::io", r"use std::fmt;", r"use std::{fmt, io};")
571     }
572
573     #[test]
574     fn merge_groups_last() {
575         check_last(
576             "std::io",
577             r"use std::fmt::{Result, Display};",
578             r"use std::fmt::{Result, Display};
579 use std::io;",
580         )
581     }
582
583     #[test]
584     fn merge_groups_full() {
585         check_full(
586             "std::io",
587             r"use std::fmt::{Result, Display};",
588             r"use std::{fmt::{Result, Display}, io};",
589         )
590     }
591
592     #[test]
593     fn merge_groups_long_full() {
594         check_full(
595             "std::foo::bar::Baz",
596             r"use std::foo::bar::Qux;",
597             r"use std::foo::bar::{Qux, Baz};",
598         )
599     }
600
601     #[test]
602     fn merge_groups_long_last() {
603         check_last(
604             "std::foo::bar::Baz",
605             r"use std::foo::bar::Qux;",
606             r"use std::foo::bar::{Qux, Baz};",
607         )
608     }
609
610     #[test]
611     fn merge_groups_long_full_list() {
612         check_full(
613             "std::foo::bar::Baz",
614             r"use std::foo::bar::{Qux, Quux};",
615             r"use std::foo::bar::{Qux, Quux, Baz};",
616         )
617     }
618
619     #[test]
620     fn merge_groups_long_last_list() {
621         check_last(
622             "std::foo::bar::Baz",
623             r"use std::foo::bar::{Qux, Quux};",
624             r"use std::foo::bar::{Qux, Quux, Baz};",
625         )
626     }
627
628     #[test]
629     fn merge_groups_long_full_nested() {
630         check_full(
631             "std::foo::bar::Baz",
632             r"use std::foo::bar::{Qux, quux::{Fez, Fizz}};",
633             r"use std::foo::bar::{Qux, quux::{Fez, Fizz}, Baz};",
634         )
635     }
636
637     #[test]
638     fn merge_groups_long_last_nested() {
639         check_last(
640             "std::foo::bar::Baz",
641             r"use std::foo::bar::{Qux, quux::{Fez, Fizz}};",
642             r"use std::foo::bar::Baz;
643 use std::foo::bar::{Qux, quux::{Fez, Fizz}};",
644         )
645     }
646
647     #[test]
648     fn merge_groups_skip_pub() {
649         check_full(
650             "std::io",
651             r"pub use std::fmt::{Result, Display};",
652             r"pub use std::fmt::{Result, Display};
653 use std::io;",
654         )
655     }
656
657     #[test]
658     fn merge_groups_skip_pub_crate() {
659         check_full(
660             "std::io",
661             r"pub(crate) use std::fmt::{Result, Display};",
662             r"pub(crate) use std::fmt::{Result, Display};
663 use std::io;",
664         )
665     }
666
667     #[test]
668     #[ignore] // FIXME: Support this
669     fn split_out_merge() {
670         check_last(
671             "std::fmt::Result",
672             r"use std::{fmt, io};",
673             r"use std::{self, fmt::Result};
674 use std::io;",
675         )
676     }
677
678     #[test]
679     fn merge_groups_self() {
680         check_full("std::fmt::Debug", r"use std::fmt;", r"use std::fmt::{self, Debug};")
681     }
682
683     #[test]
684     fn merge_self_glob() {
685         check_full(
686             "token::TokenKind",
687             r"use token::TokenKind::*;",
688             r"use token::TokenKind::{self::*, self};",
689         )
690     }
691
692     #[test]
693     fn merge_last_too_long() {
694         mark::check!(test_last_merge_too_long);
695         check_last(
696             "foo::bar",
697             r"use foo::bar::baz::Qux;",
698             r"use foo::bar;
699 use foo::bar::baz::Qux;",
700         );
701     }
702
703     #[test]
704     fn insert_short_before_long() {
705         check_none(
706             "foo::bar",
707             r"use foo::bar::baz::Qux;",
708             r"use foo::bar;
709 use foo::bar::baz::Qux;",
710         );
711     }
712
713     fn check(
714         path: &str,
715         ra_fixture_before: &str,
716         ra_fixture_after: &str,
717         mb: Option<MergeBehaviour>,
718     ) {
719         let file = super::ImportScope::from(
720             ast::SourceFile::parse(ra_fixture_before).tree().syntax().clone(),
721         )
722         .unwrap();
723         let path = ast::SourceFile::parse(&format!("use {};", path))
724             .tree()
725             .syntax()
726             .descendants()
727             .find_map(ast::Path::cast)
728             .unwrap();
729
730         let result = insert_use(&file, path, mb).to_string();
731         assert_eq_text!(&result, ra_fixture_after);
732     }
733
734     fn check_full(path: &str, ra_fixture_before: &str, ra_fixture_after: &str) {
735         check(path, ra_fixture_before, ra_fixture_after, Some(MergeBehaviour::Full))
736     }
737
738     fn check_last(path: &str, ra_fixture_before: &str, ra_fixture_after: &str) {
739         check(path, ra_fixture_before, ra_fixture_after, Some(MergeBehaviour::Last))
740     }
741
742     fn check_none(path: &str, ra_fixture_before: &str, ra_fixture_after: &str) {
743         check(path, ra_fixture_before, ra_fixture_after, None)
744     }
745 }