]> git.lizzy.rs Git - rust.git/commitdiff
refactor: use hir to test if a value is returned
authorCôme ALLART <come.allart@etu.emse.fr>
Sat, 11 Dec 2021 19:52:14 +0000 (20:52 +0100)
committerCôme ALLART <come.allart@etu.emse.fr>
Sat, 11 Dec 2021 19:52:14 +0000 (20:52 +0100)
crates/ide_assists/src/handlers/generate_documentation_template.rs

index a2fe3463b6faf5042f44a46358fe5819f87d2780..36b3321f91d33dfc8bf4ce7a49de52706a058cd0 100644 (file)
@@ -159,7 +159,7 @@ fn safety_builder(ast_func: &ast::Fn) -> Option<Vec<String>> {
 fn gen_ex_template(ast_func: &ast::Fn, ctx: &AssistContext) -> Option<Vec<String>> {
     let (mut lines, ex_helper) = gen_ex_start_helper(ast_func, ctx)?;
     // Call the function, check result
-    if returns_a_value(ast_func) {
+    if returns_a_value(ast_func, ctx) {
         if count_parameters(&ex_helper.param_list) < 3 {
             lines.push(format!("assert_eq!({}, );", ex_helper.function_call));
         } else {
@@ -183,7 +183,7 @@ fn gen_ex_template(ast_func: &ast::Fn, ctx: &AssistContext) -> Option<Vec<String
 /// `None` if the function has a `self` parameter but is not in an `impl`.
 fn gen_panic_ex_template(ast_func: &ast::Fn, ctx: &AssistContext) -> Option<Vec<String>> {
     let (mut lines, ex_helper) = gen_ex_start_helper(ast_func, ctx)?;
-    match returns_a_value(ast_func) {
+    match returns_a_value(ast_func, ctx) {
         true => lines.push(format!("let _ = {}; // panics", ex_helper.function_call)),
         false => lines.push(format!("{}; // panics", ex_helper.function_call)),
     }
@@ -424,11 +424,12 @@ fn return_type(ast_func: &ast::Fn) -> Option<ast::Type> {
 }
 
 /// Helper function to determine if the function returns some data
-fn returns_a_value(ast_func: &ast::Fn) -> bool {
-    match return_type(ast_func) {
-        Some(ret_type) => !["()", "!"].contains(&ret_type.to_string().as_str()),
-        None => false,
-    }
+fn returns_a_value(ast_func: &ast::Fn, ctx: &AssistContext) -> bool {
+    ctx.sema
+        .to_def(ast_func)
+        .map(|hir_func| hir_func.ret_type(ctx.db()))
+        .map(|ret_ty| !ret_ty.is_unit() && !ret_ty.is_never())
+        .unwrap_or(false)
 }
 
 #[cfg(test)]