]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide/src/runnables.rs
Unnest ide::display::navigation_target module
[rust.git] / crates / ide / src / runnables.rs
index 21130e06075846281c9190a9487685b77168e2fe..3078789d1236fe309b83813b7a3f9fce14f73059 100644 (file)
@@ -2,7 +2,7 @@
 
 use ast::HasName;
 use cfg::CfgExpr;
-use hir::{AsAssocItem, HasAttrs, HasSource, HirDisplay, InFile, Semantics};
+use hir::{AsAssocItem, HasAttrs, HasSource, HirDisplay, Semantics};
 use ide_assists::utils::test_related_attribute;
 use ide_db::{
     base_db::{FilePosition, FileRange},
 use itertools::Itertools;
 use rustc_hash::{FxHashMap, FxHashSet};
 use stdx::{always, format_to};
-use syntax::ast::{self, AstNode, HasAttrs as _};
-
-use crate::{
-    display::{ToNav, TryToNav},
-    references, FileId, NavigationTarget,
+use syntax::{
+    ast::{self, AstNode, HasAttrs as _},
+    SmolStr, SyntaxNode,
 };
 
+use crate::{references, FileId, NavigationTarget, ToNav, TryToNav};
+
 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
 pub struct Runnable {
     pub use_name_in_title: bool,
@@ -31,7 +31,7 @@ pub struct Runnable {
 
 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
 pub enum TestId {
-    Name(String),
+    Name(SmolStr),
     Path(String),
 }
 
@@ -206,68 +206,71 @@ pub(crate) fn related_tests(
 ) -> Vec<Runnable> {
     let sema = Semantics::new(db);
     let mut res: FxHashSet<Runnable> = FxHashSet::default();
+    let syntax = sema.parse(position.file_id).syntax().clone();
 
-    find_related_tests(&sema, position, search_scope, &mut res);
+    find_related_tests(&sema, &syntax, position, search_scope, &mut res);
 
-    res.into_iter().collect_vec()
+    res.into_iter().collect()
 }
 
 fn find_related_tests(
     sema: &Semantics<RootDatabase>,
+    syntax: &SyntaxNode,
     position: FilePosition,
     search_scope: Option<SearchScope>,
     tests: &mut FxHashSet<Runnable>,
 ) {
-    if let Some(refs) = references::find_all_refs(sema, position, search_scope) {
-        for (file_id, refs) in refs.into_iter().flat_map(|refs| refs.references) {
-            let file = sema.parse(file_id);
-            let file = file.syntax();
-
-            // create flattened vec of tokens
-            let tokens = refs.iter().flat_map(|(range, _)| {
-                match file.token_at_offset(range.start()).next() {
-                    Some(token) => sema.descend_into_macros(token),
-                    None => Default::default(),
-                }
-            });
-
-            // find first suitable ancestor
-            let functions = tokens
-                .filter_map(|token| token.ancestors().find_map(ast::Fn::cast))
-                .map(|f| hir::InFile::new(sema.hir_file_for(f.syntax()), f));
-
-            for fn_def in functions {
-                let InFile { value: fn_def, .. } = &fn_def;
-                if let Some(runnable) = as_test_runnable(sema, fn_def) {
+    let defs = references::find_defs(sema, syntax, position.offset);
+    for def in defs {
+        let defs = def
+            .usages(sema)
+            .set_scope(search_scope.clone())
+            .all()
+            .references
+            .into_values()
+            .flatten();
+        for ref_ in defs {
+            let name_ref = match ref_.name {
+                ast::NameLike::NameRef(name_ref) => name_ref,
+                _ => continue,
+            };
+            if let Some(fn_def) =
+                sema.ancestors_with_macros(name_ref.syntax().clone()).find_map(ast::Fn::cast)
+            {
+                if let Some(runnable) = as_test_runnable(sema, &fn_def) {
                     // direct test
                     tests.insert(runnable);
-                } else if let Some(module) = parent_test_module(sema, fn_def) {
+                } else if let Some(module) = parent_test_module(sema, &fn_def) {
                     // indirect test
-                    find_related_tests_in_module(sema, fn_def, &module, tests);
+                    find_related_tests_in_module(sema, syntax, &fn_def, &module, tests);
                 }
             }
         }
     }
 }
+
 fn find_related_tests_in_module(
     sema: &Semantics<RootDatabase>,
+    syntax: &SyntaxNode,
     fn_def: &ast::Fn,
     parent_module: &hir::Module,
     tests: &mut FxHashSet<Runnable>,
 ) {
-    if let Some(fn_name) = fn_def.name() {
-        let mod_source = parent_module.definition_source(sema.db);
-        let range = match mod_source.value {
-            hir::ModuleSource::Module(m) => m.syntax().text_range(),
-            hir::ModuleSource::BlockExpr(b) => b.syntax().text_range(),
-            hir::ModuleSource::SourceFile(f) => f.syntax().text_range(),
-        };
+    let fn_name = match fn_def.name() {
+        Some(it) => it,
+        _ => return,
+    };
+    let mod_source = parent_module.definition_source(sema.db);
+    let range = match &mod_source.value {
+        hir::ModuleSource::Module(m) => m.syntax().text_range(),
+        hir::ModuleSource::BlockExpr(b) => b.syntax().text_range(),
+        hir::ModuleSource::SourceFile(f) => f.syntax().text_range(),
+    };
 
-        let file_id = mod_source.file_id.original_file(sema.db);
-        let mod_scope = SearchScope::file_range(FileRange { file_id, range });
-        let fn_pos = FilePosition { file_id, offset: fn_name.syntax().text_range().start() };
-        find_related_tests(sema, fn_pos, Some(mod_scope), tests)
-    }
+    let file_id = mod_source.file_id.original_file(sema.db);
+    let mod_scope = SearchScope::file_range(FileRange { file_id, range });
+    let fn_pos = FilePosition { file_id, offset: fn_name.syntax().text_range().start() };
+    find_related_tests(sema, syntax, fn_pos, Some(mod_scope), tests)
 }
 
 fn as_test_runnable(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Option<Runnable> {
@@ -294,24 +297,26 @@ fn parent_test_module(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Optio
 
 pub(crate) fn runnable_fn(sema: &Semantics<RootDatabase>, def: hir::Function) -> Option<Runnable> {
     let func = def.source(sema.db)?;
-    let name_string = def.name(sema.db).to_string();
+    let name = def.name(sema.db).to_smol_str();
 
     let root = def.module(sema.db).krate().root_module(sema.db);
 
-    let kind = if name_string == "main" && def.module(sema.db) == root {
+    let kind = if name == "main" && def.module(sema.db) == root {
         RunnableKind::Bin
     } else {
-        let canonical_path = {
-            let def: hir::ModuleDef = def.into();
-            def.canonical_path(sema.db)
+        let test_id = || {
+            let canonical_path = {
+                let def: hir::ModuleDef = def.into();
+                def.canonical_path(sema.db)
+            };
+            canonical_path.map(TestId::Path).unwrap_or(TestId::Name(name))
         };
-        let test_id = canonical_path.map(TestId::Path).unwrap_or(TestId::Name(name_string));
 
         if test_related_attribute(&func.value).is_some() {
             let attr = TestAttr::from_fn(&func.value);
-            RunnableKind::Test { test_id, attr }
+            RunnableKind::Test { test_id: test_id(), attr }
         } else if func.value.has_atom_attr("bench") {
-            RunnableKind::Bench { test_id }
+            RunnableKind::Bench { test_id: test_id() }
         } else {
             return None;
         }
@@ -430,7 +435,7 @@ fn module_def_doctest(db: &RootDatabase, def: Definition) -> Option<Runnable> {
         Some(path)
     })();
 
-    let test_id = path.map_or_else(|| TestId::Name(def_name.to_string()), TestId::Path);
+    let test_id = path.map_or_else(|| TestId::Name(def_name.to_smol_str()), TestId::Path);
 
     let mut nav = match def {
         Definition::Module(def) => NavigationTarget::from_module_to_decl(db, def),