]> git.lizzy.rs Git - rust.git/blob - crates/hir_ty/src/tests.rs
Merge #9334
[rust.git] / crates / hir_ty / src / tests.rs
1 mod never_type;
2 mod coercion;
3 mod regression;
4 mod simple;
5 mod patterns;
6 mod traits;
7 mod method_resolution;
8 mod macros;
9 mod display_source_code;
10 mod incremental;
11
12 use std::{collections::HashMap, env, sync::Arc};
13
14 use base_db::{fixture::WithFixture, FileRange, SourceDatabase, SourceDatabaseExt};
15 use expect_test::Expect;
16 use hir_def::{
17     body::{Body, BodySourceMap, SyntheticSyntax},
18     child_by_source::ChildBySource,
19     db::DefDatabase,
20     item_scope::ItemScope,
21     keys,
22     nameres::DefMap,
23     src::HasSource,
24     AssocItemId, DefWithBodyId, LocalModuleId, Lookup, ModuleDefId,
25 };
26 use hir_expand::{db::AstDatabase, InFile};
27 use once_cell::race::OnceBool;
28 use stdx::format_to;
29 use syntax::{
30     algo,
31     ast::{self, AstNode, NameOwner},
32     SyntaxNode,
33 };
34 use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};
35 use tracing_tree::HierarchicalLayer;
36
37 use crate::{
38     db::HirDatabase, display::HirDisplay, infer::TypeMismatch, test_db::TestDB, InferenceResult, Ty,
39 };
40
41 // These tests compare the inference results for all expressions in a file
42 // against snapshots of the expected results using expect. Use
43 // `env UPDATE_EXPECT=1 cargo test -p hir_ty` to update the snapshots.
44
45 fn setup_tracing() -> Option<tracing::subscriber::DefaultGuard> {
46     static ENABLE: OnceBool = OnceBool::new();
47     if !ENABLE.get_or_init(|| env::var("CHALK_DEBUG").is_ok()) {
48         return None;
49     }
50
51     let filter = EnvFilter::from_env("CHALK_DEBUG");
52     let layer = HierarchicalLayer::default()
53         .with_indent_lines(true)
54         .with_ansi(false)
55         .with_indent_amount(2)
56         .with_writer(std::io::stderr);
57     let subscriber = Registry::default().with(filter).with(layer);
58     Some(tracing::subscriber::set_default(subscriber))
59 }
60
61 fn check_types(ra_fixture: &str) {
62     check_types_impl(ra_fixture, false)
63 }
64
65 fn check_types_source_code(ra_fixture: &str) {
66     check_types_impl(ra_fixture, true)
67 }
68
69 fn check_types_impl(ra_fixture: &str, display_source: bool) {
70     let _tracing = setup_tracing();
71     let db = TestDB::with_files(ra_fixture);
72     let mut checked_one = false;
73     for (file_id, annotations) in db.extract_annotations() {
74         for (range, expected) in annotations {
75             let ty = type_at_range(&db, FileRange { file_id, range });
76             let actual = if display_source {
77                 let module = db.module_for_file(file_id);
78                 ty.display_source_code(&db, module).unwrap()
79             } else {
80                 ty.display_test(&db).to_string()
81             };
82             assert_eq!(expected, actual);
83             checked_one = true;
84         }
85     }
86
87     assert!(checked_one, "no `//^` annotations found");
88 }
89
90 fn check_no_mismatches(ra_fixture: &str) {
91     check_mismatches_impl(ra_fixture, true)
92 }
93
94 #[allow(unused)]
95 fn check_mismatches(ra_fixture: &str) {
96     check_mismatches_impl(ra_fixture, false)
97 }
98
99 fn check_mismatches_impl(ra_fixture: &str, allow_none: bool) {
100     let _tracing = setup_tracing();
101     let (db, file_id) = TestDB::with_single_file(ra_fixture);
102     let module = db.module_for_file(file_id);
103     let def_map = module.def_map(&db);
104
105     let mut defs: Vec<DefWithBodyId> = Vec::new();
106     visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
107     defs.sort_by_key(|def| match def {
108         DefWithBodyId::FunctionId(it) => {
109             let loc = it.lookup(&db);
110             loc.source(&db).value.syntax().text_range().start()
111         }
112         DefWithBodyId::ConstId(it) => {
113             let loc = it.lookup(&db);
114             loc.source(&db).value.syntax().text_range().start()
115         }
116         DefWithBodyId::StaticId(it) => {
117             let loc = it.lookup(&db);
118             loc.source(&db).value.syntax().text_range().start()
119         }
120     });
121     let mut mismatches = HashMap::new();
122     let mut push_mismatch = |src_ptr: InFile<SyntaxNode>, mismatch: TypeMismatch| {
123         let range = src_ptr.value.text_range();
124         if src_ptr.file_id.call_node(&db).is_some() {
125             panic!("type mismatch in macro expansion");
126         }
127         let file_range = FileRange { file_id: src_ptr.file_id.original_file(&db), range };
128         let actual = format!(
129             "expected {}, got {}",
130             mismatch.expected.display_test(&db),
131             mismatch.actual.display_test(&db)
132         );
133         mismatches.insert(file_range, actual);
134     };
135     for def in defs {
136         let (_body, body_source_map) = db.body_with_source_map(def);
137         let inference_result = db.infer(def);
138         for (pat, mismatch) in inference_result.pat_type_mismatches() {
139             let syntax_ptr = match body_source_map.pat_syntax(pat) {
140                 Ok(sp) => {
141                     let root = db.parse_or_expand(sp.file_id).unwrap();
142                     sp.map(|ptr| {
143                         ptr.either(
144                             |it| it.to_node(&root).syntax().clone(),
145                             |it| it.to_node(&root).syntax().clone(),
146                         )
147                     })
148                 }
149                 Err(SyntheticSyntax) => continue,
150             };
151             push_mismatch(syntax_ptr, mismatch.clone());
152         }
153         for (expr, mismatch) in inference_result.expr_type_mismatches() {
154             let node = match body_source_map.expr_syntax(expr) {
155                 Ok(sp) => {
156                     let root = db.parse_or_expand(sp.file_id).unwrap();
157                     sp.map(|ptr| ptr.to_node(&root).syntax().clone())
158                 }
159                 Err(SyntheticSyntax) => continue,
160             };
161             push_mismatch(node, mismatch.clone());
162         }
163     }
164     let mut checked_one = false;
165     for (file_id, annotations) in db.extract_annotations() {
166         for (range, expected) in annotations {
167             let file_range = FileRange { file_id, range };
168             if let Some(mismatch) = mismatches.remove(&file_range) {
169                 assert_eq!(mismatch, expected);
170             } else {
171                 assert!(false, "Expected mismatch not encountered: {}\n", expected);
172             }
173             checked_one = true;
174         }
175     }
176     let mut buf = String::new();
177     for (range, mismatch) in mismatches {
178         format_to!(buf, "{:?}: {}\n", range.range, mismatch,);
179     }
180     assert!(buf.is_empty(), "Unexpected type mismatches:\n{}", buf);
181
182     assert!(checked_one || allow_none, "no `//^` annotations found");
183 }
184
185 fn type_at_range(db: &TestDB, pos: FileRange) -> Ty {
186     let file = db.parse(pos.file_id).ok().unwrap();
187     let expr = algo::find_node_at_range::<ast::Expr>(file.syntax(), pos.range).unwrap();
188     let fn_def = expr.syntax().ancestors().find_map(ast::Fn::cast).unwrap();
189     let module = db.module_for_file(pos.file_id);
190     let func = *module.child_by_source(db)[keys::FUNCTION]
191         .get(&InFile::new(pos.file_id.into(), fn_def))
192         .unwrap();
193
194     let (_body, source_map) = db.body_with_source_map(func.into());
195     if let Some(expr_id) = source_map.node_expr(InFile::new(pos.file_id.into(), &expr)) {
196         let infer = db.infer(func.into());
197         return infer[expr_id].clone();
198     }
199     panic!("Can't find expression")
200 }
201
202 fn infer(ra_fixture: &str) -> String {
203     infer_with_mismatches(ra_fixture, false)
204 }
205
206 fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
207     let _tracing = setup_tracing();
208     let (db, file_id) = TestDB::with_single_file(content);
209
210     let mut buf = String::new();
211
212     let mut infer_def = |inference_result: Arc<InferenceResult>,
213                          body_source_map: Arc<BodySourceMap>| {
214         let mut types: Vec<(InFile<SyntaxNode>, &Ty)> = Vec::new();
215         let mut mismatches: Vec<(InFile<SyntaxNode>, &TypeMismatch)> = Vec::new();
216
217         for (pat, ty) in inference_result.type_of_pat.iter() {
218             let syntax_ptr = match body_source_map.pat_syntax(pat) {
219                 Ok(sp) => {
220                     let root = db.parse_or_expand(sp.file_id).unwrap();
221                     sp.map(|ptr| {
222                         ptr.either(
223                             |it| it.to_node(&root).syntax().clone(),
224                             |it| it.to_node(&root).syntax().clone(),
225                         )
226                     })
227                 }
228                 Err(SyntheticSyntax) => continue,
229             };
230             types.push((syntax_ptr.clone(), ty));
231             if let Some(mismatch) = inference_result.type_mismatch_for_pat(pat) {
232                 mismatches.push((syntax_ptr, mismatch));
233             }
234         }
235
236         for (expr, ty) in inference_result.type_of_expr.iter() {
237             let node = match body_source_map.expr_syntax(expr) {
238                 Ok(sp) => {
239                     let root = db.parse_or_expand(sp.file_id).unwrap();
240                     sp.map(|ptr| ptr.to_node(&root).syntax().clone())
241                 }
242                 Err(SyntheticSyntax) => continue,
243             };
244             types.push((node.clone(), ty));
245             if let Some(mismatch) = inference_result.type_mismatch_for_expr(expr) {
246                 mismatches.push((node, mismatch));
247             }
248         }
249
250         // sort ranges for consistency
251         types.sort_by_key(|(node, _)| {
252             let range = node.value.text_range();
253             (range.start(), range.end())
254         });
255         for (node, ty) in &types {
256             let (range, text) = if let Some(self_param) = ast::SelfParam::cast(node.value.clone()) {
257                 (self_param.name().unwrap().syntax().text_range(), "self".to_string())
258             } else {
259                 (node.value.text_range(), node.value.text().to_string().replace("\n", " "))
260             };
261             let macro_prefix = if node.file_id != file_id.into() { "!" } else { "" };
262             format_to!(
263                 buf,
264                 "{}{:?} '{}': {}\n",
265                 macro_prefix,
266                 range,
267                 ellipsize(text, 15),
268                 ty.display_test(&db)
269             );
270         }
271         if include_mismatches {
272             mismatches.sort_by_key(|(node, _)| {
273                 let range = node.value.text_range();
274                 (range.start(), range.end())
275             });
276             for (src_ptr, mismatch) in &mismatches {
277                 let range = src_ptr.value.text_range();
278                 let macro_prefix = if src_ptr.file_id != file_id.into() { "!" } else { "" };
279                 format_to!(
280                     buf,
281                     "{}{:?}: expected {}, got {}\n",
282                     macro_prefix,
283                     range,
284                     mismatch.expected.display_test(&db),
285                     mismatch.actual.display_test(&db),
286                 );
287             }
288         }
289     };
290
291     let module = db.module_for_file(file_id);
292     let def_map = module.def_map(&db);
293
294     let mut defs: Vec<DefWithBodyId> = Vec::new();
295     visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
296     defs.sort_by_key(|def| match def {
297         DefWithBodyId::FunctionId(it) => {
298             let loc = it.lookup(&db);
299             loc.source(&db).value.syntax().text_range().start()
300         }
301         DefWithBodyId::ConstId(it) => {
302             let loc = it.lookup(&db);
303             loc.source(&db).value.syntax().text_range().start()
304         }
305         DefWithBodyId::StaticId(it) => {
306             let loc = it.lookup(&db);
307             loc.source(&db).value.syntax().text_range().start()
308         }
309     });
310     for def in defs {
311         let (_body, source_map) = db.body_with_source_map(def);
312         let infer = db.infer(def);
313         infer_def(infer, source_map);
314     }
315
316     buf.truncate(buf.trim_end().len());
317     buf
318 }
319
320 fn visit_module(
321     db: &TestDB,
322     crate_def_map: &DefMap,
323     module_id: LocalModuleId,
324     cb: &mut dyn FnMut(DefWithBodyId),
325 ) {
326     visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
327     for impl_id in crate_def_map[module_id].scope.impls() {
328         let impl_data = db.impl_data(impl_id);
329         for &item in impl_data.items.iter() {
330             match item {
331                 AssocItemId::FunctionId(it) => {
332                     let def = it.into();
333                     cb(def);
334                     let body = db.body(def);
335                     visit_body(db, &body, cb);
336                 }
337                 AssocItemId::ConstId(it) => {
338                     let def = it.into();
339                     cb(def);
340                     let body = db.body(def);
341                     visit_body(db, &body, cb);
342                 }
343                 AssocItemId::TypeAliasId(_) => (),
344             }
345         }
346     }
347
348     fn visit_scope(
349         db: &TestDB,
350         crate_def_map: &DefMap,
351         scope: &ItemScope,
352         cb: &mut dyn FnMut(DefWithBodyId),
353     ) {
354         for decl in scope.declarations() {
355             match decl {
356                 ModuleDefId::FunctionId(it) => {
357                     let def = it.into();
358                     cb(def);
359                     let body = db.body(def);
360                     visit_body(db, &body, cb);
361                 }
362                 ModuleDefId::ConstId(it) => {
363                     let def = it.into();
364                     cb(def);
365                     let body = db.body(def);
366                     visit_body(db, &body, cb);
367                 }
368                 ModuleDefId::StaticId(it) => {
369                     let def = it.into();
370                     cb(def);
371                     let body = db.body(def);
372                     visit_body(db, &body, cb);
373                 }
374                 ModuleDefId::TraitId(it) => {
375                     let trait_data = db.trait_data(it);
376                     for &(_, item) in trait_data.items.iter() {
377                         match item {
378                             AssocItemId::FunctionId(it) => cb(it.into()),
379                             AssocItemId::ConstId(it) => cb(it.into()),
380                             AssocItemId::TypeAliasId(_) => (),
381                         }
382                     }
383                 }
384                 ModuleDefId::ModuleId(it) => visit_module(db, crate_def_map, it.local_id, cb),
385                 _ => (),
386             }
387         }
388     }
389
390     fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
391         for (_, def_map) in body.blocks(db) {
392             for (mod_id, _) in def_map.modules() {
393                 visit_module(db, &def_map, mod_id, cb);
394             }
395         }
396     }
397 }
398
399 fn ellipsize(mut text: String, max_len: usize) -> String {
400     if text.len() <= max_len {
401         return text;
402     }
403     let ellipsis = "...";
404     let e_len = ellipsis.len();
405     let mut prefix_len = (max_len - e_len) / 2;
406     while !text.is_char_boundary(prefix_len) {
407         prefix_len += 1;
408     }
409     let mut suffix_len = max_len - e_len - prefix_len;
410     while !text.is_char_boundary(text.len() - suffix_len) {
411         suffix_len += 1;
412     }
413     text.replace_range(prefix_len..text.len() - suffix_len, ellipsis);
414     text
415 }
416
417 fn check_infer(ra_fixture: &str, expect: Expect) {
418     let mut actual = infer(ra_fixture);
419     actual.push('\n');
420     expect.assert_eq(&actual);
421 }
422
423 fn check_infer_with_mismatches(ra_fixture: &str, expect: Expect) {
424     let mut actual = infer_with_mismatches(ra_fixture, true);
425     actual.push('\n');
426     expect.assert_eq(&actual);
427 }
428
429 #[test]
430 fn salsa_bug() {
431     let (mut db, pos) = TestDB::with_position(
432         "
433         //- /lib.rs
434         trait Index {
435             type Output;
436         }
437
438         type Key<S: UnificationStoreBase> = <S as UnificationStoreBase>::Key;
439
440         pub trait UnificationStoreBase: Index<Output = Key<Self>> {
441             type Key;
442
443             fn len(&self) -> usize;
444         }
445
446         pub trait UnificationStoreMut: UnificationStoreBase {
447             fn push(&mut self, value: Self::Key);
448         }
449
450         fn main() {
451             let x = 1;
452             x.push(1);$0
453         }
454     ",
455     );
456
457     let module = db.module_for_file(pos.file_id);
458     let crate_def_map = module.def_map(&db);
459     visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
460         db.infer(def);
461     });
462
463     let new_text = "
464         //- /lib.rs
465         trait Index {
466             type Output;
467         }
468
469         type Key<S: UnificationStoreBase> = <S as UnificationStoreBase>::Key;
470
471         pub trait UnificationStoreBase: Index<Output = Key<Self>> {
472             type Key;
473
474             fn len(&self) -> usize;
475         }
476
477         pub trait UnificationStoreMut: UnificationStoreBase {
478             fn push(&mut self, value: Self::Key);
479         }
480
481         fn main() {
482
483             let x = 1;
484             x.push(1);
485         }
486     "
487     .to_string();
488
489     db.set_file_text(pos.file_id, Arc::new(new_text));
490
491     let module = db.module_for_file(pos.file_id);
492     let crate_def_map = module.def_map(&db);
493     visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
494         db.infer(def);
495     });
496 }