]> git.lizzy.rs Git - rust.git/blob - crates/ra_assists/src/handlers/add_new.rs
Merge assits::test_helpers and tests
[rust.git] / crates / ra_assists / src / handlers / add_new.rs
1 use hir::Adt;
2 use ra_syntax::{
3     ast::{
4         self, AstNode, NameOwner, StructKind, TypeAscriptionOwner, TypeParamsOwner, VisibilityOwner,
5     },
6     TextSize, T,
7 };
8 use stdx::{format_to, SepBy};
9
10 use crate::{Assist, AssistCtx, AssistId};
11
12 // Assist: add_new
13 //
14 // Adds a new inherent impl for a type.
15 //
16 // ```
17 // struct Ctx<T: Clone> {
18 //      data: T,<|>
19 // }
20 // ```
21 // ->
22 // ```
23 // struct Ctx<T: Clone> {
24 //      data: T,
25 // }
26 //
27 // impl<T: Clone> Ctx<T> {
28 //     fn new(data: T) -> Self { Self { data } }
29 // }
30 //
31 // ```
32 pub(crate) fn add_new(ctx: AssistCtx) -> Option<Assist> {
33     let strukt = ctx.find_node_at_offset::<ast::StructDef>()?;
34
35     // We want to only apply this to non-union structs with named fields
36     let field_list = match strukt.kind() {
37         StructKind::Record(named) => named,
38         _ => return None,
39     };
40
41     // Return early if we've found an existing new fn
42     let impl_def = find_struct_impl(&ctx, &strukt)?;
43
44     ctx.add_assist(AssistId("add_new"), "Add default constructor", |edit| {
45         edit.target(strukt.syntax().text_range());
46
47         let mut buf = String::with_capacity(512);
48
49         if impl_def.is_some() {
50             buf.push('\n');
51         }
52
53         let vis = strukt.visibility().map(|v| format!("{} ", v));
54         let vis = vis.as_deref().unwrap_or("");
55
56         let params = field_list
57             .fields()
58             .filter_map(|f| {
59                 Some(format!(
60                     "{}: {}",
61                     f.name()?.syntax().text(),
62                     f.ascribed_type()?.syntax().text()
63                 ))
64             })
65             .sep_by(", ");
66         let fields = field_list.fields().filter_map(|f| f.name()).sep_by(", ");
67
68         format_to!(buf, "    {}fn new({}) -> Self {{ Self {{ {} }} }}", vis, params, fields);
69
70         let (start_offset, end_offset) = impl_def
71             .and_then(|impl_def| {
72                 buf.push('\n');
73                 let start = impl_def
74                     .syntax()
75                     .descendants_with_tokens()
76                     .find(|t| t.kind() == T!['{'])?
77                     .text_range()
78                     .end();
79
80                 Some((start, TextSize::of("\n")))
81             })
82             .unwrap_or_else(|| {
83                 buf = generate_impl_text(&strukt, &buf);
84                 let start = strukt.syntax().text_range().end();
85
86                 (start, TextSize::of("\n}\n"))
87             });
88
89         edit.set_cursor(start_offset + TextSize::of(&buf) - end_offset);
90         edit.insert(start_offset, buf);
91     })
92 }
93
94 // Generates the surrounding `impl Type { <code> }` including type and lifetime
95 // parameters
96 fn generate_impl_text(strukt: &ast::StructDef, code: &str) -> String {
97     let type_params = strukt.type_param_list();
98     let mut buf = String::with_capacity(code.len());
99     buf.push_str("\n\nimpl");
100     if let Some(type_params) = &type_params {
101         format_to!(buf, "{}", type_params.syntax());
102     }
103     buf.push_str(" ");
104     buf.push_str(strukt.name().unwrap().text().as_str());
105     if let Some(type_params) = type_params {
106         let lifetime_params = type_params
107             .lifetime_params()
108             .filter_map(|it| it.lifetime_token())
109             .map(|it| it.text().clone());
110         let type_params =
111             type_params.type_params().filter_map(|it| it.name()).map(|it| it.text().clone());
112         format_to!(buf, "<{}>", lifetime_params.chain(type_params).sep_by(", "))
113     }
114
115     format_to!(buf, " {{\n{}\n}}\n", code);
116
117     buf
118 }
119
120 // Uses a syntax-driven approach to find any impl blocks for the struct that
121 // exist within the module/file
122 //
123 // Returns `None` if we've found an existing `new` fn
124 //
125 // FIXME: change the new fn checking to a more semantic approach when that's more
126 // viable (e.g. we process proc macros, etc)
127 fn find_struct_impl(ctx: &AssistCtx, strukt: &ast::StructDef) -> Option<Option<ast::ImplDef>> {
128     let db = ctx.db;
129     let module = strukt.syntax().ancestors().find(|node| {
130         ast::Module::can_cast(node.kind()) || ast::SourceFile::can_cast(node.kind())
131     })?;
132
133     let struct_def = ctx.sema.to_def(strukt)?;
134
135     let block = module.descendants().filter_map(ast::ImplDef::cast).find_map(|impl_blk| {
136         let blk = ctx.sema.to_def(&impl_blk)?;
137
138         // FIXME: handle e.g. `struct S<T>; impl<U> S<U> {}`
139         // (we currently use the wrong type parameter)
140         // also we wouldn't want to use e.g. `impl S<u32>`
141         let same_ty = match blk.target_ty(db).as_adt() {
142             Some(def) => def == Adt::Struct(struct_def),
143             None => false,
144         };
145         let not_trait_impl = blk.target_trait(db).is_none();
146
147         if !(same_ty && not_trait_impl) {
148             None
149         } else {
150             Some(impl_blk)
151         }
152     });
153
154     if let Some(ref impl_blk) = block {
155         if has_new_fn(impl_blk) {
156             return None;
157         }
158     }
159
160     Some(block)
161 }
162
163 fn has_new_fn(imp: &ast::ImplDef) -> bool {
164     if let Some(il) = imp.item_list() {
165         for item in il.assoc_items() {
166             if let ast::AssocItem::FnDef(f) = item {
167                 if let Some(name) = f.name() {
168                     if name.text().eq_ignore_ascii_case("new") {
169                         return true;
170                     }
171                 }
172             }
173         }
174     }
175
176     false
177 }
178
179 #[cfg(test)]
180 mod tests {
181     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
182
183     use super::*;
184
185     #[test]
186     #[rustfmt::skip]
187     fn test_add_new() {
188         // Check output of generation
189         check_assist(
190             add_new,
191 "struct Foo {<|>}",
192 "struct Foo {}
193
194 impl Foo {
195     fn new() -> Self { Self {  } }<|>
196 }
197 ",
198         );
199         check_assist(
200             add_new,
201 "struct Foo<T: Clone> {<|>}",
202 "struct Foo<T: Clone> {}
203
204 impl<T: Clone> Foo<T> {
205     fn new() -> Self { Self {  } }<|>
206 }
207 ",
208         );
209         check_assist(
210             add_new,
211 "struct Foo<'a, T: Foo<'a>> {<|>}",
212 "struct Foo<'a, T: Foo<'a>> {}
213
214 impl<'a, T: Foo<'a>> Foo<'a, T> {
215     fn new() -> Self { Self {  } }<|>
216 }
217 ",
218         );
219         check_assist(
220             add_new,
221 "struct Foo { baz: String <|>}",
222 "struct Foo { baz: String }
223
224 impl Foo {
225     fn new(baz: String) -> Self { Self { baz } }<|>
226 }
227 ",
228         );
229         check_assist(
230             add_new,
231 "struct Foo { baz: String, qux: Vec<i32> <|>}",
232 "struct Foo { baz: String, qux: Vec<i32> }
233
234 impl Foo {
235     fn new(baz: String, qux: Vec<i32>) -> Self { Self { baz, qux } }<|>
236 }
237 ",
238         );
239
240         // Check that visibility modifiers don't get brought in for fields
241         check_assist(
242             add_new,
243 "struct Foo { pub baz: String, pub qux: Vec<i32> <|>}",
244 "struct Foo { pub baz: String, pub qux: Vec<i32> }
245
246 impl Foo {
247     fn new(baz: String, qux: Vec<i32>) -> Self { Self { baz, qux } }<|>
248 }
249 ",
250         );
251
252         // Check that it reuses existing impls
253         check_assist(
254             add_new,
255 "struct Foo {<|>}
256
257 impl Foo {}
258 ",
259 "struct Foo {}
260
261 impl Foo {
262     fn new() -> Self { Self {  } }<|>
263 }
264 ",
265         );
266         check_assist(
267             add_new,
268 "struct Foo {<|>}
269
270 impl Foo {
271     fn qux(&self) {}
272 }
273 ",
274 "struct Foo {}
275
276 impl Foo {
277     fn new() -> Self { Self {  } }<|>
278
279     fn qux(&self) {}
280 }
281 ",
282         );
283
284         check_assist(
285             add_new,
286 "struct Foo {<|>}
287
288 impl Foo {
289     fn qux(&self) {}
290     fn baz() -> i32 {
291         5
292     }
293 }
294 ",
295 "struct Foo {}
296
297 impl Foo {
298     fn new() -> Self { Self {  } }<|>
299
300     fn qux(&self) {}
301     fn baz() -> i32 {
302         5
303     }
304 }
305 ",
306         );
307
308         // Check visibility of new fn based on struct
309         check_assist(
310             add_new,
311 "pub struct Foo {<|>}",
312 "pub struct Foo {}
313
314 impl Foo {
315     pub fn new() -> Self { Self {  } }<|>
316 }
317 ",
318         );
319         check_assist(
320             add_new,
321 "pub(crate) struct Foo {<|>}",
322 "pub(crate) struct Foo {}
323
324 impl Foo {
325     pub(crate) fn new() -> Self { Self {  } }<|>
326 }
327 ",
328         );
329     }
330
331     #[test]
332     fn add_new_not_applicable_if_fn_exists() {
333         check_assist_not_applicable(
334             add_new,
335             "
336 struct Foo {<|>}
337
338 impl Foo {
339     fn new() -> Self {
340         Self
341     }
342 }",
343         );
344
345         check_assist_not_applicable(
346             add_new,
347             "
348 struct Foo {<|>}
349
350 impl Foo {
351     fn New() -> Self {
352         Self
353     }
354 }",
355         );
356     }
357
358     #[test]
359     fn add_new_target() {
360         check_assist_target(
361             add_new,
362             "
363 struct SomeThingIrrelevant;
364 /// Has a lifetime parameter
365 struct Foo<'a, T: Foo<'a>> {<|>}
366 struct EvenMoreIrrelevant;
367 ",
368             "/// Has a lifetime parameter
369 struct Foo<'a, T: Foo<'a>> {}",
370         );
371     }
372
373     #[test]
374     fn test_unrelated_new() {
375         check_assist(
376             add_new,
377             r##"
378 pub struct AstId<N: AstNode> {
379     file_id: HirFileId,
380     file_ast_id: FileAstId<N>,
381 }
382
383 impl<N: AstNode> AstId<N> {
384     pub fn new(file_id: HirFileId, file_ast_id: FileAstId<N>) -> AstId<N> {
385         AstId { file_id, file_ast_id }
386     }
387 }
388
389 pub struct Source<T> {
390     pub file_id: HirFileId,<|>
391     pub ast: T,
392 }
393
394 impl<T> Source<T> {
395     pub fn map<F: FnOnce(T) -> U, U>(self, f: F) -> Source<U> {
396         Source { file_id: self.file_id, ast: f(self.ast) }
397     }
398 }
399 "##,
400             r##"
401 pub struct AstId<N: AstNode> {
402     file_id: HirFileId,
403     file_ast_id: FileAstId<N>,
404 }
405
406 impl<N: AstNode> AstId<N> {
407     pub fn new(file_id: HirFileId, file_ast_id: FileAstId<N>) -> AstId<N> {
408         AstId { file_id, file_ast_id }
409     }
410 }
411
412 pub struct Source<T> {
413     pub file_id: HirFileId,
414     pub ast: T,
415 }
416
417 impl<T> Source<T> {
418     pub fn new(file_id: HirFileId, ast: T) -> Self { Self { file_id, ast } }<|>
419
420     pub fn map<F: FnOnce(T) -> U, U>(self, f: F) -> Source<U> {
421         Source { file_id: self.file_id, ast: f(self.ast) }
422     }
423 }
424 "##,
425         );
426     }
427 }