]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs
Auto merge of #101969 - reez12g:issue-101306, r=reez12g
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / extract_type_alias.rs
1 use either::Either;
2 use ide_db::syntax_helpers::node_ext::walk_ty;
3 use itertools::Itertools;
4 use syntax::{
5     ast::{self, edit::IndentLevel, AstNode, HasGenericParams, HasName},
6     match_ast,
7 };
8
9 use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11 // Assist: extract_type_alias
12 //
13 // Extracts the selected type as a type alias.
14 //
15 // ```
16 // struct S {
17 //     field: $0(u8, u8, u8)$0,
18 // }
19 // ```
20 // ->
21 // ```
22 // type $0Type = (u8, u8, u8);
23 //
24 // struct S {
25 //     field: Type,
26 // }
27 // ```
28 pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
29     if ctx.has_empty_selection() {
30         return None;
31     }
32
33     let ty = ctx.find_node_at_range::<ast::Type>()?;
34     let item = ty.syntax().ancestors().find_map(ast::Item::cast)?;
35     let assoc_owner = item.syntax().ancestors().nth(2).and_then(|it| {
36         match_ast! {
37             match it {
38                 ast::Trait(tr) => Some(Either::Left(tr)),
39                 ast::Impl(impl_) => Some(Either::Right(impl_)),
40                 _ => None,
41             }
42         }
43     });
44     let node = assoc_owner.as_ref().map_or_else(
45         || item.syntax(),
46         |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax),
47     );
48     let insert_pos = node.text_range().start();
49     let target = ty.syntax().text_range();
50
51     acc.add(
52         AssistId("extract_type_alias", AssistKind::RefactorExtract),
53         "Extract type as type alias",
54         target,
55         |builder| {
56             let mut known_generics = match item.generic_param_list() {
57                 Some(it) => it.generic_params().collect(),
58                 None => Vec::new(),
59             };
60             if let Some(it) = assoc_owner.as_ref().and_then(|it| match it {
61                 Either::Left(it) => it.generic_param_list(),
62                 Either::Right(it) => it.generic_param_list(),
63             }) {
64                 known_generics.extend(it.generic_params());
65             }
66             let generics = collect_used_generics(&ty, &known_generics);
67
68             let replacement = if !generics.is_empty() {
69                 format!(
70                     "Type<{}>",
71                     generics.iter().format_with(", ", |generic, f| {
72                         match generic {
73                             ast::GenericParam::ConstParam(cp) => f(&cp.name().unwrap()),
74                             ast::GenericParam::LifetimeParam(lp) => f(&lp.lifetime().unwrap()),
75                             ast::GenericParam::TypeParam(tp) => f(&tp.name().unwrap()),
76                         }
77                     })
78                 )
79             } else {
80                 String::from("Type")
81             };
82             builder.replace(target, replacement);
83
84             let indent = IndentLevel::from_node(node);
85             let generics = if !generics.is_empty() {
86                 format!("<{}>", generics.iter().format(", "))
87             } else {
88                 String::new()
89             };
90             match ctx.config.snippet_cap {
91                 Some(cap) => {
92                     builder.insert_snippet(
93                         cap,
94                         insert_pos,
95                         format!("type $0Type{} = {};\n\n{}", generics, ty, indent),
96                     );
97                 }
98                 None => {
99                     builder.insert(
100                         insert_pos,
101                         format!("type Type{} = {};\n\n{}", generics, ty, indent),
102                     );
103                 }
104             }
105         },
106     )
107 }
108
109 fn collect_used_generics<'gp>(
110     ty: &ast::Type,
111     known_generics: &'gp [ast::GenericParam],
112 ) -> Vec<&'gp ast::GenericParam> {
113     // can't use a closure -> closure here cause lifetime inference fails for that
114     fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ {
115         move |gp: &&ast::GenericParam| match gp {
116             ast::GenericParam::LifetimeParam(lp) => {
117                 lp.lifetime().map_or(false, |lt| lt.text() == text)
118             }
119             _ => false,
120         }
121     }
122
123     let mut generics = Vec::new();
124     walk_ty(ty, &mut |ty| match ty {
125         ast::Type::PathType(ty) => {
126             if let Some(path) = ty.path() {
127                 if let Some(name_ref) = path.as_single_name_ref() {
128                     if let Some(param) = known_generics.iter().find(|gp| {
129                         match gp {
130                             ast::GenericParam::ConstParam(cp) => cp.name(),
131                             ast::GenericParam::TypeParam(tp) => tp.name(),
132                             _ => None,
133                         }
134                         .map_or(false, |n| n.text() == name_ref.text())
135                     }) {
136                         generics.push(param);
137                     }
138                 }
139                 generics.extend(
140                     path.segments()
141                         .filter_map(|seg| seg.generic_arg_list())
142                         .flat_map(|it| it.generic_args())
143                         .filter_map(|it| match it {
144                             ast::GenericArg::LifetimeArg(lt) => {
145                                 let lt = lt.lifetime()?;
146                                 known_generics.iter().find(find_lifetime(&lt.text()))
147                             }
148                             _ => None,
149                         }),
150                 );
151             }
152         }
153         ast::Type::ImplTraitType(impl_ty) => {
154             if let Some(it) = impl_ty.type_bound_list() {
155                 generics.extend(
156                     it.bounds()
157                         .filter_map(|it| it.lifetime())
158                         .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
159                 );
160             }
161         }
162         ast::Type::DynTraitType(dyn_ty) => {
163             if let Some(it) = dyn_ty.type_bound_list() {
164                 generics.extend(
165                     it.bounds()
166                         .filter_map(|it| it.lifetime())
167                         .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
168                 );
169             }
170         }
171         ast::Type::RefType(ref_) => generics.extend(
172             ref_.lifetime().and_then(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
173         ),
174         ast::Type::ArrayType(ar) => {
175             if let Some(expr) = ar.expr() {
176                 if let ast::Expr::PathExpr(p) = expr {
177                     if let Some(path) = p.path() {
178                         if let Some(name_ref) = path.as_single_name_ref() {
179                             if let Some(param) = known_generics.iter().find(|gp| {
180                                 if let ast::GenericParam::ConstParam(cp) = gp {
181                                     cp.name().map_or(false, |n| n.text() == name_ref.text())
182                                 } else {
183                                     false
184                                 }
185                             }) {
186                                 generics.push(param);
187                             }
188                         }
189                     }
190                 }
191             }
192         }
193         _ => (),
194     });
195     // stable resort to lifetime, type, const
196     generics.sort_by_key(|gp| match gp {
197         ast::GenericParam::ConstParam(_) => 2,
198         ast::GenericParam::LifetimeParam(_) => 0,
199         ast::GenericParam::TypeParam(_) => 1,
200     });
201     generics
202 }
203
204 #[cfg(test)]
205 mod tests {
206     use crate::tests::{check_assist, check_assist_not_applicable};
207
208     use super::*;
209
210     #[test]
211     fn test_not_applicable_without_selection() {
212         check_assist_not_applicable(
213             extract_type_alias,
214             r"
215 struct S {
216     field: $0(u8, u8, u8),
217 }
218             ",
219         );
220     }
221
222     #[test]
223     fn test_simple_types() {
224         check_assist(
225             extract_type_alias,
226             r"
227 struct S {
228     field: $0u8$0,
229 }
230             ",
231             r#"
232 type $0Type = u8;
233
234 struct S {
235     field: Type,
236 }
237             "#,
238         );
239     }
240
241     #[test]
242     fn test_generic_type_arg() {
243         check_assist(
244             extract_type_alias,
245             r"
246 fn generic<T>() {}
247
248 fn f() {
249     generic::<$0()$0>();
250 }
251             ",
252             r#"
253 fn generic<T>() {}
254
255 type $0Type = ();
256
257 fn f() {
258     generic::<Type>();
259 }
260             "#,
261         );
262     }
263
264     #[test]
265     fn test_inner_type_arg() {
266         check_assist(
267             extract_type_alias,
268             r"
269 struct Vec<T> {}
270 struct S {
271     v: Vec<Vec<$0Vec<u8>$0>>,
272 }
273             ",
274             r#"
275 struct Vec<T> {}
276 type $0Type = Vec<u8>;
277
278 struct S {
279     v: Vec<Vec<Type>>,
280 }
281             "#,
282         );
283     }
284
285     #[test]
286     fn test_extract_inner_type() {
287         check_assist(
288             extract_type_alias,
289             r"
290 struct S {
291     field: ($0u8$0,),
292 }
293             ",
294             r#"
295 type $0Type = u8;
296
297 struct S {
298     field: (Type,),
299 }
300             "#,
301         );
302     }
303
304     #[test]
305     fn extract_from_impl_or_trait() {
306         // When invoked in an impl/trait, extracted type alias should be placed next to the
307         // impl/trait, not inside.
308         check_assist(
309             extract_type_alias,
310             r#"
311 impl S {
312     fn f() -> $0(u8, u8)$0 {}
313 }
314             "#,
315             r#"
316 type $0Type = (u8, u8);
317
318 impl S {
319     fn f() -> Type {}
320 }
321             "#,
322         );
323         check_assist(
324             extract_type_alias,
325             r#"
326 trait Tr {
327     fn f() -> $0(u8, u8)$0 {}
328 }
329             "#,
330             r#"
331 type $0Type = (u8, u8);
332
333 trait Tr {
334     fn f() -> Type {}
335 }
336             "#,
337         );
338     }
339
340     #[test]
341     fn indentation() {
342         check_assist(
343             extract_type_alias,
344             r#"
345 mod m {
346     fn f() -> $0u8$0 {}
347 }
348             "#,
349             r#"
350 mod m {
351     type $0Type = u8;
352
353     fn f() -> Type {}
354 }
355             "#,
356         );
357     }
358
359     #[test]
360     fn generics() {
361         check_assist(
362             extract_type_alias,
363             r#"
364 struct Struct<const C: usize>;
365 impl<'outer, Outer, const OUTER: usize> () {
366     fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ())$0) {}
367 }
368 "#,
369             r#"
370 struct Struct<const C: usize>;
371 type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ());
372
373 impl<'outer, Outer, const OUTER: usize> () {
374     fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {}
375 }
376 "#,
377         );
378     }
379
380     #[test]
381     fn issue_11197() {
382         check_assist(
383             extract_type_alias,
384             r#"
385 struct Foo<T, const N: usize>
386 where
387     [T; N]: Sized,
388 {
389     arr: $0[T; N]$0,
390 }
391             "#,
392             r#"
393 type $0Type<T, const N: usize> = [T; N];
394
395 struct Foo<T, const N: usize>
396 where
397     [T; N]: Sized,
398 {
399     arr: Type<T, N>,
400 }
401             "#,
402         );
403     }
404 }