]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs
Rollup merge of #104117 - crlf0710:update_feature_gate, r=jackh726
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / extract_type_alias.rs
1 use either::Either;
2 use ide_db::syntax_helpers::node_ext::walk_ty;
3 use syntax::{
4     ast::{self, edit::IndentLevel, make, AstNode, HasGenericParams, HasName},
5     match_ast,
6 };
7
8 use crate::{AssistContext, AssistId, AssistKind, Assists};
9
10 // Assist: extract_type_alias
11 //
12 // Extracts the selected type as a type alias.
13 //
14 // ```
15 // struct S {
16 //     field: $0(u8, u8, u8)$0,
17 // }
18 // ```
19 // ->
20 // ```
21 // type $0Type = (u8, u8, u8);
22 //
23 // struct S {
24 //     field: Type,
25 // }
26 // ```
27 pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
28     if ctx.has_empty_selection() {
29         return None;
30     }
31
32     let ty = ctx.find_node_at_range::<ast::Type>()?;
33     let item = ty.syntax().ancestors().find_map(ast::Item::cast)?;
34     let assoc_owner = item.syntax().ancestors().nth(2).and_then(|it| {
35         match_ast! {
36             match it {
37                 ast::Trait(tr) => Some(Either::Left(tr)),
38                 ast::Impl(impl_) => Some(Either::Right(impl_)),
39                 _ => None,
40             }
41         }
42     });
43     let node = assoc_owner.as_ref().map_or_else(
44         || item.syntax(),
45         |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax),
46     );
47     let insert_pos = node.text_range().start();
48     let target = ty.syntax().text_range();
49
50     acc.add(
51         AssistId("extract_type_alias", AssistKind::RefactorExtract),
52         "Extract type as type alias",
53         target,
54         |builder| {
55             let mut known_generics = match item.generic_param_list() {
56                 Some(it) => it.generic_params().collect(),
57                 None => Vec::new(),
58             };
59             if let Some(it) = assoc_owner.as_ref().and_then(|it| match it {
60                 Either::Left(it) => it.generic_param_list(),
61                 Either::Right(it) => it.generic_param_list(),
62             }) {
63                 known_generics.extend(it.generic_params());
64             }
65             let generics = collect_used_generics(&ty, &known_generics);
66             let generic_params =
67                 generics.map(|it| make::generic_param_list(it.into_iter().cloned()));
68
69             let ty_args = generic_params
70                 .as_ref()
71                 .map_or(String::new(), |it| it.to_generic_args().to_string());
72             let replacement = format!("Type{ty_args}");
73             builder.replace(target, replacement);
74
75             let indent = IndentLevel::from_node(node);
76             let generic_params = generic_params.map_or(String::new(), |it| it.to_string());
77             match ctx.config.snippet_cap {
78                 Some(cap) => {
79                     builder.insert_snippet(
80                         cap,
81                         insert_pos,
82                         format!("type $0Type{generic_params} = {ty};\n\n{indent}"),
83                     );
84                 }
85                 None => {
86                     builder.insert(
87                         insert_pos,
88                         format!("type Type{generic_params} = {ty};\n\n{indent}"),
89                     );
90                 }
91             }
92         },
93     )
94 }
95
96 fn collect_used_generics<'gp>(
97     ty: &ast::Type,
98     known_generics: &'gp [ast::GenericParam],
99 ) -> Option<Vec<&'gp ast::GenericParam>> {
100     // can't use a closure -> closure here cause lifetime inference fails for that
101     fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ {
102         move |gp: &&ast::GenericParam| match gp {
103             ast::GenericParam::LifetimeParam(lp) => {
104                 lp.lifetime().map_or(false, |lt| lt.text() == text)
105             }
106             _ => false,
107         }
108     }
109
110     let mut generics = Vec::new();
111     walk_ty(ty, &mut |ty| match ty {
112         ast::Type::PathType(ty) => {
113             if let Some(path) = ty.path() {
114                 if let Some(name_ref) = path.as_single_name_ref() {
115                     if let Some(param) = known_generics.iter().find(|gp| {
116                         match gp {
117                             ast::GenericParam::ConstParam(cp) => cp.name(),
118                             ast::GenericParam::TypeParam(tp) => tp.name(),
119                             _ => None,
120                         }
121                         .map_or(false, |n| n.text() == name_ref.text())
122                     }) {
123                         generics.push(param);
124                     }
125                 }
126                 generics.extend(
127                     path.segments()
128                         .filter_map(|seg| seg.generic_arg_list())
129                         .flat_map(|it| it.generic_args())
130                         .filter_map(|it| match it {
131                             ast::GenericArg::LifetimeArg(lt) => {
132                                 let lt = lt.lifetime()?;
133                                 known_generics.iter().find(find_lifetime(&lt.text()))
134                             }
135                             _ => None,
136                         }),
137                 );
138             }
139         }
140         ast::Type::ImplTraitType(impl_ty) => {
141             if let Some(it) = impl_ty.type_bound_list() {
142                 generics.extend(
143                     it.bounds()
144                         .filter_map(|it| it.lifetime())
145                         .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
146                 );
147             }
148         }
149         ast::Type::DynTraitType(dyn_ty) => {
150             if let Some(it) = dyn_ty.type_bound_list() {
151                 generics.extend(
152                     it.bounds()
153                         .filter_map(|it| it.lifetime())
154                         .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
155                 );
156             }
157         }
158         ast::Type::RefType(ref_) => generics.extend(
159             ref_.lifetime().and_then(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
160         ),
161         ast::Type::ArrayType(ar) => {
162             if let Some(expr) = ar.expr() {
163                 if let ast::Expr::PathExpr(p) = expr {
164                     if let Some(path) = p.path() {
165                         if let Some(name_ref) = path.as_single_name_ref() {
166                             if let Some(param) = known_generics.iter().find(|gp| {
167                                 if let ast::GenericParam::ConstParam(cp) = gp {
168                                     cp.name().map_or(false, |n| n.text() == name_ref.text())
169                                 } else {
170                                     false
171                                 }
172                             }) {
173                                 generics.push(param);
174                             }
175                         }
176                     }
177                 }
178             }
179         }
180         _ => (),
181     });
182     // stable resort to lifetime, type, const
183     generics.sort_by_key(|gp| match gp {
184         ast::GenericParam::ConstParam(_) => 2,
185         ast::GenericParam::LifetimeParam(_) => 0,
186         ast::GenericParam::TypeParam(_) => 1,
187     });
188
189     Some(generics).filter(|it| it.len() > 0)
190 }
191
192 #[cfg(test)]
193 mod tests {
194     use crate::tests::{check_assist, check_assist_not_applicable};
195
196     use super::*;
197
198     #[test]
199     fn test_not_applicable_without_selection() {
200         check_assist_not_applicable(
201             extract_type_alias,
202             r"
203 struct S {
204     field: $0(u8, u8, u8),
205 }
206             ",
207         );
208     }
209
210     #[test]
211     fn test_simple_types() {
212         check_assist(
213             extract_type_alias,
214             r"
215 struct S {
216     field: $0u8$0,
217 }
218             ",
219             r#"
220 type $0Type = u8;
221
222 struct S {
223     field: Type,
224 }
225             "#,
226         );
227     }
228
229     #[test]
230     fn test_generic_type_arg() {
231         check_assist(
232             extract_type_alias,
233             r"
234 fn generic<T>() {}
235
236 fn f() {
237     generic::<$0()$0>();
238 }
239             ",
240             r#"
241 fn generic<T>() {}
242
243 type $0Type = ();
244
245 fn f() {
246     generic::<Type>();
247 }
248             "#,
249         );
250     }
251
252     #[test]
253     fn test_inner_type_arg() {
254         check_assist(
255             extract_type_alias,
256             r"
257 struct Vec<T> {}
258 struct S {
259     v: Vec<Vec<$0Vec<u8>$0>>,
260 }
261             ",
262             r#"
263 struct Vec<T> {}
264 type $0Type = Vec<u8>;
265
266 struct S {
267     v: Vec<Vec<Type>>,
268 }
269             "#,
270         );
271     }
272
273     #[test]
274     fn test_extract_inner_type() {
275         check_assist(
276             extract_type_alias,
277             r"
278 struct S {
279     field: ($0u8$0,),
280 }
281             ",
282             r#"
283 type $0Type = u8;
284
285 struct S {
286     field: (Type,),
287 }
288             "#,
289         );
290     }
291
292     #[test]
293     fn extract_from_impl_or_trait() {
294         // When invoked in an impl/trait, extracted type alias should be placed next to the
295         // impl/trait, not inside.
296         check_assist(
297             extract_type_alias,
298             r#"
299 impl S {
300     fn f() -> $0(u8, u8)$0 {}
301 }
302             "#,
303             r#"
304 type $0Type = (u8, u8);
305
306 impl S {
307     fn f() -> Type {}
308 }
309             "#,
310         );
311         check_assist(
312             extract_type_alias,
313             r#"
314 trait Tr {
315     fn f() -> $0(u8, u8)$0 {}
316 }
317             "#,
318             r#"
319 type $0Type = (u8, u8);
320
321 trait Tr {
322     fn f() -> Type {}
323 }
324             "#,
325         );
326     }
327
328     #[test]
329     fn indentation() {
330         check_assist(
331             extract_type_alias,
332             r#"
333 mod m {
334     fn f() -> $0u8$0 {}
335 }
336             "#,
337             r#"
338 mod m {
339     type $0Type = u8;
340
341     fn f() -> Type {}
342 }
343             "#,
344         );
345     }
346
347     #[test]
348     fn generics() {
349         check_assist(
350             extract_type_alias,
351             r#"
352 struct Struct<const C: usize>;
353 impl<'outer, Outer, const OUTER: usize> () {
354     fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ())$0) {}
355 }
356 "#,
357             r#"
358 struct Struct<const C: usize>;
359 type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ());
360
361 impl<'outer, Outer, const OUTER: usize> () {
362     fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {}
363 }
364 "#,
365         );
366     }
367
368     #[test]
369     fn issue_11197() {
370         check_assist(
371             extract_type_alias,
372             r#"
373 struct Foo<T, const N: usize>
374 where
375     [T; N]: Sized,
376 {
377     arr: $0[T; N]$0,
378 }
379             "#,
380             r#"
381 type $0Type<T, const N: usize> = [T; N];
382
383 struct Foo<T, const N: usize>
384 where
385     [T; N]: Sized,
386 {
387     arr: Type<T, N>,
388 }
389             "#,
390         );
391     }
392 }