]> git.lizzy.rs Git - rust.git/commitdiff
fix hir for new block syntax
authorAleksey Kladov <aleksey.kladov@gmail.com>
Mon, 2 Sep 2019 18:23:19 +0000 (21:23 +0300)
committerAleksey Kladov <aleksey.kladov@gmail.com>
Mon, 2 Sep 2019 18:23:19 +0000 (21:23 +0300)
14 files changed:
crates/ra_assists/src/move_guard.rs
crates/ra_assists/src/replace_if_let_with_match.rs
crates/ra_fmt/src/lib.rs
crates/ra_hir/src/code_model/src.rs
crates/ra_hir/src/expr.rs
crates/ra_hir/src/expr/scope.rs
crates/ra_hir/src/expr/validation.rs
crates/ra_hir/src/source_binder.rs
crates/ra_hir/src/ty/tests.rs
crates/ra_ide_api/src/join_lines.rs
crates/ra_syntax/src/ast/expr_extensions.rs
crates/ra_syntax/src/ast/generated.rs
crates/ra_syntax/src/ast/traits.rs
crates/ra_syntax/src/grammar.ron

index 127c9e068273e35bcb7624521b9ae01a75718de3..699221e33567138df6c4734be95df2b57c3d53b4 100644 (file)
@@ -65,9 +65,9 @@ pub(crate) fn move_arm_cond_to_match_guard(mut ctx: AssistCtx<impl HirDatabase>)
         "move condition to match guard",
         |edit| {
             edit.target(if_expr.syntax().text_range());
-            let then_only_expr = then_block.statements().next().is_none();
+            let then_only_expr = then_block.block().and_then(|it| it.statements().next()).is_none();
 
-            match &then_block.expr() {
+            match &then_block.block().and_then(|it| it.expr()) {
                 Some(then_expr) if then_only_expr => {
                     edit.replace(if_expr.syntax().text_range(), then_expr.syntax().text())
                 }
index c0bf6d23512936a4bd1ee123f16e6bbd0135d606..401835c579f4a748ea795c76892b3346bbb8b033 100644 (file)
@@ -1,3 +1,4 @@
+use format_buf::format;
 use hir::db::HirDatabase;
 use ra_fmt::extract_trivial_expression;
 use ra_syntax::{ast, AstNode};
@@ -25,16 +26,21 @@ pub(crate) fn replace_if_let_with_match(mut ctx: AssistCtx<impl HirDatabase>) ->
     ctx.build()
 }
 
-fn build_match_expr(expr: ast::Expr, pat1: ast::Pat, arm1: ast::Block, arm2: ast::Block) -> String {
+fn build_match_expr(
+    expr: ast::Expr,
+    pat1: ast::Pat,
+    arm1: ast::BlockExpr,
+    arm2: ast::BlockExpr,
+) -> String {
     let mut buf = String::new();
-    buf.push_str(&format!("match {} {{\n", expr.syntax().text()));
-    buf.push_str(&format!("    {} => {}\n", pat1.syntax().text(), format_arm(&arm1)));
-    buf.push_str(&format!("    _ => {}\n", format_arm(&arm2)));
+    format!(buf, "match {} {{\n", expr.syntax().text());
+    format!(buf, "    {} => {}\n", pat1.syntax().text(), format_arm(&arm1));
+    format!(buf, "    _ => {}\n", format_arm(&arm2));
     buf.push_str("}");
     buf
 }
 
-fn format_arm(block: &ast::Block) -> String {
+fn format_arm(block: &ast::BlockExpr) -> String {
     match extract_trivial_expression(block) {
         None => block.syntax().text().to_string(),
         Some(e) => format!("{},", e.syntax().text()),
index b09478d7a371e3a783709d575531a4a6ff098f31..e22ac9753f036155317d3ac967efbe5289274939 100644 (file)
@@ -34,7 +34,8 @@ fn prev_tokens(token: SyntaxToken) -> impl Iterator<Item = SyntaxToken> {
     successors(token.prev_token(), |token| token.prev_token())
 }
 
-pub fn extract_trivial_expression(block: &ast::Block) -> Option<ast::Expr> {
+pub fn extract_trivial_expression(expr: &ast::BlockExpr) -> Option<ast::Expr> {
+    let block = expr.block()?;
     let expr = block.expr()?;
     if expr.syntax().text().contains_char('\n') {
         return None;
index e5bae16ab5b88a8c36f4a593460979044cd1552d..7c9454c0b584f4f92738880ce9bc58c3ebe1b8c6 100644 (file)
@@ -119,10 +119,10 @@ fn expr_source(
         expr_id: crate::expr::ExprId,
     ) -> Option<Source<ast::Expr>> {
         let source_map = self.body_source_map(db);
-        let expr_syntax = source_map.expr_syntax(expr_id)?;
+        let expr_syntax = source_map.expr_syntax(expr_id)?.a()?;
         let source = self.source(db);
-        let node = expr_syntax.to_node(&source.ast.syntax());
-        ast::Expr::cast(node).map(|ast| Source { file_id: source.file_id, ast })
+        let ast = expr_syntax.to_node(&source.ast.syntax());
+        Some(Source { file_id: source.file_id, ast })
     }
 }
 
index c7530849b299d2be2b93470163c04d563426d690..5c95bed40c13371207d2ca9c02c639d2592894c8 100644 (file)
@@ -9,7 +9,7 @@
         self, ArgListOwner, ArrayExprKind, LiteralKind, LoopBodyOwner, NameOwner,
         TypeAscriptionOwner,
     },
-    AstNode, AstPtr, SyntaxNodePtr,
+    AstNode, AstPtr,
 };
 use test_utils::tested_by;
 
@@ -56,13 +56,14 @@ pub struct Body {
 /// file, so that we don't recompute types whenever some whitespace is typed.
 #[derive(Default, Debug, Eq, PartialEq)]
 pub struct BodySourceMap {
-    expr_map: FxHashMap<SyntaxNodePtr, ExprId>,
-    expr_map_back: ArenaMap<ExprId, SyntaxNodePtr>,
+    expr_map: FxHashMap<ExprPtr, ExprId>,
+    expr_map_back: ArenaMap<ExprId, ExprPtr>,
     pat_map: FxHashMap<PatPtr, PatId>,
     pat_map_back: ArenaMap<PatId, PatPtr>,
     field_map: FxHashMap<(ExprId, usize), AstPtr<ast::RecordField>>,
 }
 
+type ExprPtr = Either<AstPtr<ast::Expr>, AstPtr<ast::RecordField>>;
 type PatPtr = Either<AstPtr<ast::Pat>, AstPtr<ast::SelfParam>>;
 
 impl Body {
@@ -128,16 +129,12 @@ fn index(&self, pat: PatId) -> &Pat {
 }
 
 impl BodySourceMap {
-    pub(crate) fn expr_syntax(&self, expr: ExprId) -> Option<SyntaxNodePtr> {
+    pub(crate) fn expr_syntax(&self, expr: ExprId) -> Option<ExprPtr> {
         self.expr_map_back.get(expr).cloned()
     }
 
-    pub(crate) fn syntax_expr(&self, ptr: SyntaxNodePtr) -> Option<ExprId> {
-        self.expr_map.get(&ptr).cloned()
-    }
-
     pub(crate) fn node_expr(&self, node: &ast::Expr) -> Option<ExprId> {
-        self.expr_map.get(&SyntaxNodePtr::new(node.syntax())).cloned()
+        self.expr_map.get(&Either::A(AstPtr::new(node))).cloned()
     }
 
     pub(crate) fn pat_syntax(&self, pat: PatId) -> Option<PatPtr> {
@@ -575,11 +572,12 @@ fn new(owner: DefWithBody, file_id: HirFileId, resolver: Resolver, db: &'a DB) -
             current_file_id: file_id,
         }
     }
-    fn alloc_expr(&mut self, expr: Expr, syntax_ptr: SyntaxNodePtr) -> ExprId {
+    fn alloc_expr(&mut self, expr: Expr, ptr: AstPtr<ast::Expr>) -> ExprId {
+        let ptr = Either::A(ptr);
         let id = self.exprs.alloc(expr);
         if self.current_file_id == self.original_file_id {
-            self.source_map.expr_map.insert(syntax_ptr, id);
-            self.source_map.expr_map_back.insert(id, syntax_ptr);
+            self.source_map.expr_map.insert(ptr, id);
+            self.source_map.expr_map_back.insert(id, ptr);
         }
         id
     }
@@ -601,7 +599,7 @@ fn empty_block(&mut self) -> ExprId {
     }
 
     fn collect_expr(&mut self, expr: ast::Expr) -> ExprId {
-        let syntax_ptr = SyntaxNodePtr::new(expr.syntax());
+        let syntax_ptr = AstPtr::new(&expr);
         match expr {
             ast::Expr::IfExpr(e) => {
                 let then_branch = self.collect_block_opt(e.then_branch());
@@ -640,10 +638,10 @@ fn collect_expr(&mut self, expr: ast::Expr) -> ExprId {
                 self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr)
             }
             ast::Expr::TryBlockExpr(e) => {
-                let body = self.collect_block_opt(e.block());
+                let body = self.collect_block_opt(e.body());
                 self.alloc_expr(Expr::TryBlock { body }, syntax_ptr)
             }
-            ast::Expr::BlockExpr(e) => self.collect_block_opt(e.block()),
+            ast::Expr::BlockExpr(e) => self.collect_block(e),
             ast::Expr::LoopExpr(e) => {
                 let body = self.collect_block_opt(e.loop_body());
                 self.alloc_expr(Expr::Loop { body }, syntax_ptr)
@@ -739,7 +737,7 @@ fn collect_expr(&mut self, expr: ast::Expr) -> ExprId {
             ast::Expr::ParenExpr(e) => {
                 let inner = self.collect_expr_opt(e.expr());
                 // make the paren expr point to the inner expression as well
-                self.source_map.expr_map.insert(syntax_ptr, inner);
+                self.source_map.expr_map.insert(Either::A(syntax_ptr), inner);
                 inner
             }
             ast::Expr::ReturnExpr(e) => {
@@ -763,12 +761,9 @@ fn collect_expr(&mut self, expr: ast::Expr) -> ExprId {
                             } else if let Some(nr) = field.name_ref() {
                                 // field shorthand
                                 let id = self.exprs.alloc(Expr::Path(Path::from_name_ref(&nr)));
-                                self.source_map
-                                    .expr_map
-                                    .insert(SyntaxNodePtr::new(nr.syntax()), id);
-                                self.source_map
-                                    .expr_map_back
-                                    .insert(id, SyntaxNodePtr::new(nr.syntax()));
+                                let ptr = Either::B(AstPtr::new(&field));
+                                self.source_map.expr_map.insert(ptr, id);
+                                self.source_map.expr_map_back.insert(id, ptr);
                                 id
                             } else {
                                 self.exprs.alloc(Expr::Missing)
@@ -942,7 +937,12 @@ fn collect_expr_opt(&mut self, expr: Option<ast::Expr>) -> ExprId {
         }
     }
 
-    fn collect_block(&mut self, block: ast::Block) -> ExprId {
+    fn collect_block(&mut self, expr: ast::BlockExpr) -> ExprId {
+        let syntax_node_ptr = AstPtr::new(&expr.clone().into());
+        let block = match expr.block() {
+            Some(block) => block,
+            None => return self.alloc_expr(Expr::Missing, syntax_node_ptr),
+        };
         let statements = block
             .statements()
             .map(|s| match s {
@@ -956,11 +956,11 @@ fn collect_block(&mut self, block: ast::Block) -> ExprId {
             })
             .collect();
         let tail = block.expr().map(|e| self.collect_expr(e));
-        self.alloc_expr(Expr::Block { statements, tail }, SyntaxNodePtr::new(block.syntax()))
+        self.alloc_expr(Expr::Block { statements, tail }, syntax_node_ptr)
     }
 
-    fn collect_block_opt(&mut self, block: Option<ast::Block>) -> ExprId {
-        if let Some(block) = block {
+    fn collect_block_opt(&mut self, expr: Option<ast::BlockExpr>) -> ExprId {
+        if let Some(block) = expr {
             self.collect_block(block)
         } else {
             self.exprs.alloc(Expr::Missing)
index 79e1857f9365185f6bac5395e34cb42b67605e99..b6d7f3fc14baf87b34e2a6086ff236e56d18e40f 100644 (file)
@@ -172,7 +172,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope
 #[cfg(test)]
 mod tests {
     use ra_db::SourceDatabase;
-    use ra_syntax::{algo::find_node_at_offset, ast, AstNode, SyntaxNodePtr};
+    use ra_syntax::{algo::find_node_at_offset, ast, AstNode};
     use test_utils::{assert_eq_text, extract_offset};
 
     use crate::{mock::MockDatabase, source_binder::SourceAnalyzer};
@@ -194,8 +194,7 @@ fn do_check(code: &str, expected: &[&str]) {
         let analyzer = SourceAnalyzer::new(&db, file_id, marker.syntax(), None);
 
         let scopes = analyzer.scopes();
-        let expr_id =
-            analyzer.body_source_map().syntax_expr(SyntaxNodePtr::new(marker.syntax())).unwrap();
+        let expr_id = analyzer.body_source_map().node_expr(&marker.into()).unwrap();
         let scope = scopes.scope_for(expr_id);
 
         let actual = scopes
index c8ae198696e2b5d1481046e69956d3547a2c449f..6fdaf1fce3bedcd72d50fbda3bbccbfca0dea8fc 100644 (file)
@@ -1,7 +1,7 @@
-use rustc_hash::FxHashSet;
 use std::sync::Arc;
 
-use ra_syntax::ast::{AstNode, RecordLit};
+use ra_syntax::ast::{self, AstNode};
+use rustc_hash::FxHashSet;
 
 use super::{Expr, ExprId, RecordLitField};
 use crate::{
@@ -13,7 +13,6 @@
     ty::{ApplicationTy, InferenceResult, Ty, TypeCtor},
     Function, HasSource, HirDatabase, ModuleDef, Name, Path, PerNs, Resolution,
 };
-use ra_syntax::ast;
 
 pub(crate) struct ExprValidator<'a, 'b: 'a> {
     func: Function,
@@ -84,8 +83,12 @@ fn validate_record_literal(
         let source_file = parse.tree();
         if let Some(field_list_node) = source_map
             .expr_syntax(id)
+            .and_then(|ptr| ptr.a())
             .map(|ptr| ptr.to_node(source_file.syntax()))
-            .and_then(RecordLit::cast)
+            .and_then(|expr| match expr {
+                ast::Expr::RecordLit(it) => Some(it),
+                _ => None,
+            })
             .and_then(|lit| lit.record_field_list())
         {
             let field_list_ptr = AstPtr::new(&field_list_node);
@@ -135,7 +138,7 @@ fn validate_results_in_tail_expr(
             let source_map = self.func.body_source_map(db);
             let file_id = self.func.source(db).file_id;
 
-            if let Some(expr) = source_map.expr_syntax(id).and_then(|n| n.cast::<ast::Expr>()) {
+            if let Some(expr) = source_map.expr_syntax(id).and_then(|n| n.a()) {
                 self.sink.push(MissingOkInTailExpr { file: file_id, expr });
             }
         }
index 43aec201a7b28bd92df542c087fa1bb0c9eec46e..e5f4d11a6413a39770291f70794254f87b1875be 100644 (file)
@@ -462,8 +462,8 @@ fn scope_for(
     node: &SyntaxNode,
 ) -> Option<ScopeId> {
     node.ancestors()
-        .map(|it| SyntaxNodePtr::new(&it))
-        .filter_map(|ptr| source_map.syntax_expr(ptr))
+        .filter_map(ast::Expr::cast)
+        .filter_map(|it| source_map.node_expr(&it))
         .find_map(|it| scopes.scope_for(it))
 }
 
@@ -475,7 +475,10 @@ fn scope_for_offset(
     scopes
         .scope_by_expr()
         .iter()
-        .filter_map(|(id, scope)| Some((source_map.expr_syntax(*id)?, scope)))
+        .filter_map(|(id, scope)| {
+            let ast_ptr = source_map.expr_syntax(*id)?.a()?;
+            Some((ast_ptr.syntax_node_ptr(), scope))
+        })
         // find containing scope
         .min_by_key(|(ptr, _scope)| {
             (!(ptr.range().start() <= offset && offset <= ptr.range().end()), ptr.range().len())
@@ -495,7 +498,10 @@ fn adjust(
     let child_scopes = scopes
         .scope_by_expr()
         .iter()
-        .filter_map(|(id, scope)| Some((source_map.expr_syntax(*id)?, scope)))
+        .filter_map(|(id, scope)| {
+            let ast_ptr = source_map.expr_syntax(*id)?.a()?;
+            Some((ast_ptr.syntax_node_ptr(), scope))
+        })
         .map(|(ptr, scope)| (ptr.range(), scope))
         .filter(|(range, _)| range.start() <= offset && range.is_subrange(&r) && *range != r);
 
index b034fd59e9d68d22f24abd566cab12f2e9e0c406..d344ab12e78b714267fdf93e28c6b5005d407d9e 100644 (file)
@@ -3582,7 +3582,7 @@ fn infer(content: &str) -> String {
 
         for (expr, ty) in inference_result.type_of_expr.iter() {
             let syntax_ptr = match body_source_map.expr_syntax(expr) {
-                Some(sp) => sp,
+                Some(sp) => sp.either(|it| it.syntax_node_ptr(), |it| it.syntax_node_ptr()),
                 None => continue,
             };
             types.push((syntax_ptr, ty));
index a2e4b6f3cfdb2d8e28379f74dde468e0a302e7e3..a71e4ed7dcd926e8333dfc89e2525a08a9667846 100644 (file)
@@ -123,7 +123,7 @@ fn has_comma_after(node: &SyntaxNode) -> bool {
 fn join_single_expr_block(edit: &mut TextEditBuilder, token: &SyntaxToken) -> Option<()> {
     let block = ast::Block::cast(token.parent())?;
     let block_expr = ast::BlockExpr::cast(block.syntax().parent()?)?;
-    let expr = extract_trivial_expression(&block)?;
+    let expr = extract_trivial_expression(&block_expr)?;
 
     let block_range = block_expr.syntax().text_range();
     let mut buf = expr.syntax().text().to_string();
index d7ea4354df470d678005e206612e56bb0dab01a9..1324965cfbd6332e27af27fc26dc14d12d7560a4 100644 (file)
@@ -9,12 +9,12 @@
 
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub enum ElseBranch {
-    Block(ast::Block),
+    Block(ast::BlockExpr),
     IfExpr(ast::IfExpr),
 }
 
 impl ast::IfExpr {
-    pub fn then_branch(&self) -> Option<ast::Block> {
+    pub fn then_branch(&self) -> Option<ast::BlockExpr> {
         self.blocks().nth(0)
     }
     pub fn else_branch(&self) -> Option<ElseBranch> {
@@ -28,7 +28,7 @@ pub fn else_branch(&self) -> Option<ElseBranch> {
         Some(res)
     }
 
-    fn blocks(&self) -> AstChildren<ast::Block> {
+    fn blocks(&self) -> AstChildren<ast::BlockExpr> {
         children(self)
     }
 }
index fd85a323151ecf24d12fa00b390c75a1c9540f3f..e2a92ae604b1de24d83c795ca291141e4f0522e8 100644 (file)
@@ -3135,7 +3135,7 @@ fn syntax(&self) -> &SyntaxNode {
     }
 }
 impl TryBlockExpr {
-    pub fn block(&self) -> Option<Block> {
+    pub fn body(&self) -> Option<BlockExpr> {
         AstChildren::new(&self.syntax).next()
     }
 }
index 20c251fbad21ee80339a4bdf0828e82d92eae734..c3e676d4c25d7f8a73af697c732f7da64e5e7634 100644 (file)
@@ -28,7 +28,7 @@ fn visibility(&self) -> Option<ast::Visibility> {
 }
 
 pub trait LoopBodyOwner: AstNode {
-    fn loop_body(&self) -> Option<ast::Block> {
+    fn loop_body(&self) -> Option<ast::BlockExpr> {
         child_opt(self)
     }
 }
index 37166182f9205d4a40d284fd8379012f2c15e029..c14ee0e856c907a3a0db1a36fa8cb28b5b36eeca 100644 (file)
@@ -426,7 +426,7 @@ Grammar(
             traits: ["LoopBodyOwner"],
         ),
         "TryBlockExpr": (
-            options: ["Block"],
+            options: [["body", "BlockExpr"]],
         ),
         "ForExpr": (
             traits: ["LoopBodyOwner"],