]> git.lizzy.rs Git - rust.git/blob - crates/hir_def/src/test_db.rs
Merge #8549
[rust.git] / crates / hir_def / src / test_db.rs
1 //! Database used for testing `hir_def`.
2
3 use std::{
4     fmt, panic,
5     sync::{Arc, Mutex},
6 };
7
8 use base_db::{salsa, CrateId, FileId, FileLoader, FileLoaderDelegate, FilePosition, Upcast};
9 use base_db::{AnchoredPath, SourceDatabase};
10 use hir_expand::diagnostics::Diagnostic;
11 use hir_expand::diagnostics::DiagnosticSinkBuilder;
12 use hir_expand::{db::AstDatabase, InFile};
13 use rustc_hash::FxHashMap;
14 use rustc_hash::FxHashSet;
15 use syntax::{algo, ast, AstNode, TextRange, TextSize};
16 use test_utils::extract_annotations;
17
18 use crate::{
19     db::DefDatabase,
20     nameres::{DefMap, ModuleSource},
21     src::HasSource,
22     LocalModuleId, Lookup, ModuleDefId, ModuleId,
23 };
24
25 #[salsa::database(
26     base_db::SourceDatabaseExtStorage,
27     base_db::SourceDatabaseStorage,
28     hir_expand::db::AstDatabaseStorage,
29     crate::db::InternDatabaseStorage,
30     crate::db::DefDatabaseStorage
31 )]
32 #[derive(Default)]
33 pub(crate) struct TestDB {
34     storage: salsa::Storage<TestDB>,
35     events: Mutex<Option<Vec<salsa::Event>>>,
36 }
37
38 impl Upcast<dyn AstDatabase> for TestDB {
39     fn upcast(&self) -> &(dyn AstDatabase + 'static) {
40         &*self
41     }
42 }
43
44 impl Upcast<dyn DefDatabase> for TestDB {
45     fn upcast(&self) -> &(dyn DefDatabase + 'static) {
46         &*self
47     }
48 }
49
50 impl salsa::Database for TestDB {
51     fn salsa_event(&self, event: salsa::Event) {
52         let mut events = self.events.lock().unwrap();
53         if let Some(events) = &mut *events {
54             events.push(event);
55         }
56     }
57 }
58
59 impl fmt::Debug for TestDB {
60     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61         f.debug_struct("TestDB").finish()
62     }
63 }
64
65 impl panic::RefUnwindSafe for TestDB {}
66
67 impl FileLoader for TestDB {
68     fn file_text(&self, file_id: FileId) -> Arc<String> {
69         FileLoaderDelegate(self).file_text(file_id)
70     }
71     fn resolve_path(&self, path: AnchoredPath) -> Option<FileId> {
72         FileLoaderDelegate(self).resolve_path(path)
73     }
74     fn relevant_crates(&self, file_id: FileId) -> Arc<FxHashSet<CrateId>> {
75         FileLoaderDelegate(self).relevant_crates(file_id)
76     }
77 }
78
79 impl TestDB {
80     pub(crate) fn module_for_file(&self, file_id: FileId) -> ModuleId {
81         for &krate in self.relevant_crates(file_id).iter() {
82             let crate_def_map = self.crate_def_map(krate);
83             for (local_id, data) in crate_def_map.modules() {
84                 if data.origin.file_id() == Some(file_id) {
85                     return crate_def_map.module_id(local_id);
86                 }
87             }
88         }
89         panic!("Can't find module for file")
90     }
91
92     pub(crate) fn module_at_position(&self, position: FilePosition) -> ModuleId {
93         let file_module = self.module_for_file(position.file_id);
94         let mut def_map = file_module.def_map(self);
95         let module = self.mod_at_position(&def_map, position);
96
97         def_map = match self.block_at_position(&def_map, position) {
98             Some(it) => it,
99             None => return def_map.module_id(module),
100         };
101         loop {
102             let new_map = self.block_at_position(&def_map, position);
103             match new_map {
104                 Some(new_block) if !Arc::ptr_eq(&new_block, &def_map) => {
105                     def_map = new_block;
106                 }
107                 _ => {
108                     // FIXME: handle `mod` inside block expression
109                     return def_map.module_id(def_map.root());
110                 }
111             }
112         }
113     }
114
115     /// Finds the smallest/innermost module in `def_map` containing `position`.
116     fn mod_at_position(&self, def_map: &DefMap, position: FilePosition) -> LocalModuleId {
117         let mut size = None;
118         let mut res = def_map.root();
119         for (module, data) in def_map.modules() {
120             let src = data.definition_source(self);
121             if src.file_id != position.file_id.into() {
122                 continue;
123             }
124
125             let range = match src.value {
126                 ModuleSource::SourceFile(it) => it.syntax().text_range(),
127                 ModuleSource::Module(it) => it.syntax().text_range(),
128                 ModuleSource::BlockExpr(it) => it.syntax().text_range(),
129             };
130
131             if !range.contains(position.offset) {
132                 continue;
133             }
134
135             let new_size = match size {
136                 None => range.len(),
137                 Some(size) => {
138                     if range.len() < size {
139                         range.len()
140                     } else {
141                         size
142                     }
143                 }
144             };
145
146             if size != Some(new_size) {
147                 cov_mark::hit!(submodule_in_testdb);
148                 size = Some(new_size);
149                 res = module;
150             }
151         }
152
153         res
154     }
155
156     fn block_at_position(&self, def_map: &DefMap, position: FilePosition) -> Option<Arc<DefMap>> {
157         // Find the smallest (innermost) function in `def_map` containing the cursor.
158         let mut size = None;
159         let mut fn_def = None;
160         for (_, module) in def_map.modules() {
161             let file_id = module.definition_source(self).file_id;
162             if file_id != position.file_id.into() {
163                 continue;
164             }
165             for decl in module.scope.declarations() {
166                 if let ModuleDefId::FunctionId(it) = decl {
167                     let range = it.lookup(self).source(self).value.syntax().text_range();
168
169                     if !range.contains(position.offset) {
170                         continue;
171                     }
172
173                     let new_size = match size {
174                         None => range.len(),
175                         Some(size) => {
176                             if range.len() < size {
177                                 range.len()
178                             } else {
179                                 size
180                             }
181                         }
182                     };
183                     if size != Some(new_size) {
184                         size = Some(new_size);
185                         fn_def = Some(it);
186                     }
187                 }
188             }
189         }
190
191         // Find the innermost block expression that has a `DefMap`.
192         let def_with_body = fn_def?.into();
193         let (_, source_map) = self.body_with_source_map(def_with_body);
194         let scopes = self.expr_scopes(def_with_body);
195         let root = self.parse(position.file_id);
196
197         let scope_iter = algo::ancestors_at_offset(&root.syntax_node(), position.offset)
198             .filter_map(|node| {
199                 let block = ast::BlockExpr::cast(node)?;
200                 let expr = ast::Expr::from(block);
201                 let expr_id = source_map.node_expr(InFile::new(position.file_id.into(), &expr))?;
202                 let scope = scopes.scope_for(expr_id).unwrap();
203                 Some(scope)
204             });
205
206         for scope in scope_iter {
207             let containing_blocks =
208                 scopes.scope_chain(Some(scope)).filter_map(|scope| scopes.block(scope));
209
210             for block in containing_blocks {
211                 if let Some(def_map) = self.block_def_map(block) {
212                     return Some(def_map);
213                 }
214             }
215         }
216
217         None
218     }
219
220     pub(crate) fn log(&self, f: impl FnOnce()) -> Vec<salsa::Event> {
221         *self.events.lock().unwrap() = Some(Vec::new());
222         f();
223         self.events.lock().unwrap().take().unwrap()
224     }
225
226     pub(crate) fn log_executed(&self, f: impl FnOnce()) -> Vec<String> {
227         let events = self.log(f);
228         events
229             .into_iter()
230             .filter_map(|e| match e.kind {
231                 // This pretty horrible, but `Debug` is the only way to inspect
232                 // QueryDescriptor at the moment.
233                 salsa::EventKind::WillExecute { database_key } => {
234                     Some(format!("{:?}", database_key.debug(self)))
235                 }
236                 _ => None,
237             })
238             .collect()
239     }
240
241     pub(crate) fn extract_annotations(&self) -> FxHashMap<FileId, Vec<(TextRange, String)>> {
242         let mut files = Vec::new();
243         let crate_graph = self.crate_graph();
244         for krate in crate_graph.iter() {
245             let crate_def_map = self.crate_def_map(krate);
246             for (module_id, _) in crate_def_map.modules() {
247                 let file_id = crate_def_map[module_id].origin.file_id();
248                 files.extend(file_id)
249             }
250         }
251         assert!(!files.is_empty());
252         files
253             .into_iter()
254             .filter_map(|file_id| {
255                 let text = self.file_text(file_id);
256                 let annotations = extract_annotations(&text);
257                 if annotations.is_empty() {
258                     return None;
259                 }
260                 Some((file_id, annotations))
261             })
262             .collect()
263     }
264
265     pub(crate) fn diagnostics<F: FnMut(&dyn Diagnostic)>(&self, mut cb: F) {
266         let crate_graph = self.crate_graph();
267         for krate in crate_graph.iter() {
268             let crate_def_map = self.crate_def_map(krate);
269
270             let mut sink = DiagnosticSinkBuilder::new().build(&mut cb);
271             for (module_id, module) in crate_def_map.modules() {
272                 crate_def_map.add_diagnostics(self, module_id, &mut sink);
273
274                 for decl in module.scope.declarations() {
275                     if let ModuleDefId::FunctionId(it) = decl {
276                         let source_map = self.body_with_source_map(it.into()).1;
277                         source_map.add_diagnostics(self, &mut sink);
278                     }
279                 }
280             }
281         }
282     }
283
284     pub(crate) fn check_diagnostics(&self) {
285         let db: &TestDB = self;
286         let annotations = db.extract_annotations();
287         assert!(!annotations.is_empty());
288
289         let mut actual: FxHashMap<FileId, Vec<(TextRange, String)>> = FxHashMap::default();
290         db.diagnostics(|d| {
291             let src = d.display_source();
292             let root = db.parse_or_expand(src.file_id).unwrap();
293
294             let node = src.map(|ptr| ptr.to_node(&root));
295             let frange = node.as_ref().original_file_range(db);
296
297             let message = d.message();
298             actual.entry(frange.file_id).or_default().push((frange.range, message));
299         });
300
301         for (file_id, diags) in actual.iter_mut() {
302             diags.sort_by_key(|it| it.0.start());
303             let text = db.file_text(*file_id);
304             // For multiline spans, place them on line start
305             for (range, content) in diags {
306                 if text[*range].contains('\n') {
307                     *range = TextRange::new(range.start(), range.start() + TextSize::from(1));
308                     *content = format!("... {}", content);
309                 }
310             }
311         }
312
313         assert_eq!(annotations, actual);
314     }
315
316     pub(crate) fn check_no_diagnostics(&self) {
317         let db: &TestDB = self;
318         let annotations = db.extract_annotations();
319         assert!(annotations.is_empty());
320
321         let mut has_diagnostics = false;
322         db.diagnostics(|_| {
323             has_diagnostics = true;
324         });
325
326         assert!(!has_diagnostics);
327     }
328 }