]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs
Rollup merge of #99954 - dingxiangfei2009:break-out-let-else-higher-up, r=oli-obk
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / extract_type_alias.rs
1 use either::Either;
2 use ide_db::syntax_helpers::node_ext::walk_ty;
3 use itertools::Itertools;
4 use syntax::{
5     ast::{self, edit::IndentLevel, AstNode, HasGenericParams, HasName},
6     match_ast,
7 };
8
9 use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11 // Assist: extract_type_alias
12 //
13 // Extracts the selected type as a type alias.
14 //
15 // ```
16 // struct S {
17 //     field: $0(u8, u8, u8)$0,
18 // }
19 // ```
20 // ->
21 // ```
22 // type $0Type = (u8, u8, u8);
23 //
24 // struct S {
25 //     field: Type,
26 // }
27 // ```
28 pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
29     if ctx.has_empty_selection() {
30         return None;
31     }
32
33     let ty = ctx.find_node_at_range::<ast::Type>()?;
34     let item = ty.syntax().ancestors().find_map(ast::Item::cast)?;
35     let assoc_owner = item.syntax().ancestors().nth(2).and_then(|it| {
36         match_ast! {
37             match it {
38                 ast::Trait(tr) => Some(Either::Left(tr)),
39                 ast::Impl(impl_) => Some(Either::Right(impl_)),
40                 _ => None,
41             }
42         }
43     });
44     let node = assoc_owner.as_ref().map_or_else(
45         || item.syntax(),
46         |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax),
47     );
48     let insert_pos = node.text_range().start();
49     let target = ty.syntax().text_range();
50
51     acc.add(
52         AssistId("extract_type_alias", AssistKind::RefactorExtract),
53         "Extract type as type alias",
54         target,
55         |builder| {
56             let mut known_generics = match item.generic_param_list() {
57                 Some(it) => it.generic_params().collect(),
58                 None => Vec::new(),
59             };
60             if let Some(it) = assoc_owner.as_ref().and_then(|it| match it {
61                 Either::Left(it) => it.generic_param_list(),
62                 Either::Right(it) => it.generic_param_list(),
63             }) {
64                 known_generics.extend(it.generic_params());
65             }
66             let generics = collect_used_generics(&ty, &known_generics);
67
68             let replacement = if !generics.is_empty() {
69                 format!(
70                     "Type<{}>",
71                     generics.iter().format_with(", ", |generic, f| {
72                         match generic {
73                             ast::GenericParam::ConstParam(cp) => f(&cp.name().unwrap()),
74                             ast::GenericParam::LifetimeParam(lp) => f(&lp.lifetime().unwrap()),
75                             ast::GenericParam::TypeParam(tp) => f(&tp.name().unwrap()),
76                         }
77                     })
78                 )
79             } else {
80                 String::from("Type")
81             };
82             builder.replace(target, replacement);
83
84             let indent = IndentLevel::from_node(node);
85             let generics = if !generics.is_empty() {
86                 format!("<{}>", generics.iter().format(", "))
87             } else {
88                 String::new()
89             };
90             match ctx.config.snippet_cap {
91                 Some(cap) => {
92                     builder.insert_snippet(
93                         cap,
94                         insert_pos,
95                         format!("type $0Type{} = {};\n\n{}", generics, ty, indent),
96                     );
97                 }
98                 None => {
99                     builder.insert(
100                         insert_pos,
101                         format!("type Type{} = {};\n\n{}", generics, ty, indent),
102                     );
103                 }
104             }
105         },
106     )
107 }
108
109 fn collect_used_generics<'gp>(
110     ty: &ast::Type,
111     known_generics: &'gp [ast::GenericParam],
112 ) -> Vec<&'gp ast::GenericParam> {
113     // can't use a closure -> closure here cause lifetime inference fails for that
114     fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ {
115         move |gp: &&ast::GenericParam| match gp {
116             ast::GenericParam::LifetimeParam(lp) => {
117                 lp.lifetime().map_or(false, |lt| lt.text() == text)
118             }
119             _ => false,
120         }
121     }
122
123     let mut generics = Vec::new();
124     walk_ty(ty, &mut |ty| match ty {
125         ast::Type::PathType(ty) => {
126             if let Some(path) = ty.path() {
127                 if let Some(name_ref) = path.as_single_name_ref() {
128                     if let Some(param) = known_generics.iter().find(|gp| {
129                         match gp {
130                             ast::GenericParam::ConstParam(cp) => cp.name(),
131                             ast::GenericParam::TypeParam(tp) => tp.name(),
132                             _ => None,
133                         }
134                         .map_or(false, |n| n.text() == name_ref.text())
135                     }) {
136                         generics.push(param);
137                     }
138                 }
139                 generics.extend(
140                     path.segments()
141                         .filter_map(|seg| seg.generic_arg_list())
142                         .flat_map(|it| it.generic_args())
143                         .filter_map(|it| match it {
144                             ast::GenericArg::LifetimeArg(lt) => {
145                                 let lt = lt.lifetime()?;
146                                 known_generics.iter().find(find_lifetime(&lt.text()))
147                             }
148                             _ => None,
149                         }),
150                 );
151             }
152         }
153         ast::Type::ImplTraitType(impl_ty) => {
154             if let Some(it) = impl_ty.type_bound_list() {
155                 generics.extend(
156                     it.bounds()
157                         .filter_map(|it| it.lifetime())
158                         .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
159                 );
160             }
161         }
162         ast::Type::DynTraitType(dyn_ty) => {
163             if let Some(it) = dyn_ty.type_bound_list() {
164                 generics.extend(
165                     it.bounds()
166                         .filter_map(|it| it.lifetime())
167                         .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
168                 );
169             }
170         }
171         ast::Type::RefType(ref_) => generics.extend(
172             ref_.lifetime().and_then(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
173         ),
174         _ => (),
175     });
176     // stable resort to lifetime, type, const
177     generics.sort_by_key(|gp| match gp {
178         ast::GenericParam::ConstParam(_) => 2,
179         ast::GenericParam::LifetimeParam(_) => 0,
180         ast::GenericParam::TypeParam(_) => 1,
181     });
182     generics
183 }
184
185 #[cfg(test)]
186 mod tests {
187     use crate::tests::{check_assist, check_assist_not_applicable};
188
189     use super::*;
190
191     #[test]
192     fn test_not_applicable_without_selection() {
193         check_assist_not_applicable(
194             extract_type_alias,
195             r"
196 struct S {
197     field: $0(u8, u8, u8),
198 }
199             ",
200         );
201     }
202
203     #[test]
204     fn test_simple_types() {
205         check_assist(
206             extract_type_alias,
207             r"
208 struct S {
209     field: $0u8$0,
210 }
211             ",
212             r#"
213 type $0Type = u8;
214
215 struct S {
216     field: Type,
217 }
218             "#,
219         );
220     }
221
222     #[test]
223     fn test_generic_type_arg() {
224         check_assist(
225             extract_type_alias,
226             r"
227 fn generic<T>() {}
228
229 fn f() {
230     generic::<$0()$0>();
231 }
232             ",
233             r#"
234 fn generic<T>() {}
235
236 type $0Type = ();
237
238 fn f() {
239     generic::<Type>();
240 }
241             "#,
242         );
243     }
244
245     #[test]
246     fn test_inner_type_arg() {
247         check_assist(
248             extract_type_alias,
249             r"
250 struct Vec<T> {}
251 struct S {
252     v: Vec<Vec<$0Vec<u8>$0>>,
253 }
254             ",
255             r#"
256 struct Vec<T> {}
257 type $0Type = Vec<u8>;
258
259 struct S {
260     v: Vec<Vec<Type>>,
261 }
262             "#,
263         );
264     }
265
266     #[test]
267     fn test_extract_inner_type() {
268         check_assist(
269             extract_type_alias,
270             r"
271 struct S {
272     field: ($0u8$0,),
273 }
274             ",
275             r#"
276 type $0Type = u8;
277
278 struct S {
279     field: (Type,),
280 }
281             "#,
282         );
283     }
284
285     #[test]
286     fn extract_from_impl_or_trait() {
287         // When invoked in an impl/trait, extracted type alias should be placed next to the
288         // impl/trait, not inside.
289         check_assist(
290             extract_type_alias,
291             r#"
292 impl S {
293     fn f() -> $0(u8, u8)$0 {}
294 }
295             "#,
296             r#"
297 type $0Type = (u8, u8);
298
299 impl S {
300     fn f() -> Type {}
301 }
302             "#,
303         );
304         check_assist(
305             extract_type_alias,
306             r#"
307 trait Tr {
308     fn f() -> $0(u8, u8)$0 {}
309 }
310             "#,
311             r#"
312 type $0Type = (u8, u8);
313
314 trait Tr {
315     fn f() -> Type {}
316 }
317             "#,
318         );
319     }
320
321     #[test]
322     fn indentation() {
323         check_assist(
324             extract_type_alias,
325             r#"
326 mod m {
327     fn f() -> $0u8$0 {}
328 }
329             "#,
330             r#"
331 mod m {
332     type $0Type = u8;
333
334     fn f() -> Type {}
335 }
336             "#,
337         );
338     }
339
340     #[test]
341     fn generics() {
342         check_assist(
343             extract_type_alias,
344             r#"
345 struct Struct<const C: usize>;
346 impl<'outer, Outer, const OUTER: usize> () {
347     fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ())$0) {}
348 }
349 "#,
350             r#"
351 struct Struct<const C: usize>;
352 type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ());
353
354 impl<'outer, Outer, const OUTER: usize> () {
355     fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {}
356 }
357 "#,
358         );
359     }
360 }