]> git.lizzy.rs Git - rust.git/blobdiff - src/tools/rust-analyzer/crates/hir-ty/src/tests.rs
:arrow_up: rust-analyzer
[rust.git] / src / tools / rust-analyzer / crates / hir-ty / src / tests.rs
index d2f13e4351c73b17cca725446ddfcbd772988e4f..ebbc5410147c6b8d63e8b4df5623fba06db5b4d5 100644 (file)
@@ -16,7 +16,7 @@
 use expect_test::Expect;
 use hir_def::{
     body::{Body, BodySourceMap, SyntheticSyntax},
-    db::DefDatabase,
+    db::{DefDatabase, InternDatabase},
     expr::{ExprId, PatId},
     item_scope::ItemScope,
     nameres::DefMap,
@@ -135,6 +135,10 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour
             let loc = it.lookup(&db);
             loc.source(&db).value.syntax().text_range().start()
         }
+        DefWithBodyId::VariantId(it) => {
+            let loc = db.lookup_intern_enum(it.parent);
+            loc.source(&db).value.syntax().text_range().start()
+        }
     });
     let mut unexpected_type_mismatches = String::new();
     for def in defs {
@@ -388,6 +392,10 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
             let loc = it.lookup(&db);
             loc.source(&db).value.syntax().text_range().start()
         }
+        DefWithBodyId::VariantId(it) => {
+            let loc = db.lookup_intern_enum(it.parent);
+            loc.source(&db).value.syntax().text_range().start()
+        }
     });
     for def in defs {
         let (_body, source_map) = db.body_with_source_map(def);
@@ -453,6 +461,18 @@ fn visit_scope(
                     let body = db.body(def);
                     visit_body(db, &body, cb);
                 }
+                ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => {
+                    db.enum_data(it)
+                        .variants
+                        .iter()
+                        .map(|(id, _)| hir_def::EnumVariantId { parent: it, local_id: id })
+                        .for_each(|it| {
+                            let def = it.into();
+                            cb(def);
+                            let body = db.body(def);
+                            visit_body(db, &body, cb);
+                        });
+                }
                 ModuleDefId::TraitId(it) => {
                     let trait_data = db.trait_data(it);
                     for &(_, item) in trait_data.items.iter() {