]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_type_alias.rs
Merge #9790
[rust.git] / crates / ide_assists / src / handlers / extract_type_alias.rs
1 use either::Either;
2 use itertools::Itertools;
3 use syntax::{
4     ast::{self, edit::IndentLevel, AstNode, GenericParamsOwner, NameOwner},
5     match_ast,
6 };
7
8 use crate::{AssistContext, AssistId, AssistKind, Assists};
9
10 // Assist: extract_type_alias
11 //
12 // Extracts the selected type as a type alias.
13 //
14 // ```
15 // struct S {
16 //     field: $0(u8, u8, u8)$0,
17 // }
18 // ```
19 // ->
20 // ```
21 // type $0Type = (u8, u8, u8);
22 //
23 // struct S {
24 //     field: Type,
25 // }
26 // ```
27 pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
28     if ctx.frange.range.is_empty() {
29         return None;
30     }
31
32     let ty = ctx.find_node_at_range::<ast::Type>()?;
33     let item = ty.syntax().ancestors().find_map(ast::Item::cast)?;
34     let assoc_owner = item.syntax().ancestors().nth(2).and_then(|it| {
35         match_ast! {
36             match it {
37                 ast::Trait(tr) => Some(Either::Left(tr)),
38                 ast::Impl(impl_) => Some(Either::Right(impl_)),
39                 _ => None,
40             }
41         }
42     });
43     let node = assoc_owner.as_ref().map_or_else(
44         || item.syntax(),
45         |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax),
46     );
47     let insert_pos = node.text_range().start();
48     let target = ty.syntax().text_range();
49
50     acc.add(
51         AssistId("extract_type_alias", AssistKind::RefactorExtract),
52         "Extract type as type alias",
53         target,
54         |builder| {
55             let mut known_generics = match item.generic_param_list() {
56                 Some(it) => it.generic_params().collect(),
57                 None => Vec::new(),
58             };
59             if let Some(it) = assoc_owner.as_ref().and_then(|it| match it {
60                 Either::Left(it) => it.generic_param_list(),
61                 Either::Right(it) => it.generic_param_list(),
62             }) {
63                 known_generics.extend(it.generic_params());
64             }
65             let generics = collect_used_generics(&ty, &known_generics);
66
67             let replacement = if !generics.is_empty() {
68                 format!(
69                     "Type<{}>",
70                     generics.iter().format_with(", ", |generic, f| {
71                         match generic {
72                             ast::GenericParam::ConstParam(cp) => f(&cp.name().unwrap()),
73                             ast::GenericParam::LifetimeParam(lp) => f(&lp.lifetime().unwrap()),
74                             ast::GenericParam::TypeParam(tp) => f(&tp.name().unwrap()),
75                         }
76                     })
77                 )
78             } else {
79                 String::from("Type")
80             };
81             builder.replace(target, replacement);
82
83             let indent = IndentLevel::from_node(node);
84             let generics = if !generics.is_empty() {
85                 format!("<{}>", generics.iter().format(", "))
86             } else {
87                 String::new()
88             };
89             match ctx.config.snippet_cap {
90                 Some(cap) => {
91                     builder.insert_snippet(
92                         cap,
93                         insert_pos,
94                         format!("type $0Type{} = {};\n\n{}", generics, ty, indent),
95                     );
96                 }
97                 None => {
98                     builder.insert(
99                         insert_pos,
100                         format!("type Type{} = {};\n\n{}", generics, ty, indent),
101                     );
102                 }
103             }
104         },
105     )
106 }
107
108 fn collect_used_generics<'gp>(
109     ty: &ast::Type,
110     known_generics: &'gp [ast::GenericParam],
111 ) -> Vec<&'gp ast::GenericParam> {
112     // can't use a closure -> closure here cause lifetime inference fails for that
113     fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ {
114         move |gp: &&ast::GenericParam| match gp {
115             ast::GenericParam::LifetimeParam(lp) => {
116                 lp.lifetime().map_or(false, |lt| lt.text() == text)
117             }
118             _ => false,
119         }
120     }
121
122     let mut generics = Vec::new();
123     ty.walk(&mut |ty| match ty {
124         ast::Type::PathType(ty) => {
125             if let Some(path) = ty.path() {
126                 if let Some(name_ref) = path.as_single_name_ref() {
127                     if let Some(param) = known_generics.iter().find(|gp| {
128                         match gp {
129                             ast::GenericParam::ConstParam(cp) => cp.name(),
130                             ast::GenericParam::TypeParam(tp) => tp.name(),
131                             _ => None,
132                         }
133                         .map_or(false, |n| n.text() == name_ref.text())
134                     }) {
135                         generics.push(param);
136                     }
137                 }
138                 generics.extend(
139                     path.segments()
140                         .filter_map(|seg| seg.generic_arg_list())
141                         .flat_map(|it| it.generic_args())
142                         .filter_map(|it| match it {
143                             ast::GenericArg::LifetimeArg(lt) => {
144                                 let lt = lt.lifetime()?;
145                                 known_generics.iter().find(find_lifetime(&lt.text()))
146                             }
147                             _ => None,
148                         }),
149                 );
150             }
151         }
152         ast::Type::ImplTraitType(impl_ty) => {
153             if let Some(it) = impl_ty.type_bound_list() {
154                 generics.extend(
155                     it.bounds()
156                         .filter_map(|it| it.lifetime())
157                         .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
158                 );
159             }
160         }
161         ast::Type::DynTraitType(dyn_ty) => {
162             if let Some(it) = dyn_ty.type_bound_list() {
163                 generics.extend(
164                     it.bounds()
165                         .filter_map(|it| it.lifetime())
166                         .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
167                 );
168             }
169         }
170         ast::Type::RefType(ref_) => generics.extend(
171             ref_.lifetime().and_then(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
172         ),
173         _ => (),
174     });
175     // stable resort to lifetime, type, const
176     generics.sort_by_key(|gp| match gp {
177         ast::GenericParam::ConstParam(_) => 2,
178         ast::GenericParam::LifetimeParam(_) => 0,
179         ast::GenericParam::TypeParam(_) => 1,
180     });
181     generics
182 }
183
184 #[cfg(test)]
185 mod tests {
186     use crate::tests::{check_assist, check_assist_not_applicable};
187
188     use super::*;
189
190     #[test]
191     fn test_not_applicable_without_selection() {
192         check_assist_not_applicable(
193             extract_type_alias,
194             r"
195 struct S {
196     field: $0(u8, u8, u8),
197 }
198             ",
199         );
200     }
201
202     #[test]
203     fn test_simple_types() {
204         check_assist(
205             extract_type_alias,
206             r"
207 struct S {
208     field: $0u8$0,
209 }
210             ",
211             r#"
212 type $0Type = u8;
213
214 struct S {
215     field: Type,
216 }
217             "#,
218         );
219     }
220
221     #[test]
222     fn test_generic_type_arg() {
223         check_assist(
224             extract_type_alias,
225             r"
226 fn generic<T>() {}
227
228 fn f() {
229     generic::<$0()$0>();
230 }
231             ",
232             r#"
233 fn generic<T>() {}
234
235 type $0Type = ();
236
237 fn f() {
238     generic::<Type>();
239 }
240             "#,
241         );
242     }
243
244     #[test]
245     fn test_inner_type_arg() {
246         check_assist(
247             extract_type_alias,
248             r"
249 struct Vec<T> {}
250 struct S {
251     v: Vec<Vec<$0Vec<u8>$0>>,
252 }
253             ",
254             r#"
255 struct Vec<T> {}
256 type $0Type = Vec<u8>;
257
258 struct S {
259     v: Vec<Vec<Type>>,
260 }
261             "#,
262         );
263     }
264
265     #[test]
266     fn test_extract_inner_type() {
267         check_assist(
268             extract_type_alias,
269             r"
270 struct S {
271     field: ($0u8$0,),
272 }
273             ",
274             r#"
275 type $0Type = u8;
276
277 struct S {
278     field: (Type,),
279 }
280             "#,
281         );
282     }
283
284     #[test]
285     fn extract_from_impl_or_trait() {
286         // When invoked in an impl/trait, extracted type alias should be placed next to the
287         // impl/trait, not inside.
288         check_assist(
289             extract_type_alias,
290             r#"
291 impl S {
292     fn f() -> $0(u8, u8)$0 {}
293 }
294             "#,
295             r#"
296 type $0Type = (u8, u8);
297
298 impl S {
299     fn f() -> Type {}
300 }
301             "#,
302         );
303         check_assist(
304             extract_type_alias,
305             r#"
306 trait Tr {
307     fn f() -> $0(u8, u8)$0 {}
308 }
309             "#,
310             r#"
311 type $0Type = (u8, u8);
312
313 trait Tr {
314     fn f() -> Type {}
315 }
316             "#,
317         );
318     }
319
320     #[test]
321     fn indentation() {
322         check_assist(
323             extract_type_alias,
324             r#"
325 mod m {
326     fn f() -> $0u8$0 {}
327 }
328             "#,
329             r#"
330 mod m {
331     type $0Type = u8;
332
333     fn f() -> Type {}
334 }
335             "#,
336         );
337     }
338
339     #[test]
340     fn generics() {
341         check_assist(
342             extract_type_alias,
343             r#"
344 struct Struct<const C: usize>;
345 impl<'outer, Outer, const OUTER: usize> () {
346     fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ())$0) {}
347 }
348 "#,
349             r#"
350 struct Struct<const C: usize>;
351 type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ());
352
353 impl<'outer, Outer, const OUTER: usize> () {
354     fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {}
355 }
356 "#,
357         );
358     }
359 }