]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_lifetime_to_type.rs
Auto merge of #103913 - Neutron3529:patch-1, r=thomcc
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / add_lifetime_to_type.rs
1 use syntax::ast::{self, AstNode, HasGenericParams, HasName};
2
3 use crate::{AssistContext, AssistId, AssistKind, Assists};
4
5 // Assist: add_lifetime_to_type
6 //
7 // Adds a new lifetime to a struct, enum or union.
8 //
9 // ```
10 // struct Point {
11 //     x: &$0u32,
12 //     y: u32,
13 // }
14 // ```
15 // ->
16 // ```
17 // struct Point<'a> {
18 //     x: &'a u32,
19 //     y: u32,
20 // }
21 // ```
22 pub(crate) fn add_lifetime_to_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
23     let ref_type_focused = ctx.find_node_at_offset::<ast::RefType>()?;
24     if ref_type_focused.lifetime().is_some() {
25         return None;
26     }
27
28     let node = ctx.find_node_at_offset::<ast::Adt>()?;
29     let has_lifetime = node
30         .generic_param_list()
31         .map_or(false, |gen_list| gen_list.lifetime_params().next().is_some());
32
33     if has_lifetime {
34         return None;
35     }
36
37     let ref_types = fetch_borrowed_types(&node)?;
38     let target = node.syntax().text_range();
39
40     acc.add(
41         AssistId("add_lifetime_to_type", AssistKind::Generate),
42         "Add lifetime",
43         target,
44         |builder| {
45             match node.generic_param_list() {
46                 Some(gen_param) => {
47                     if let Some(left_angle) = gen_param.l_angle_token() {
48                         builder.insert(left_angle.text_range().end(), "'a, ");
49                     }
50                 }
51                 None => {
52                     if let Some(name) = node.name() {
53                         builder.insert(name.syntax().text_range().end(), "<'a>");
54                     }
55                 }
56             }
57
58             for ref_type in ref_types {
59                 if let Some(amp_token) = ref_type.amp_token() {
60                     builder.insert(amp_token.text_range().end(), "'a ");
61                 }
62             }
63         },
64     )
65 }
66
67 fn fetch_borrowed_types(node: &ast::Adt) -> Option<Vec<ast::RefType>> {
68     let ref_types: Vec<ast::RefType> = match node {
69         ast::Adt::Enum(enum_) => {
70             let variant_list = enum_.variant_list()?;
71             variant_list
72                 .variants()
73                 .filter_map(|variant| {
74                     let field_list = variant.field_list()?;
75
76                     find_ref_types_from_field_list(&field_list)
77                 })
78                 .flatten()
79                 .collect()
80         }
81         ast::Adt::Struct(strukt) => {
82             let field_list = strukt.field_list()?;
83             find_ref_types_from_field_list(&field_list)?
84         }
85         ast::Adt::Union(un) => {
86             let record_field_list = un.record_field_list()?;
87             record_field_list
88                 .fields()
89                 .filter_map(|r_field| {
90                     if let ast::Type::RefType(ref_type) = r_field.ty()? {
91                         if ref_type.lifetime().is_none() {
92                             return Some(ref_type);
93                         }
94                     }
95
96                     None
97                 })
98                 .collect()
99         }
100     };
101
102     if ref_types.is_empty() {
103         None
104     } else {
105         Some(ref_types)
106     }
107 }
108
109 fn find_ref_types_from_field_list(field_list: &ast::FieldList) -> Option<Vec<ast::RefType>> {
110     let ref_types: Vec<ast::RefType> = match field_list {
111         ast::FieldList::RecordFieldList(record_list) => record_list
112             .fields()
113             .filter_map(|f| {
114                 if let ast::Type::RefType(ref_type) = f.ty()? {
115                     if ref_type.lifetime().is_none() {
116                         return Some(ref_type);
117                     }
118                 }
119
120                 None
121             })
122             .collect(),
123         ast::FieldList::TupleFieldList(tuple_field_list) => tuple_field_list
124             .fields()
125             .filter_map(|f| {
126                 if let ast::Type::RefType(ref_type) = f.ty()? {
127                     if ref_type.lifetime().is_none() {
128                         return Some(ref_type);
129                     }
130                 }
131
132                 None
133             })
134             .collect(),
135     };
136
137     if ref_types.is_empty() {
138         None
139     } else {
140         Some(ref_types)
141     }
142 }
143
144 #[cfg(test)]
145 mod tests {
146     use crate::tests::{check_assist, check_assist_not_applicable};
147
148     use super::*;
149
150     #[test]
151     fn add_lifetime_to_struct() {
152         check_assist(
153             add_lifetime_to_type,
154             r#"struct Foo { a: &$0i32 }"#,
155             r#"struct Foo<'a> { a: &'a i32 }"#,
156         );
157
158         check_assist(
159             add_lifetime_to_type,
160             r#"struct Foo { a: &$0i32, b: &usize }"#,
161             r#"struct Foo<'a> { a: &'a i32, b: &'a usize }"#,
162         );
163
164         check_assist(
165             add_lifetime_to_type,
166             r#"struct Foo { a: &$0i32, b: usize }"#,
167             r#"struct Foo<'a> { a: &'a i32, b: usize }"#,
168         );
169
170         check_assist(
171             add_lifetime_to_type,
172             r#"struct Foo<T> { a: &$0T, b: usize }"#,
173             r#"struct Foo<'a, T> { a: &'a T, b: usize }"#,
174         );
175
176         check_assist_not_applicable(add_lifetime_to_type, r#"struct Foo<'a> { a: &$0'a i32 }"#);
177         check_assist_not_applicable(add_lifetime_to_type, r#"struct Foo { a: &'a$0 i32 }"#);
178     }
179
180     #[test]
181     fn add_lifetime_to_enum() {
182         check_assist(
183             add_lifetime_to_type,
184             r#"enum Foo { Bar { a: i32 }, Other, Tuple(u32, &$0u32)}"#,
185             r#"enum Foo<'a> { Bar { a: i32 }, Other, Tuple(u32, &'a u32)}"#,
186         );
187
188         check_assist(
189             add_lifetime_to_type,
190             r#"enum Foo { Bar { a: &$0i32 }}"#,
191             r#"enum Foo<'a> { Bar { a: &'a i32 }}"#,
192         );
193
194         check_assist(
195             add_lifetime_to_type,
196             r#"enum Foo<T> { Bar { a: &$0i32, b: &T }}"#,
197             r#"enum Foo<'a, T> { Bar { a: &'a i32, b: &'a T }}"#,
198         );
199
200         check_assist_not_applicable(
201             add_lifetime_to_type,
202             r#"enum Foo<'a> { Bar { a: &$0'a i32 }}"#,
203         );
204         check_assist_not_applicable(add_lifetime_to_type, r#"enum Foo { Bar, $0Misc }"#);
205     }
206
207     #[test]
208     fn add_lifetime_to_union() {
209         check_assist(
210             add_lifetime_to_type,
211             r#"union Foo { a: &$0i32 }"#,
212             r#"union Foo<'a> { a: &'a i32 }"#,
213         );
214
215         check_assist(
216             add_lifetime_to_type,
217             r#"union Foo { a: &$0i32, b: &usize }"#,
218             r#"union Foo<'a> { a: &'a i32, b: &'a usize }"#,
219         );
220
221         check_assist(
222             add_lifetime_to_type,
223             r#"union Foo<T> { a: &$0T, b: usize }"#,
224             r#"union Foo<'a, T> { a: &'a T, b: usize }"#,
225         );
226
227         check_assist_not_applicable(add_lifetime_to_type, r#"struct Foo<'a> { a: &'a $0i32 }"#);
228     }
229 }