]> git.lizzy.rs Git - rust.git/blob - crates/hir_ty/src/tests.rs
More cleanups, use `check` for `display_source_code` tests
[rust.git] / crates / hir_ty / src / tests.rs
1 mod never_type;
2 mod coercion;
3 mod regression;
4 mod simple;
5 mod patterns;
6 mod traits;
7 mod method_resolution;
8 mod macros;
9 mod display_source_code;
10 mod incremental;
11
12 use std::{collections::HashMap, env, sync::Arc};
13
14 use base_db::{fixture::WithFixture, FileRange, SourceDatabaseExt};
15 use expect_test::Expect;
16 use hir_def::{
17     body::{Body, BodySourceMap, SyntheticSyntax},
18     db::DefDatabase,
19     expr::{ExprId, PatId},
20     item_scope::ItemScope,
21     nameres::DefMap,
22     src::HasSource,
23     AssocItemId, DefWithBodyId, HasModule, LocalModuleId, Lookup, ModuleDefId,
24 };
25 use hir_expand::{db::AstDatabase, InFile};
26 use once_cell::race::OnceBool;
27 use stdx::format_to;
28 use syntax::{
29     ast::{self, AstNode, NameOwner},
30     SyntaxNode,
31 };
32 use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};
33 use tracing_tree::HierarchicalLayer;
34
35 use crate::{
36     db::HirDatabase, display::HirDisplay, infer::TypeMismatch, test_db::TestDB, InferenceResult, Ty,
37 };
38
39 // These tests compare the inference results for all expressions in a file
40 // against snapshots of the expected results using expect. Use
41 // `env UPDATE_EXPECT=1 cargo test -p hir_ty` to update the snapshots.
42
43 fn setup_tracing() -> Option<tracing::subscriber::DefaultGuard> {
44     static ENABLE: OnceBool = OnceBool::new();
45     if !ENABLE.get_or_init(|| env::var("CHALK_DEBUG").is_ok()) {
46         return None;
47     }
48
49     let filter = EnvFilter::from_env("CHALK_DEBUG");
50     let layer = HierarchicalLayer::default()
51         .with_indent_lines(true)
52         .with_ansi(false)
53         .with_indent_amount(2)
54         .with_writer(std::io::stderr);
55     let subscriber = Registry::default().with(filter).with(layer);
56     Some(tracing::subscriber::set_default(subscriber))
57 }
58
59 fn check_types(ra_fixture: &str) {
60     check_impl(ra_fixture, false, true, false)
61 }
62
63 fn check_types_source_code(ra_fixture: &str) {
64     check_impl(ra_fixture, false, true, true)
65 }
66
67 fn check_no_mismatches(ra_fixture: &str) {
68     check_impl(ra_fixture, true, false, false)
69 }
70
71 fn check(ra_fixture: &str) {
72     check_impl(ra_fixture, false, false, false)
73 }
74
75 fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_source: bool) {
76     let _tracing = setup_tracing();
77     let (db, files) = TestDB::with_many_files(ra_fixture);
78
79     let mut had_annotations = false;
80     let mut mismatches = HashMap::new();
81     let mut types = HashMap::new();
82     for (file_id, annotations) in db.extract_annotations() {
83         for (range, expected) in annotations {
84             let file_range = FileRange { file_id, range };
85             if only_types {
86                 types.insert(file_range, expected);
87             } else if expected.starts_with("type: ") {
88                 types.insert(file_range, expected.trim_start_matches("type: ").to_string());
89             } else if expected.starts_with("expected") {
90                 mismatches.insert(file_range, expected);
91             } else {
92                 panic!("unexpected annotation: {}", expected);
93             }
94             had_annotations = true;
95         }
96     }
97     assert!(had_annotations || allow_none, "no `//^` annotations found");
98
99     let mut defs: Vec<DefWithBodyId> = Vec::new();
100     for file_id in files {
101         let module = db.module_for_file_opt(file_id);
102         let module = match module {
103             Some(m) => m,
104             None => continue,
105         };
106         let def_map = module.def_map(&db);
107         visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
108     }
109     defs.sort_by_key(|def| match def {
110         DefWithBodyId::FunctionId(it) => {
111             let loc = it.lookup(&db);
112             loc.source(&db).value.syntax().text_range().start()
113         }
114         DefWithBodyId::ConstId(it) => {
115             let loc = it.lookup(&db);
116             loc.source(&db).value.syntax().text_range().start()
117         }
118         DefWithBodyId::StaticId(it) => {
119             let loc = it.lookup(&db);
120             loc.source(&db).value.syntax().text_range().start()
121         }
122     });
123     let mut unexpected_type_mismatches = String::new();
124     for def in defs {
125         let (_body, body_source_map) = db.body_with_source_map(def);
126         let inference_result = db.infer(def);
127
128         for (pat, ty) in inference_result.type_of_pat.iter() {
129             let node = match pat_node(&body_source_map, pat, &db) {
130                 Some(value) => value,
131                 None => continue,
132             };
133             let range = node.as_ref().original_file_range(&db);
134             if let Some(expected) = types.remove(&range) {
135                 let actual = if display_source {
136                     ty.display_source_code(&db, def.module(&db)).unwrap()
137                 } else {
138                     ty.display_test(&db).to_string()
139                 };
140                 assert_eq!(actual, expected);
141             }
142         }
143
144         for (expr, ty) in inference_result.type_of_expr.iter() {
145             let node = match expr_node(&body_source_map, expr, &db) {
146                 Some(value) => value,
147                 None => continue,
148             };
149             let range = node.as_ref().original_file_range(&db);
150             if let Some(expected) = types.remove(&range) {
151                 let actual = if display_source {
152                     ty.display_source_code(&db, def.module(&db)).unwrap()
153                 } else {
154                     ty.display_test(&db).to_string()
155                 };
156                 assert_eq!(actual, expected);
157             }
158         }
159
160         for (pat, mismatch) in inference_result.pat_type_mismatches() {
161             let node = match pat_node(&body_source_map, pat, &db) {
162                 Some(value) => value,
163                 None => continue,
164             };
165             let range = node.as_ref().original_file_range(&db);
166             let actual = format!(
167                 "expected {}, got {}",
168                 mismatch.expected.display_test(&db),
169                 mismatch.actual.display_test(&db)
170             );
171             if let Some(annotation) = mismatches.remove(&range) {
172                 assert_eq!(actual, annotation);
173             } else {
174                 format_to!(unexpected_type_mismatches, "{:?}: {}\n", range.range, actual);
175             }
176         }
177         for (expr, mismatch) in inference_result.expr_type_mismatches() {
178             let node = match body_source_map.expr_syntax(expr) {
179                 Ok(sp) => {
180                     let root = db.parse_or_expand(sp.file_id).unwrap();
181                     sp.map(|ptr| ptr.to_node(&root).syntax().clone())
182                 }
183                 Err(SyntheticSyntax) => continue,
184             };
185             let range = node.as_ref().original_file_range(&db);
186             let actual = format!(
187                 "expected {}, got {}",
188                 mismatch.expected.display_test(&db),
189                 mismatch.actual.display_test(&db)
190             );
191             if let Some(annotation) = mismatches.remove(&range) {
192                 assert_eq!(actual, annotation);
193             } else {
194                 format_to!(unexpected_type_mismatches, "{:?}: {}\n", range.range, actual);
195             }
196         }
197     }
198
199     let mut buf = String::new();
200     if !unexpected_type_mismatches.is_empty() {
201         format_to!(buf, "Unexpected type mismatches:\n{}", unexpected_type_mismatches);
202     }
203     if !mismatches.is_empty() {
204         format_to!(buf, "Unchecked mismatch annotations:\n");
205         for m in mismatches {
206             format_to!(buf, "{:?}: {}\n", m.0.range, m.1);
207         }
208     }
209     if !types.is_empty() {
210         format_to!(buf, "Unchecked type annotations:\n");
211         for t in types {
212             format_to!(buf, "{:?}: type {}\n", t.0.range, t.1);
213         }
214     }
215     assert!(buf.is_empty(), "{}", buf);
216 }
217
218 fn expr_node(
219     body_source_map: &BodySourceMap,
220     expr: ExprId,
221     db: &TestDB,
222 ) -> Option<InFile<SyntaxNode>> {
223     Some(match body_source_map.expr_syntax(expr) {
224         Ok(sp) => {
225             let root = db.parse_or_expand(sp.file_id).unwrap();
226             sp.map(|ptr| ptr.to_node(&root).syntax().clone())
227         }
228         Err(SyntheticSyntax) => return None,
229     })
230 }
231
232 fn pat_node(
233     body_source_map: &BodySourceMap,
234     pat: PatId,
235     db: &TestDB,
236 ) -> Option<InFile<SyntaxNode>> {
237     Some(match body_source_map.pat_syntax(pat) {
238         Ok(sp) => {
239             let root = db.parse_or_expand(sp.file_id).unwrap();
240             sp.map(|ptr| {
241                 ptr.either(
242                     |it| it.to_node(&root).syntax().clone(),
243                     |it| it.to_node(&root).syntax().clone(),
244                 )
245             })
246         }
247         Err(SyntheticSyntax) => return None,
248     })
249 }
250
251 fn infer(ra_fixture: &str) -> String {
252     infer_with_mismatches(ra_fixture, false)
253 }
254
255 fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
256     let _tracing = setup_tracing();
257     let (db, file_id) = TestDB::with_single_file(content);
258
259     let mut buf = String::new();
260
261     let mut infer_def = |inference_result: Arc<InferenceResult>,
262                          body_source_map: Arc<BodySourceMap>| {
263         let mut types: Vec<(InFile<SyntaxNode>, &Ty)> = Vec::new();
264         let mut mismatches: Vec<(InFile<SyntaxNode>, &TypeMismatch)> = Vec::new();
265
266         for (pat, ty) in inference_result.type_of_pat.iter() {
267             let syntax_ptr = match body_source_map.pat_syntax(pat) {
268                 Ok(sp) => {
269                     let root = db.parse_or_expand(sp.file_id).unwrap();
270                     sp.map(|ptr| {
271                         ptr.either(
272                             |it| it.to_node(&root).syntax().clone(),
273                             |it| it.to_node(&root).syntax().clone(),
274                         )
275                     })
276                 }
277                 Err(SyntheticSyntax) => continue,
278             };
279             types.push((syntax_ptr.clone(), ty));
280             if let Some(mismatch) = inference_result.type_mismatch_for_pat(pat) {
281                 mismatches.push((syntax_ptr, mismatch));
282             }
283         }
284
285         for (expr, ty) in inference_result.type_of_expr.iter() {
286             let node = match body_source_map.expr_syntax(expr) {
287                 Ok(sp) => {
288                     let root = db.parse_or_expand(sp.file_id).unwrap();
289                     sp.map(|ptr| ptr.to_node(&root).syntax().clone())
290                 }
291                 Err(SyntheticSyntax) => continue,
292             };
293             types.push((node.clone(), ty));
294             if let Some(mismatch) = inference_result.type_mismatch_for_expr(expr) {
295                 mismatches.push((node, mismatch));
296             }
297         }
298
299         // sort ranges for consistency
300         types.sort_by_key(|(node, _)| {
301             let range = node.value.text_range();
302             (range.start(), range.end())
303         });
304         for (node, ty) in &types {
305             let (range, text) = if let Some(self_param) = ast::SelfParam::cast(node.value.clone()) {
306                 (self_param.name().unwrap().syntax().text_range(), "self".to_string())
307             } else {
308                 (node.value.text_range(), node.value.text().to_string().replace("\n", " "))
309             };
310             let macro_prefix = if node.file_id != file_id.into() { "!" } else { "" };
311             format_to!(
312                 buf,
313                 "{}{:?} '{}': {}\n",
314                 macro_prefix,
315                 range,
316                 ellipsize(text, 15),
317                 ty.display_test(&db)
318             );
319         }
320         if include_mismatches {
321             mismatches.sort_by_key(|(node, _)| {
322                 let range = node.value.text_range();
323                 (range.start(), range.end())
324             });
325             for (src_ptr, mismatch) in &mismatches {
326                 let range = src_ptr.value.text_range();
327                 let macro_prefix = if src_ptr.file_id != file_id.into() { "!" } else { "" };
328                 format_to!(
329                     buf,
330                     "{}{:?}: expected {}, got {}\n",
331                     macro_prefix,
332                     range,
333                     mismatch.expected.display_test(&db),
334                     mismatch.actual.display_test(&db),
335                 );
336             }
337         }
338     };
339
340     let module = db.module_for_file(file_id);
341     let def_map = module.def_map(&db);
342
343     let mut defs: Vec<DefWithBodyId> = Vec::new();
344     visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
345     defs.sort_by_key(|def| match def {
346         DefWithBodyId::FunctionId(it) => {
347             let loc = it.lookup(&db);
348             loc.source(&db).value.syntax().text_range().start()
349         }
350         DefWithBodyId::ConstId(it) => {
351             let loc = it.lookup(&db);
352             loc.source(&db).value.syntax().text_range().start()
353         }
354         DefWithBodyId::StaticId(it) => {
355             let loc = it.lookup(&db);
356             loc.source(&db).value.syntax().text_range().start()
357         }
358     });
359     for def in defs {
360         let (_body, source_map) = db.body_with_source_map(def);
361         let infer = db.infer(def);
362         infer_def(infer, source_map);
363     }
364
365     buf.truncate(buf.trim_end().len());
366     buf
367 }
368
369 fn visit_module(
370     db: &TestDB,
371     crate_def_map: &DefMap,
372     module_id: LocalModuleId,
373     cb: &mut dyn FnMut(DefWithBodyId),
374 ) {
375     visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
376     for impl_id in crate_def_map[module_id].scope.impls() {
377         let impl_data = db.impl_data(impl_id);
378         for &item in impl_data.items.iter() {
379             match item {
380                 AssocItemId::FunctionId(it) => {
381                     let def = it.into();
382                     cb(def);
383                     let body = db.body(def);
384                     visit_body(db, &body, cb);
385                 }
386                 AssocItemId::ConstId(it) => {
387                     let def = it.into();
388                     cb(def);
389                     let body = db.body(def);
390                     visit_body(db, &body, cb);
391                 }
392                 AssocItemId::TypeAliasId(_) => (),
393             }
394         }
395     }
396
397     fn visit_scope(
398         db: &TestDB,
399         crate_def_map: &DefMap,
400         scope: &ItemScope,
401         cb: &mut dyn FnMut(DefWithBodyId),
402     ) {
403         for decl in scope.declarations() {
404             match decl {
405                 ModuleDefId::FunctionId(it) => {
406                     let def = it.into();
407                     cb(def);
408                     let body = db.body(def);
409                     visit_body(db, &body, cb);
410                 }
411                 ModuleDefId::ConstId(it) => {
412                     let def = it.into();
413                     cb(def);
414                     let body = db.body(def);
415                     visit_body(db, &body, cb);
416                 }
417                 ModuleDefId::StaticId(it) => {
418                     let def = it.into();
419                     cb(def);
420                     let body = db.body(def);
421                     visit_body(db, &body, cb);
422                 }
423                 ModuleDefId::TraitId(it) => {
424                     let trait_data = db.trait_data(it);
425                     for &(_, item) in trait_data.items.iter() {
426                         match item {
427                             AssocItemId::FunctionId(it) => cb(it.into()),
428                             AssocItemId::ConstId(it) => cb(it.into()),
429                             AssocItemId::TypeAliasId(_) => (),
430                         }
431                     }
432                 }
433                 ModuleDefId::ModuleId(it) => visit_module(db, crate_def_map, it.local_id, cb),
434                 _ => (),
435             }
436         }
437     }
438
439     fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
440         for (_, def_map) in body.blocks(db) {
441             for (mod_id, _) in def_map.modules() {
442                 visit_module(db, &def_map, mod_id, cb);
443             }
444         }
445     }
446 }
447
448 fn ellipsize(mut text: String, max_len: usize) -> String {
449     if text.len() <= max_len {
450         return text;
451     }
452     let ellipsis = "...";
453     let e_len = ellipsis.len();
454     let mut prefix_len = (max_len - e_len) / 2;
455     while !text.is_char_boundary(prefix_len) {
456         prefix_len += 1;
457     }
458     let mut suffix_len = max_len - e_len - prefix_len;
459     while !text.is_char_boundary(text.len() - suffix_len) {
460         suffix_len += 1;
461     }
462     text.replace_range(prefix_len..text.len() - suffix_len, ellipsis);
463     text
464 }
465
466 fn check_infer(ra_fixture: &str, expect: Expect) {
467     let mut actual = infer(ra_fixture);
468     actual.push('\n');
469     expect.assert_eq(&actual);
470 }
471
472 fn check_infer_with_mismatches(ra_fixture: &str, expect: Expect) {
473     let mut actual = infer_with_mismatches(ra_fixture, true);
474     actual.push('\n');
475     expect.assert_eq(&actual);
476 }
477
478 #[test]
479 fn salsa_bug() {
480     let (mut db, pos) = TestDB::with_position(
481         "
482         //- /lib.rs
483         trait Index {
484             type Output;
485         }
486
487         type Key<S: UnificationStoreBase> = <S as UnificationStoreBase>::Key;
488
489         pub trait UnificationStoreBase: Index<Output = Key<Self>> {
490             type Key;
491
492             fn len(&self) -> usize;
493         }
494
495         pub trait UnificationStoreMut: UnificationStoreBase {
496             fn push(&mut self, value: Self::Key);
497         }
498
499         fn main() {
500             let x = 1;
501             x.push(1);$0
502         }
503     ",
504     );
505
506     let module = db.module_for_file(pos.file_id);
507     let crate_def_map = module.def_map(&db);
508     visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
509         db.infer(def);
510     });
511
512     let new_text = "
513         //- /lib.rs
514         trait Index {
515             type Output;
516         }
517
518         type Key<S: UnificationStoreBase> = <S as UnificationStoreBase>::Key;
519
520         pub trait UnificationStoreBase: Index<Output = Key<Self>> {
521             type Key;
522
523             fn len(&self) -> usize;
524         }
525
526         pub trait UnificationStoreMut: UnificationStoreBase {
527             fn push(&mut self, value: Self::Key);
528         }
529
530         fn main() {
531
532             let x = 1;
533             x.push(1);
534         }
535     "
536     .to_string();
537
538     db.set_file_text(pos.file_id, Arc::new(new_text));
539
540     let module = db.module_for_file(pos.file_id);
541     let crate_def_map = module.def_map(&db);
542     visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
543         db.infer(def);
544     });
545 }