]> git.lizzy.rs Git - rust.git/blob - crates/hir_ty/src/tests.rs
Merge #9027
[rust.git] / crates / hir_ty / src / tests.rs
1 mod never_type;
2 mod coercion;
3 mod regression;
4 mod simple;
5 mod patterns;
6 mod traits;
7 mod method_resolution;
8 mod macros;
9 mod display_source_code;
10 mod incremental;
11
12 use std::{env, sync::Arc};
13
14 use base_db::{fixture::WithFixture, FileRange, SourceDatabase, SourceDatabaseExt};
15 use expect_test::Expect;
16 use hir_def::{
17     body::{Body, BodySourceMap, SyntheticSyntax},
18     child_by_source::ChildBySource,
19     db::DefDatabase,
20     item_scope::ItemScope,
21     keys,
22     nameres::DefMap,
23     src::HasSource,
24     AssocItemId, DefWithBodyId, LocalModuleId, Lookup, ModuleDefId,
25 };
26 use hir_expand::{db::AstDatabase, InFile};
27 use once_cell::race::OnceBool;
28 use stdx::format_to;
29 use syntax::{
30     algo,
31     ast::{self, AstNode, NameOwner},
32     SyntaxNode,
33 };
34 use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};
35 use tracing_tree::HierarchicalLayer;
36
37 use crate::{
38     db::HirDatabase, display::HirDisplay, infer::TypeMismatch, test_db::TestDB, InferenceResult, Ty,
39 };
40
41 // These tests compare the inference results for all expressions in a file
42 // against snapshots of the expected results using expect. Use
43 // `env UPDATE_EXPECT=1 cargo test -p hir_ty` to update the snapshots.
44
45 fn setup_tracing() -> Option<tracing::subscriber::DefaultGuard> {
46     static ENABLE: OnceBool = OnceBool::new();
47     if !ENABLE.get_or_init(|| env::var("CHALK_DEBUG").is_ok()) {
48         return None;
49     }
50
51     let filter = EnvFilter::from_env("CHALK_DEBUG");
52     let layer = HierarchicalLayer::default()
53         .with_indent_lines(true)
54         .with_ansi(false)
55         .with_indent_amount(2)
56         .with_writer(std::io::stderr);
57     let subscriber = Registry::default().with(filter).with(layer);
58     Some(tracing::subscriber::set_default(subscriber))
59 }
60
61 fn check_types(ra_fixture: &str) {
62     check_types_impl(ra_fixture, false)
63 }
64
65 fn check_types_source_code(ra_fixture: &str) {
66     check_types_impl(ra_fixture, true)
67 }
68
69 fn check_types_impl(ra_fixture: &str, display_source: bool) {
70     let _tracing = setup_tracing();
71     let db = TestDB::with_files(ra_fixture);
72     let mut checked_one = false;
73     for (file_id, annotations) in db.extract_annotations() {
74         for (range, expected) in annotations {
75             let ty = type_at_range(&db, FileRange { file_id, range });
76             let actual = if display_source {
77                 let module = db.module_for_file(file_id);
78                 ty.display_source_code(&db, module).unwrap()
79             } else {
80                 ty.display_test(&db).to_string()
81             };
82             assert_eq!(expected, actual);
83             checked_one = true;
84         }
85     }
86     assert!(checked_one, "no `//^` annotations found");
87 }
88
89 fn type_at_range(db: &TestDB, pos: FileRange) -> Ty {
90     let file = db.parse(pos.file_id).ok().unwrap();
91     let expr = algo::find_node_at_range::<ast::Expr>(file.syntax(), pos.range).unwrap();
92     let fn_def = expr.syntax().ancestors().find_map(ast::Fn::cast).unwrap();
93     let module = db.module_for_file(pos.file_id);
94     let func = *module.child_by_source(db)[keys::FUNCTION]
95         .get(&InFile::new(pos.file_id.into(), fn_def))
96         .unwrap();
97
98     let (_body, source_map) = db.body_with_source_map(func.into());
99     if let Some(expr_id) = source_map.node_expr(InFile::new(pos.file_id.into(), &expr)) {
100         let infer = db.infer(func.into());
101         return infer[expr_id].clone();
102     }
103     panic!("Can't find expression")
104 }
105
106 fn infer(ra_fixture: &str) -> String {
107     infer_with_mismatches(ra_fixture, false)
108 }
109
110 fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
111     let _tracing = setup_tracing();
112     let (db, file_id) = TestDB::with_single_file(content);
113
114     let mut buf = String::new();
115
116     let mut infer_def = |inference_result: Arc<InferenceResult>,
117                          body_source_map: Arc<BodySourceMap>| {
118         let mut types: Vec<(InFile<SyntaxNode>, &Ty)> = Vec::new();
119         let mut mismatches: Vec<(InFile<SyntaxNode>, &TypeMismatch)> = Vec::new();
120
121         for (pat, ty) in inference_result.type_of_pat.iter() {
122             let syntax_ptr = match body_source_map.pat_syntax(pat) {
123                 Ok(sp) => {
124                     let root = db.parse_or_expand(sp.file_id).unwrap();
125                     sp.map(|ptr| {
126                         ptr.either(
127                             |it| it.to_node(&root).syntax().clone(),
128                             |it| it.to_node(&root).syntax().clone(),
129                         )
130                     })
131                 }
132                 Err(SyntheticSyntax) => continue,
133             };
134             types.push((syntax_ptr.clone(), ty));
135             if let Some(mismatch) = inference_result.type_mismatch_for_pat(pat) {
136                 mismatches.push((syntax_ptr, mismatch));
137             }
138         }
139
140         for (expr, ty) in inference_result.type_of_expr.iter() {
141             let node = match body_source_map.expr_syntax(expr) {
142                 Ok(sp) => {
143                     let root = db.parse_or_expand(sp.file_id).unwrap();
144                     sp.map(|ptr| ptr.to_node(&root).syntax().clone())
145                 }
146                 Err(SyntheticSyntax) => continue,
147             };
148             types.push((node.clone(), ty));
149             if let Some(mismatch) = inference_result.type_mismatch_for_expr(expr) {
150                 mismatches.push((node, mismatch));
151             }
152         }
153
154         // sort ranges for consistency
155         types.sort_by_key(|(node, _)| {
156             let range = node.value.text_range();
157             (range.start(), range.end())
158         });
159         for (node, ty) in &types {
160             let (range, text) = if let Some(self_param) = ast::SelfParam::cast(node.value.clone()) {
161                 (self_param.name().unwrap().syntax().text_range(), "self".to_string())
162             } else {
163                 (node.value.text_range(), node.value.text().to_string().replace("\n", " "))
164             };
165             let macro_prefix = if node.file_id != file_id.into() { "!" } else { "" };
166             format_to!(
167                 buf,
168                 "{}{:?} '{}': {}\n",
169                 macro_prefix,
170                 range,
171                 ellipsize(text, 15),
172                 ty.display_test(&db)
173             );
174         }
175         if include_mismatches {
176             mismatches.sort_by_key(|(node, _)| {
177                 let range = node.value.text_range();
178                 (range.start(), range.end())
179             });
180             for (src_ptr, mismatch) in &mismatches {
181                 let range = src_ptr.value.text_range();
182                 let macro_prefix = if src_ptr.file_id != file_id.into() { "!" } else { "" };
183                 format_to!(
184                     buf,
185                     "{}{:?}: expected {}, got {}\n",
186                     macro_prefix,
187                     range,
188                     mismatch.expected.display_test(&db),
189                     mismatch.actual.display_test(&db),
190                 );
191             }
192         }
193     };
194
195     let module = db.module_for_file(file_id);
196     let def_map = module.def_map(&db);
197
198     let mut defs: Vec<DefWithBodyId> = Vec::new();
199     visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
200     defs.sort_by_key(|def| match def {
201         DefWithBodyId::FunctionId(it) => {
202             let loc = it.lookup(&db);
203             loc.source(&db).value.syntax().text_range().start()
204         }
205         DefWithBodyId::ConstId(it) => {
206             let loc = it.lookup(&db);
207             loc.source(&db).value.syntax().text_range().start()
208         }
209         DefWithBodyId::StaticId(it) => {
210             let loc = it.lookup(&db);
211             loc.source(&db).value.syntax().text_range().start()
212         }
213     });
214     for def in defs {
215         let (_body, source_map) = db.body_with_source_map(def);
216         let infer = db.infer(def);
217         infer_def(infer, source_map);
218     }
219
220     buf.truncate(buf.trim_end().len());
221     buf
222 }
223
224 fn visit_module(
225     db: &TestDB,
226     crate_def_map: &DefMap,
227     module_id: LocalModuleId,
228     cb: &mut dyn FnMut(DefWithBodyId),
229 ) {
230     visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
231     for impl_id in crate_def_map[module_id].scope.impls() {
232         let impl_data = db.impl_data(impl_id);
233         for &item in impl_data.items.iter() {
234             match item {
235                 AssocItemId::FunctionId(it) => {
236                     let def = it.into();
237                     cb(def);
238                     let body = db.body(def);
239                     visit_body(db, &body, cb);
240                 }
241                 AssocItemId::ConstId(it) => {
242                     let def = it.into();
243                     cb(def);
244                     let body = db.body(def);
245                     visit_body(db, &body, cb);
246                 }
247                 AssocItemId::TypeAliasId(_) => (),
248             }
249         }
250     }
251
252     fn visit_scope(
253         db: &TestDB,
254         crate_def_map: &DefMap,
255         scope: &ItemScope,
256         cb: &mut dyn FnMut(DefWithBodyId),
257     ) {
258         for decl in scope.declarations() {
259             match decl {
260                 ModuleDefId::FunctionId(it) => {
261                     let def = it.into();
262                     cb(def);
263                     let body = db.body(def);
264                     visit_body(db, &body, cb);
265                 }
266                 ModuleDefId::ConstId(it) => {
267                     let def = it.into();
268                     cb(def);
269                     let body = db.body(def);
270                     visit_body(db, &body, cb);
271                 }
272                 ModuleDefId::StaticId(it) => {
273                     let def = it.into();
274                     cb(def);
275                     let body = db.body(def);
276                     visit_body(db, &body, cb);
277                 }
278                 ModuleDefId::TraitId(it) => {
279                     let trait_data = db.trait_data(it);
280                     for &(_, item) in trait_data.items.iter() {
281                         match item {
282                             AssocItemId::FunctionId(it) => cb(it.into()),
283                             AssocItemId::ConstId(it) => cb(it.into()),
284                             AssocItemId::TypeAliasId(_) => (),
285                         }
286                     }
287                 }
288                 ModuleDefId::ModuleId(it) => visit_module(db, crate_def_map, it.local_id, cb),
289                 _ => (),
290             }
291         }
292     }
293
294     fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
295         for (_, def_map) in body.blocks(db) {
296             for (mod_id, _) in def_map.modules() {
297                 visit_module(db, &def_map, mod_id, cb);
298             }
299         }
300     }
301 }
302
303 fn ellipsize(mut text: String, max_len: usize) -> String {
304     if text.len() <= max_len {
305         return text;
306     }
307     let ellipsis = "...";
308     let e_len = ellipsis.len();
309     let mut prefix_len = (max_len - e_len) / 2;
310     while !text.is_char_boundary(prefix_len) {
311         prefix_len += 1;
312     }
313     let mut suffix_len = max_len - e_len - prefix_len;
314     while !text.is_char_boundary(text.len() - suffix_len) {
315         suffix_len += 1;
316     }
317     text.replace_range(prefix_len..text.len() - suffix_len, ellipsis);
318     text
319 }
320
321 fn check_infer(ra_fixture: &str, expect: Expect) {
322     let mut actual = infer(ra_fixture);
323     actual.push('\n');
324     expect.assert_eq(&actual);
325 }
326
327 fn check_infer_with_mismatches(ra_fixture: &str, expect: Expect) {
328     let mut actual = infer_with_mismatches(ra_fixture, true);
329     actual.push('\n');
330     expect.assert_eq(&actual);
331 }
332
333 #[test]
334 fn salsa_bug() {
335     let (mut db, pos) = TestDB::with_position(
336         "
337         //- /lib.rs
338         trait Index {
339             type Output;
340         }
341
342         type Key<S: UnificationStoreBase> = <S as UnificationStoreBase>::Key;
343
344         pub trait UnificationStoreBase: Index<Output = Key<Self>> {
345             type Key;
346
347             fn len(&self) -> usize;
348         }
349
350         pub trait UnificationStoreMut: UnificationStoreBase {
351             fn push(&mut self, value: Self::Key);
352         }
353
354         fn main() {
355             let x = 1;
356             x.push(1);$0
357         }
358     ",
359     );
360
361     let module = db.module_for_file(pos.file_id);
362     let crate_def_map = module.def_map(&db);
363     visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
364         db.infer(def);
365     });
366
367     let new_text = "
368         //- /lib.rs
369         trait Index {
370             type Output;
371         }
372
373         type Key<S: UnificationStoreBase> = <S as UnificationStoreBase>::Key;
374
375         pub trait UnificationStoreBase: Index<Output = Key<Self>> {
376             type Key;
377
378             fn len(&self) -> usize;
379         }
380
381         pub trait UnificationStoreMut: UnificationStoreBase {
382             fn push(&mut self, value: Self::Key);
383         }
384
385         fn main() {
386
387             let x = 1;
388             x.push(1);
389         }
390     "
391     .to_string();
392
393     db.set_file_text(pos.file_id, Arc::new(new_text));
394
395     let module = db.module_for_file(pos.file_id);
396     let crate_def_map = module.def_map(&db);
397     visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
398         db.infer(def);
399     });
400 }