]> git.lizzy.rs Git - rust.git/blobdiff - crates/hir_ty/src/tests.rs
Merge #11842
[rust.git] / crates / hir_ty / src / tests.rs
index ad283c1e04ba5ca688fe3ac39ace2a75eaf937f8..d2f13e4351c73b17cca725446ddfcbd772988e4f 100644 (file)
@@ -7,34 +7,38 @@
 mod method_resolution;
 mod macros;
 mod display_source_code;
+mod incremental;
+mod diagnostics;
 
-use std::{env, sync::Arc};
+use std::{collections::HashMap, env, sync::Arc};
 
-use base_db::{fixture::WithFixture, FileRange, SourceDatabase, SourceDatabaseExt};
+use base_db::{fixture::WithFixture, FileRange, SourceDatabaseExt};
 use expect_test::Expect;
 use hir_def::{
     body::{Body, BodySourceMap, SyntheticSyntax},
-    child_by_source::ChildBySource,
     db::DefDatabase,
+    expr::{ExprId, PatId},
     item_scope::ItemScope,
-    keys,
     nameres::DefMap,
     src::HasSource,
-    AssocItemId, DefWithBodyId, LocalModuleId, Lookup, ModuleDefId,
+    AssocItemId, DefWithBodyId, HasModule, LocalModuleId, Lookup, ModuleDefId,
 };
 use hir_expand::{db::AstDatabase, InFile};
 use once_cell::race::OnceBool;
 use stdx::format_to;
 use syntax::{
-    algo,
-    ast::{self, AstNode, NameOwner},
+    ast::{self, AstNode, HasName},
     SyntaxNode,
 };
 use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};
 use tracing_tree::HierarchicalLayer;
 
 use crate::{
-    db::HirDatabase, display::HirDisplay, infer::TypeMismatch, test_db::TestDB, InferenceResult, Ty,
+    db::HirDatabase,
+    display::HirDisplay,
+    infer::{Adjustment, TypeMismatch},
+    test_db::TestDB,
+    InferenceResult, Ty,
 };
 
 // These tests compare the inference results for all expressions in a file
@@ -58,48 +62,223 @@ fn setup_tracing() -> Option<tracing::subscriber::DefaultGuard> {
 }
 
 fn check_types(ra_fixture: &str) {
-    check_types_impl(ra_fixture, false)
+    check_impl(ra_fixture, false, true, false)
 }
 
 fn check_types_source_code(ra_fixture: &str) {
-    check_types_impl(ra_fixture, true)
+    check_impl(ra_fixture, false, true, true)
 }
 
-fn check_types_impl(ra_fixture: &str, display_source: bool) {
+fn check_no_mismatches(ra_fixture: &str) {
+    check_impl(ra_fixture, true, false, false)
+}
+
+fn check(ra_fixture: &str) {
+    check_impl(ra_fixture, false, false, false)
+}
+
+fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_source: bool) {
     let _tracing = setup_tracing();
-    let db = TestDB::with_files(ra_fixture);
-    let mut checked_one = false;
+    let (db, files) = TestDB::with_many_files(ra_fixture);
+
+    let mut had_annotations = false;
+    let mut mismatches = HashMap::new();
+    let mut types = HashMap::new();
+    let mut adjustments = HashMap::<_, Vec<_>>::new();
     for (file_id, annotations) in db.extract_annotations() {
         for (range, expected) in annotations {
-            let ty = type_at_range(&db, FileRange { file_id, range });
-            let actual = if display_source {
-                let module = db.module_for_file(file_id);
-                ty.display_source_code(&db, module).unwrap()
+            let file_range = FileRange { file_id, range };
+            if only_types {
+                types.insert(file_range, expected);
+            } else if expected.starts_with("type: ") {
+                types.insert(file_range, expected.trim_start_matches("type: ").to_string());
+            } else if expected.starts_with("expected") {
+                mismatches.insert(file_range, expected);
+            } else if expected.starts_with("adjustments: ") {
+                adjustments.insert(
+                    file_range,
+                    expected
+                        .trim_start_matches("adjustments: ")
+                        .split(',')
+                        .map(|it| it.trim().to_string())
+                        .filter(|it| !it.is_empty())
+                        .collect(),
+                );
             } else {
-                ty.display_test(&db).to_string()
+                panic!("unexpected annotation: {}", expected);
+            }
+            had_annotations = true;
+        }
+    }
+    assert!(had_annotations || allow_none, "no `//^` annotations found");
+
+    let mut defs: Vec<DefWithBodyId> = Vec::new();
+    for file_id in files {
+        let module = db.module_for_file_opt(file_id);
+        let module = match module {
+            Some(m) => m,
+            None => continue,
+        };
+        let def_map = module.def_map(&db);
+        visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
+    }
+    defs.sort_by_key(|def| match def {
+        DefWithBodyId::FunctionId(it) => {
+            let loc = it.lookup(&db);
+            loc.source(&db).value.syntax().text_range().start()
+        }
+        DefWithBodyId::ConstId(it) => {
+            let loc = it.lookup(&db);
+            loc.source(&db).value.syntax().text_range().start()
+        }
+        DefWithBodyId::StaticId(it) => {
+            let loc = it.lookup(&db);
+            loc.source(&db).value.syntax().text_range().start()
+        }
+    });
+    let mut unexpected_type_mismatches = String::new();
+    for def in defs {
+        let (_body, body_source_map) = db.body_with_source_map(def);
+        let inference_result = db.infer(def);
+
+        for (pat, ty) in inference_result.type_of_pat.iter() {
+            let node = match pat_node(&body_source_map, pat, &db) {
+                Some(value) => value,
+                None => continue,
             };
-            assert_eq!(expected, actual);
-            checked_one = true;
+            let range = node.as_ref().original_file_range(&db);
+            if let Some(expected) = types.remove(&range) {
+                let actual = if display_source {
+                    ty.display_source_code(&db, def.module(&db)).unwrap()
+                } else {
+                    ty.display_test(&db).to_string()
+                };
+                assert_eq!(actual, expected);
+            }
+        }
+
+        for (expr, ty) in inference_result.type_of_expr.iter() {
+            let node = match expr_node(&body_source_map, expr, &db) {
+                Some(value) => value,
+                None => continue,
+            };
+            let range = node.as_ref().original_file_range(&db);
+            if let Some(expected) = types.remove(&range) {
+                let actual = if display_source {
+                    ty.display_source_code(&db, def.module(&db)).unwrap()
+                } else {
+                    ty.display_test(&db).to_string()
+                };
+                assert_eq!(actual, expected);
+            }
+            if let Some(expected) = adjustments.remove(&range) {
+                if let Some(adjustments) = inference_result.expr_adjustments.get(&expr) {
+                    assert_eq!(
+                        expected,
+                        adjustments
+                            .iter()
+                            .map(|Adjustment { kind, .. }| format!("{:?}", kind))
+                            .collect::<Vec<_>>()
+                    );
+                } else {
+                    panic!("expected {:?} adjustments, found none", expected);
+                }
+            }
+        }
+
+        for (pat, mismatch) in inference_result.pat_type_mismatches() {
+            let node = match pat_node(&body_source_map, pat, &db) {
+                Some(value) => value,
+                None => continue,
+            };
+            let range = node.as_ref().original_file_range(&db);
+            let actual = format!(
+                "expected {}, got {}",
+                mismatch.expected.display_test(&db),
+                mismatch.actual.display_test(&db)
+            );
+            match mismatches.remove(&range) {
+                Some(annotation) => assert_eq!(actual, annotation),
+                None => format_to!(unexpected_type_mismatches, "{:?}: {}\n", range.range, actual),
+            }
+        }
+        for (expr, mismatch) in inference_result.expr_type_mismatches() {
+            let node = match body_source_map.expr_syntax(expr) {
+                Ok(sp) => {
+                    let root = db.parse_or_expand(sp.file_id).unwrap();
+                    sp.map(|ptr| ptr.to_node(&root).syntax().clone())
+                }
+                Err(SyntheticSyntax) => continue,
+            };
+            let range = node.as_ref().original_file_range(&db);
+            let actual = format!(
+                "expected {}, got {}",
+                mismatch.expected.display_test(&db),
+                mismatch.actual.display_test(&db)
+            );
+            match mismatches.remove(&range) {
+                Some(annotation) => assert_eq!(actual, annotation),
+                None => format_to!(unexpected_type_mismatches, "{:?}: {}\n", range.range, actual),
+            }
         }
     }
-    assert!(checked_one, "no `//^` annotations found");
-}
 
-fn type_at_range(db: &TestDB, pos: FileRange) -> Ty {
-    let file = db.parse(pos.file_id).ok().unwrap();
-    let expr = algo::find_node_at_range::<ast::Expr>(file.syntax(), pos.range).unwrap();
-    let fn_def = expr.syntax().ancestors().find_map(ast::Fn::cast).unwrap();
-    let module = db.module_for_file(pos.file_id);
-    let func = *module.child_by_source(db)[keys::FUNCTION]
-        .get(&InFile::new(pos.file_id.into(), fn_def))
-        .unwrap();
-
-    let (_body, source_map) = db.body_with_source_map(func.into());
-    if let Some(expr_id) = source_map.node_expr(InFile::new(pos.file_id.into(), &expr)) {
-        let infer = db.infer(func.into());
-        return infer[expr_id].clone();
+    let mut buf = String::new();
+    if !unexpected_type_mismatches.is_empty() {
+        format_to!(buf, "Unexpected type mismatches:\n{}", unexpected_type_mismatches);
+    }
+    if !mismatches.is_empty() {
+        format_to!(buf, "Unchecked mismatch annotations:\n");
+        for m in mismatches {
+            format_to!(buf, "{:?}: {}\n", m.0.range, m.1);
+        }
+    }
+    if !types.is_empty() {
+        format_to!(buf, "Unchecked type annotations:\n");
+        for t in types {
+            format_to!(buf, "{:?}: type {}\n", t.0.range, t.1);
+        }
     }
-    panic!("Can't find expression")
+    if !adjustments.is_empty() {
+        format_to!(buf, "Unchecked adjustments annotations:\n");
+        for t in adjustments {
+            format_to!(buf, "{:?}: type {:?}\n", t.0.range, t.1);
+        }
+    }
+    assert!(buf.is_empty(), "{}", buf);
+}
+
+fn expr_node(
+    body_source_map: &BodySourceMap,
+    expr: ExprId,
+    db: &TestDB,
+) -> Option<InFile<SyntaxNode>> {
+    Some(match body_source_map.expr_syntax(expr) {
+        Ok(sp) => {
+            let root = db.parse_or_expand(sp.file_id).unwrap();
+            sp.map(|ptr| ptr.to_node(&root).syntax().clone())
+        }
+        Err(SyntheticSyntax) => return None,
+    })
+}
+
+fn pat_node(
+    body_source_map: &BodySourceMap,
+    pat: PatId,
+    db: &TestDB,
+) -> Option<InFile<SyntaxNode>> {
+    Some(match body_source_map.pat_syntax(pat) {
+        Ok(sp) => {
+            let root = db.parse_or_expand(sp.file_id).unwrap();
+            sp.map(|ptr| {
+                ptr.either(
+                    |it| it.to_node(&root).syntax().clone(),
+                    |it| it.to_node(&root).syntax().clone(),
+                )
+            })
+        }
+        Err(SyntheticSyntax) => return None,
+    })
 }
 
 fn infer(ra_fixture: &str) -> String {
@@ -130,7 +309,10 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
                 }
                 Err(SyntheticSyntax) => continue,
             };
-            types.push((syntax_ptr, ty));
+            types.push((syntax_ptr.clone(), ty));
+            if let Some(mismatch) = inference_result.type_mismatch_for_pat(pat) {
+                mismatches.push((syntax_ptr, mismatch));
+            }
         }
 
         for (expr, ty) in inference_result.type_of_expr.iter() {
@@ -156,7 +338,7 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
             let (range, text) = if let Some(self_param) = ast::SelfParam::cast(node.value.clone()) {
                 (self_param.name().unwrap().syntax().text_range(), "self".to_string())
             } else {
-                (node.value.text_range(), node.value.text().to_string().replace("\n", " "))
+                (node.value.text_range(), node.value.text().to_string().replace('\n', " "))
             };
             let macro_prefix = if node.file_id != file_id.into() { "!" } else { "" };
             format_to!(
@@ -288,7 +470,7 @@ fn visit_scope(
     }
 
     fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
-        for def_map in body.block_scopes.iter().filter_map(|block| db.block_def_map(*block)) {
+        for (_, def_map) in body.blocks(db) {
             for (mod_id, _) in def_map.modules() {
                 visit_module(db, &def_map, mod_id, cb);
             }
@@ -314,50 +496,6 @@ fn ellipsize(mut text: String, max_len: usize) -> String {
     text
 }
 
-#[test]
-fn typing_whitespace_inside_a_function_should_not_invalidate_types() {
-    let (mut db, pos) = TestDB::with_position(
-        "
-        //- /lib.rs
-        fn foo() -> i32 {
-            $01 + 1
-        }
-    ",
-    );
-    {
-        let events = db.log_executed(|| {
-            let module = db.module_for_file(pos.file_id);
-            let crate_def_map = module.def_map(&db);
-            visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
-                db.infer(def);
-            });
-        });
-        assert!(format!("{:?}", events).contains("infer"))
-    }
-
-    let new_text = "
-        fn foo() -> i32 {
-            1
-            +
-            1
-        }
-    "
-    .to_string();
-
-    db.set_file_text(pos.file_id, Arc::new(new_text));
-
-    {
-        let events = db.log_executed(|| {
-            let module = db.module_for_file(pos.file_id);
-            let crate_def_map = module.def_map(&db);
-            visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
-                db.infer(def);
-            });
-        });
-        assert!(!format!("{:?}", events).contains("infer"), "{:#?}", events)
-    }
-}
-
 fn check_infer(ra_fixture: &str, expect: Expect) {
     let mut actual = infer(ra_fixture);
     actual.push('\n');