]> git.lizzy.rs Git - rust.git/blob - crates/ide_diagnostics/src/handlers/type_mismatch.rs
Merge #11840
[rust.git] / crates / ide_diagnostics / src / handlers / type_mismatch.rs
1 use hir::{db::AstDatabase, HirDisplay, Type, TypeInfo};
2 use ide_db::{
3     famous_defs::FamousDefs, source_change::SourceChange,
4     syntax_helpers::node_ext::for_each_tail_expr,
5 };
6 use syntax::{
7     ast::{BlockExpr, ExprStmt},
8     AstNode,
9 };
10 use text_edit::TextEdit;
11
12 use crate::{fix, Assist, Diagnostic, DiagnosticsContext};
13
14 // Diagnostic: type-mismatch
15 //
16 // This diagnostic is triggered when the type of an expression does not match
17 // the expected type.
18 pub(crate) fn type_mismatch(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Diagnostic {
19     let mut diag = Diagnostic::new(
20         "type-mismatch",
21         format!(
22             "expected {}, found {}",
23             d.expected.display(ctx.sema.db),
24             d.actual.display(ctx.sema.db)
25         ),
26         ctx.sema.diagnostics_display_range(d.expr.clone().map(|it| it.into())).range,
27     )
28     .with_fixes(fixes(ctx, d));
29     if diag.fixes.is_none() {
30         diag.experimental = true;
31     }
32     diag
33 }
34
35 fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option<Vec<Assist>> {
36     let mut fixes = Vec::new();
37
38     add_reference(ctx, d, &mut fixes);
39     add_missing_ok_or_some(ctx, d, &mut fixes);
40     remove_semicolon(ctx, d, &mut fixes);
41
42     if fixes.is_empty() {
43         None
44     } else {
45         Some(fixes)
46     }
47 }
48
49 fn add_reference(
50     ctx: &DiagnosticsContext<'_>,
51     d: &hir::TypeMismatch,
52     acc: &mut Vec<Assist>,
53 ) -> Option<()> {
54     let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?;
55     let expr_node = d.expr.value.to_node(&root);
56
57     let range = ctx.sema.diagnostics_display_range(d.expr.clone().map(|it| it.into())).range;
58
59     let (_, mutability) = d.expected.as_reference()?;
60     let actual_with_ref = Type::reference(&d.actual, mutability);
61     if !actual_with_ref.could_coerce_to(ctx.sema.db, &d.expected) {
62         return None;
63     }
64
65     let ampersands = format!("&{}", mutability.as_keyword_for_ref());
66
67     let edit = TextEdit::insert(expr_node.syntax().text_range().start(), ampersands);
68     let source_change =
69         SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), edit);
70     acc.push(fix("add_reference_here", "Add reference here", source_change, range));
71     Some(())
72 }
73
74 fn add_missing_ok_or_some(
75     ctx: &DiagnosticsContext<'_>,
76     d: &hir::TypeMismatch,
77     acc: &mut Vec<Assist>,
78 ) -> Option<()> {
79     let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?;
80     let tail_expr = d.expr.value.to_node(&root);
81     let tail_expr_range = tail_expr.syntax().text_range();
82     let scope = ctx.sema.scope(tail_expr.syntax());
83
84     let expected_adt = d.expected.as_adt()?;
85     let expected_enum = expected_adt.as_enum()?;
86
87     let famous_defs = FamousDefs(&ctx.sema, scope.krate());
88     let core_result = famous_defs.core_result_Result();
89     let core_option = famous_defs.core_option_Option();
90
91     if Some(expected_enum) != core_result && Some(expected_enum) != core_option {
92         return None;
93     }
94
95     let variant_name = if Some(expected_enum) == core_result { "Ok" } else { "Some" };
96
97     let wrapped_actual_ty = expected_adt.ty_with_args(ctx.sema.db, &[d.actual.clone()]);
98
99     if !d.expected.could_unify_with(ctx.sema.db, &wrapped_actual_ty) {
100         return None;
101     }
102
103     let mut builder = TextEdit::builder();
104     for_each_tail_expr(&tail_expr, &mut |expr| {
105         if ctx.sema.type_of_expr(expr).map(TypeInfo::adjusted).as_ref() != Some(&d.expected) {
106             builder.insert(expr.syntax().text_range().start(), format!("{}(", variant_name));
107             builder.insert(expr.syntax().text_range().end(), ")".to_string());
108         }
109     });
110     let source_change =
111         SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), builder.finish());
112     let name = format!("Wrap in {}", variant_name);
113     acc.push(fix("wrap_tail_expr", &name, source_change, tail_expr_range));
114     Some(())
115 }
116
117 fn remove_semicolon(
118     ctx: &DiagnosticsContext<'_>,
119     d: &hir::TypeMismatch,
120     acc: &mut Vec<Assist>,
121 ) -> Option<()> {
122     let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?;
123     let expr = d.expr.value.to_node(&root);
124     if !d.actual.is_unit() {
125         return None;
126     }
127     let block = BlockExpr::cast(expr.syntax().clone())?;
128     let expr_before_semi =
129         block.statements().last().and_then(|s| ExprStmt::cast(s.syntax().clone()))?;
130     let type_before_semi = ctx.sema.type_of_expr(&expr_before_semi.expr()?)?.original();
131     if !type_before_semi.could_coerce_to(ctx.sema.db, &d.expected) {
132         return None;
133     }
134     let semicolon_range = expr_before_semi.semicolon_token()?.text_range();
135
136     let edit = TextEdit::delete(semicolon_range);
137     let source_change =
138         SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), edit);
139
140     acc.push(fix("remove_semicolon", "Remove this semicolon", source_change, semicolon_range));
141     Some(())
142 }
143
144 #[cfg(test)]
145 mod tests {
146     use crate::tests::{check_diagnostics, check_fix, check_no_fix};
147
148     #[test]
149     fn missing_reference() {
150         check_diagnostics(
151             r#"
152 fn main() {
153     test(123);
154        //^^^ ðŸ’¡ error: expected &i32, found i32
155 }
156 fn test(arg: &i32) {}
157 "#,
158         );
159     }
160
161     #[test]
162     fn test_add_reference_to_int() {
163         check_fix(
164             r#"
165 fn main() {
166     test(123$0);
167 }
168 fn test(arg: &i32) {}
169             "#,
170             r#"
171 fn main() {
172     test(&123);
173 }
174 fn test(arg: &i32) {}
175             "#,
176         );
177     }
178
179     #[test]
180     fn test_add_mutable_reference_to_int() {
181         check_fix(
182             r#"
183 fn main() {
184     test($0123);
185 }
186 fn test(arg: &mut i32) {}
187             "#,
188             r#"
189 fn main() {
190     test(&mut 123);
191 }
192 fn test(arg: &mut i32) {}
193             "#,
194         );
195     }
196
197     #[test]
198     fn test_add_reference_to_array() {
199         check_fix(
200             r#"
201 //- minicore: coerce_unsized
202 fn main() {
203     test($0[1, 2, 3]);
204 }
205 fn test(arg: &[i32]) {}
206             "#,
207             r#"
208 fn main() {
209     test(&[1, 2, 3]);
210 }
211 fn test(arg: &[i32]) {}
212             "#,
213         );
214     }
215
216     #[test]
217     fn test_add_reference_with_autoderef() {
218         check_fix(
219             r#"
220 //- minicore: coerce_unsized, deref
221 struct Foo;
222 struct Bar;
223 impl core::ops::Deref for Foo {
224     type Target = Bar;
225 }
226
227 fn main() {
228     test($0Foo);
229 }
230 fn test(arg: &Bar) {}
231             "#,
232             r#"
233 struct Foo;
234 struct Bar;
235 impl core::ops::Deref for Foo {
236     type Target = Bar;
237 }
238
239 fn main() {
240     test(&Foo);
241 }
242 fn test(arg: &Bar) {}
243             "#,
244         );
245     }
246
247     #[test]
248     fn test_add_reference_to_method_call() {
249         check_fix(
250             r#"
251 fn main() {
252     Test.call_by_ref($0123);
253 }
254 struct Test;
255 impl Test {
256     fn call_by_ref(&self, arg: &i32) {}
257 }
258             "#,
259             r#"
260 fn main() {
261     Test.call_by_ref(&123);
262 }
263 struct Test;
264 impl Test {
265     fn call_by_ref(&self, arg: &i32) {}
266 }
267             "#,
268         );
269     }
270
271     #[test]
272     fn test_add_reference_to_let_stmt() {
273         check_fix(
274             r#"
275 fn main() {
276     let test: &i32 = $0123;
277 }
278             "#,
279             r#"
280 fn main() {
281     let test: &i32 = &123;
282 }
283             "#,
284         );
285     }
286
287     #[test]
288     fn test_add_mutable_reference_to_let_stmt() {
289         check_fix(
290             r#"
291 fn main() {
292     let test: &mut i32 = $0123;
293 }
294             "#,
295             r#"
296 fn main() {
297     let test: &mut i32 = &mut 123;
298 }
299             "#,
300         );
301     }
302
303     #[test]
304     fn test_wrap_return_type_option() {
305         check_fix(
306             r#"
307 //- minicore: option, result
308 fn div(x: i32, y: i32) -> Option<i32> {
309     if y == 0 {
310         return None;
311     }
312     x / y$0
313 }
314 "#,
315             r#"
316 fn div(x: i32, y: i32) -> Option<i32> {
317     if y == 0 {
318         return None;
319     }
320     Some(x / y)
321 }
322 "#,
323         );
324     }
325
326     #[test]
327     fn test_wrap_return_type_option_tails() {
328         check_fix(
329             r#"
330 //- minicore: option, result
331 fn div(x: i32, y: i32) -> Option<i32> {
332     if y == 0 {
333         0
334     } else if true {
335         100
336     } else {
337         None
338     }$0
339 }
340 "#,
341             r#"
342 fn div(x: i32, y: i32) -> Option<i32> {
343     if y == 0 {
344         Some(0)
345     } else if true {
346         Some(100)
347     } else {
348         None
349     }
350 }
351 "#,
352         );
353     }
354
355     #[test]
356     fn test_wrap_return_type() {
357         check_fix(
358             r#"
359 //- minicore: option, result
360 fn div(x: i32, y: i32) -> Result<i32, ()> {
361     if y == 0 {
362         return Err(());
363     }
364     x / y$0
365 }
366 "#,
367             r#"
368 fn div(x: i32, y: i32) -> Result<i32, ()> {
369     if y == 0 {
370         return Err(());
371     }
372     Ok(x / y)
373 }
374 "#,
375         );
376     }
377
378     #[test]
379     fn test_wrap_return_type_handles_generic_functions() {
380         check_fix(
381             r#"
382 //- minicore: option, result
383 fn div<T>(x: T) -> Result<T, i32> {
384     if x == 0 {
385         return Err(7);
386     }
387     $0x
388 }
389 "#,
390             r#"
391 fn div<T>(x: T) -> Result<T, i32> {
392     if x == 0 {
393         return Err(7);
394     }
395     Ok(x)
396 }
397 "#,
398         );
399     }
400
401     #[test]
402     fn test_wrap_return_type_handles_type_aliases() {
403         check_fix(
404             r#"
405 //- minicore: option, result
406 type MyResult<T> = Result<T, ()>;
407
408 fn div(x: i32, y: i32) -> MyResult<i32> {
409     if y == 0 {
410         return Err(());
411     }
412     x $0/ y
413 }
414 "#,
415             r#"
416 type MyResult<T> = Result<T, ()>;
417
418 fn div(x: i32, y: i32) -> MyResult<i32> {
419     if y == 0 {
420         return Err(());
421     }
422     Ok(x / y)
423 }
424 "#,
425         );
426     }
427
428     #[test]
429     fn test_in_const_and_static() {
430         check_fix(
431             r#"
432 //- minicore: option, result
433 static A: Option<()> = {($0)};
434             "#,
435             r#"
436 static A: Option<()> = {Some(())};
437             "#,
438         );
439         check_fix(
440             r#"
441 //- minicore: option, result
442 const _: Option<()> = {($0)};
443             "#,
444             r#"
445 const _: Option<()> = {Some(())};
446             "#,
447         );
448     }
449
450     #[test]
451     fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() {
452         check_no_fix(
453             r#"
454 //- minicore: option, result
455 fn foo() -> Result<(), i32> { 0$0 }
456 "#,
457         );
458     }
459
460     #[test]
461     fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() {
462         check_no_fix(
463             r#"
464 //- minicore: option, result
465 enum SomeOtherEnum { Ok(i32), Err(String) }
466
467 fn foo() -> SomeOtherEnum { 0$0 }
468 "#,
469         );
470     }
471
472     #[test]
473     fn remove_semicolon() {
474         check_fix(r#"fn f() -> i32 { 92$0; }"#, r#"fn f() -> i32 { 92 }"#);
475     }
476 }