]> git.lizzy.rs Git - rust.git/commitdiff
extract_type_alias extracts generics correctly
authorLukas Wirth <lukastw97@gmail.com>
Thu, 5 Aug 2021 00:54:06 +0000 (02:54 +0200)
committerLukas Wirth <lukastw97@gmail.com>
Thu, 5 Aug 2021 00:54:06 +0000 (02:54 +0200)
crates/ide_assists/src/handlers/extract_type_alias.rs
crates/syntax/src/ast/node_ext.rs

index eac8857c6788b9421626a54c10f857b828782315..4913ac1e08ea8f413f78d1f4c5e26842efdf6922 100644 (file)
@@ -1,5 +1,7 @@
+use either::Either;
+use itertools::Itertools;
 use syntax::{
-    ast::{self, edit::IndentLevel, AstNode},
+    ast::{self, edit::IndentLevel, AstNode, GenericParamsOwner, NameOwner},
     match_ast,
 };
 
@@ -27,41 +29,158 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext) -> Opti
         return None;
     }
 
-    let node = ctx.find_node_at_range::<ast::Type>()?;
-    let item = ctx.find_node_at_offset::<ast::Item>()?;
-    let insert = match_ast! {
-        match (item.syntax().parent()?) {
-            ast::AssocItemList(it) => it.syntax().parent()?,
-            _ => item.syntax().clone(),
+    let ty = ctx.find_node_at_range::<ast::Type>()?;
+    let item = ty.syntax().ancestors().find_map(ast::Item::cast)?;
+    let assoc_owner = item.syntax().ancestors().nth(2).and_then(|it| {
+        match_ast! {
+            match it {
+                ast::Trait(tr) => Some(Either::Left(tr)),
+                ast::Impl(impl_) => Some(Either::Right(impl_)),
+                _ => None,
+            }
         }
-    };
-    let indent = IndentLevel::from_node(&insert);
-    let insert = insert.text_range().start();
-    let target = node.syntax().text_range();
+    });
+    let node = assoc_owner.as_ref().map_or_else(
+        || item.syntax(),
+        |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax),
+    );
+    let insert_pos = node.text_range().start();
+    let target = ty.syntax().text_range();
 
     acc.add(
         AssistId("extract_type_alias", AssistKind::RefactorExtract),
         "Extract type as type alias",
         target,
         |builder| {
-            builder.edit_file(ctx.frange.file_id);
-            builder.replace(target, "Type");
+            let mut known_generics = match item.generic_param_list() {
+                Some(it) => it.generic_params().collect(),
+                None => Vec::new(),
+            };
+            if let Some(it) = assoc_owner.as_ref().and_then(|it| match it {
+                Either::Left(it) => it.generic_param_list(),
+                Either::Right(it) => it.generic_param_list(),
+            }) {
+                known_generics.extend(it.generic_params());
+            }
+            let generics = collect_used_generics(&ty, &known_generics);
+
+            let replacement = if !generics.is_empty() {
+                format!(
+                    "Type<{}>",
+                    generics.iter().format_with(", ", |generic, f| {
+                        match generic {
+                            ast::GenericParam::ConstParam(cp) => f(&cp.name().unwrap()),
+                            ast::GenericParam::LifetimeParam(lp) => f(&lp.lifetime().unwrap()),
+                            ast::GenericParam::TypeParam(tp) => f(&tp.name().unwrap()),
+                        }
+                    })
+                )
+            } else {
+                String::from("Type")
+            };
+            builder.replace(target, replacement);
+
+            let indent = IndentLevel::from_node(node);
+            let generics = if !generics.is_empty() {
+                format!("<{}>", generics.iter().format(", "))
+            } else {
+                String::new()
+            };
             match ctx.config.snippet_cap {
                 Some(cap) => {
                     builder.insert_snippet(
                         cap,
-                        insert,
-                        format!("type $0Type = {};\n\n{}", node, indent),
+                        insert_pos,
+                        format!("type $0Type{} = {};\n\n{}", generics, ty, indent),
                     );
                 }
                 None => {
-                    builder.insert(insert, format!("type Type = {};\n\n{}", node, indent));
+                    builder.insert(
+                        insert_pos,
+                        format!("type Type{} = {};\n\n{}", generics, ty, indent),
+                    );
                 }
             }
         },
     )
 }
 
+fn collect_used_generics<'gp>(
+    ty: &ast::Type,
+    known_generics: &'gp [ast::GenericParam],
+) -> Vec<&'gp ast::GenericParam> {
+    // can't use a closure -> closure here cause lifetime inference fails for that
+    fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ {
+        move |gp: &&ast::GenericParam| match gp {
+            ast::GenericParam::LifetimeParam(lp) => {
+                lp.lifetime().map_or(false, |lt| lt.text() == text)
+            }
+            _ => false,
+        }
+    }
+
+    let mut generics = Vec::new();
+    ty.walk(&mut |ty| match ty {
+        ast::Type::PathType(ty) => {
+            if let Some(path) = ty.path() {
+                if let Some(name_ref) = path.as_single_name_ref() {
+                    if let Some(param) = known_generics.iter().find(|gp| {
+                        match gp {
+                            ast::GenericParam::ConstParam(cp) => cp.name(),
+                            ast::GenericParam::TypeParam(tp) => tp.name(),
+                            _ => None,
+                        }
+                        .map_or(false, |n| n.text() == name_ref.text())
+                    }) {
+                        generics.push(param);
+                    }
+                }
+                generics.extend(
+                    path.segments()
+                        .filter_map(|seg| seg.generic_arg_list())
+                        .flat_map(|it| it.generic_args())
+                        .filter_map(|it| match it {
+                            ast::GenericArg::LifetimeArg(lt) => {
+                                let lt = lt.lifetime()?;
+                                known_generics.iter().find(find_lifetime(&lt.text()))
+                            }
+                            _ => None,
+                        }),
+                );
+            }
+        }
+        ast::Type::ImplTraitType(impl_ty) => {
+            if let Some(it) = impl_ty.type_bound_list() {
+                generics.extend(
+                    it.bounds()
+                        .filter_map(|it| it.lifetime())
+                        .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
+                );
+            }
+        }
+        ast::Type::DynTraitType(dyn_ty) => {
+            if let Some(it) = dyn_ty.type_bound_list() {
+                generics.extend(
+                    it.bounds()
+                        .filter_map(|it| it.lifetime())
+                        .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
+                );
+            }
+        }
+        ast::Type::RefType(ref_) => generics.extend(
+            ref_.lifetime().and_then(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
+        ),
+        _ => (),
+    });
+    // stable resort to lifetime, type, const
+    generics.sort_by_key(|gp| match gp {
+        ast::GenericParam::ConstParam(_) => 2,
+        ast::GenericParam::LifetimeParam(_) => 0,
+        ast::GenericParam::TypeParam(_) => 1,
+    });
+    generics
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -216,4 +335,25 @@ fn f() -> Type {}
             "#,
         );
     }
+
+    #[test]
+    fn generics() {
+        check_assist(
+            extract_type_alias,
+            r#"
+struct Struct<const C: usize>;
+impl<'outer, Outer, const OUTER: usize> () {
+    fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ())$0) {}
+}
+"#,
+            r#"
+struct Struct<const C: usize>;
+type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ());
+
+impl<'outer, Outer, const OUTER: usize> () {
+    fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {}
+}
+"#,
+        );
+    }
 }
index 0a540d9cfb3098a191e704d5d17e74523af341a5..3030c881209c7c2edb13a61e6fcec385a5f31f3e 100644 (file)
@@ -8,7 +8,10 @@
 use rowan::{GreenNodeData, GreenTokenData, WalkEvent};
 
 use crate::{
-    ast::{self, support, AstChildren, AstNode, AstToken, AttrsOwner, NameOwner, SyntaxNode},
+    ast::{
+        self, support, AstChildren, AstNode, AstToken, AttrsOwner, GenericParamsOwner, NameOwner,
+        SyntaxNode,
+    },
     NodeOrToken, SmolStr, SyntaxElement, SyntaxToken, TokenText, T,
 };
 
@@ -593,6 +596,21 @@ pub fn kind(&self) -> StructKind {
     }
 }
 
+impl ast::Item {
+    pub fn generic_param_list(&self) -> Option<ast::GenericParamList> {
+        match self {
+            ast::Item::Enum(it) => it.generic_param_list(),
+            ast::Item::Fn(it) => it.generic_param_list(),
+            ast::Item::Impl(it) => it.generic_param_list(),
+            ast::Item::Struct(it) => it.generic_param_list(),
+            ast::Item::Trait(it) => it.generic_param_list(),
+            ast::Item::TypeAlias(it) => it.generic_param_list(),
+            ast::Item::Union(it) => it.generic_param_list(),
+            _ => None,
+        }
+    }
+}
+
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub enum FieldKind {
     Name(ast::NameRef),