]> git.lizzy.rs Git - rust.git/commitdiff
Auto merge of #35348 - scottcarr:discriminant2, r=nikomatsakis
authorbors <bors@rust-lang.org>
Sat, 13 Aug 2016 08:20:46 +0000 (01:20 -0700)
committerGitHub <noreply@github.com>
Sat, 13 Aug 2016 08:20:46 +0000 (01:20 -0700)
[MIR] Add explicit SetDiscriminant StatementKind for deaggregating enums

cc #35186

To deaggregate enums, we need to be able to explicitly set the discriminant.  This PR implements a new StatementKind that does that.

I think some of the places that have `panics!` now could maybe do something smarter.

12 files changed:
src/librustc/mir/repr.rs
src/librustc/mir/visit.rs
src/librustc_borrowck/borrowck/mir/dataflow/impls.rs
src/librustc_borrowck/borrowck/mir/dataflow/sanity_check.rs
src/librustc_borrowck/borrowck/mir/gather_moves.rs
src/librustc_borrowck/borrowck/mir/mod.rs
src/librustc_mir/transform/deaggregator.rs
src/librustc_mir/transform/promote_consts.rs
src/librustc_mir/transform/type_check.rs
src/librustc_trans/mir/constant.rs
src/librustc_trans/mir/statement.rs
src/test/mir-opt/deaggregator_test_enum.rs [new file with mode: 0644]

index 93507246241de62bde905a68e9fde7acc47e7c47..08614ca253be51867b6c55205ce88aa04ba07b00 100644 (file)
@@ -689,13 +689,17 @@ pub struct Statement<'tcx> {
 #[derive(Clone, Debug, RustcEncodable, RustcDecodable)]
 pub enum StatementKind<'tcx> {
     Assign(Lvalue<'tcx>, Rvalue<'tcx>),
+    SetDiscriminant{ lvalue: Lvalue<'tcx>, variant_index: usize },
 }
 
 impl<'tcx> Debug for Statement<'tcx> {
     fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
         use self::StatementKind::*;
         match self.kind {
-            Assign(ref lv, ref rv) => write!(fmt, "{:?} = {:?}", lv, rv)
+            Assign(ref lv, ref rv) => write!(fmt, "{:?} = {:?}", lv, rv),
+            SetDiscriminant{lvalue: ref lv, variant_index: index} => {
+                write!(fmt, "discriminant({:?}) = {:?}", lv, index)
+            }
         }
     }
 }
index 3f714ff4d5152b3eef0a134fc49b92232e4b0d62..d44f00ed2cbe2abb4636a3b631b04fb27f8aa4ab 100644 (file)
@@ -323,6 +323,9 @@ fn super_statement(&mut self,
                                           ref $($mutability)* rvalue) => {
                         self.visit_assign(block, lvalue, rvalue);
                     }
+                    StatementKind::SetDiscriminant{ ref $($mutability)* lvalue, .. } => {
+                        self.visit_lvalue(lvalue, LvalueContext::Store);
+                    }
                 }
             }
 
index 932b748520170ac11c594920a74f047c7eca3345..57b335bd5eee41f5c1d51f8a5fd216366eda5b14 100644 (file)
@@ -442,6 +442,9 @@ fn statement_effect(&self,
         }
         let bits_per_block = self.bits_per_block(ctxt);
         match stmt.kind {
+            repr::StatementKind::SetDiscriminant { .. } => {
+                span_bug!(stmt.source_info.span, "SetDiscriminant should not exist in borrowck");
+            }
             repr::StatementKind::Assign(ref lvalue, _) => {
                 // assigning into this `lvalue` kills all
                 // MoveOuts from it, and *also* all MoveOuts
index d59bdf93f3225e4e33fb32689fd74b172e9aa238..ccde429a17113f8f5e798cb396385a31a141b0d2 100644 (file)
@@ -104,6 +104,9 @@ fn each_block<'a, 'tcx, O>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
             repr::StatementKind::Assign(ref lvalue, ref rvalue) => {
                 (lvalue, rvalue)
             }
+            repr::StatementKind::SetDiscriminant{ .. } =>
+                span_bug!(stmt.source_info.span,
+                          "sanity_check should run before Deaggregator inserts SetDiscriminant"),
         };
 
         if lvalue == peek_arg_lval {
index 05412216d487c9e7153b6b2ed9c05d1928cf9c53..e965dcc169c2de9abe6c55eb4ee05020b2c6d3f1 100644 (file)
@@ -616,6 +616,10 @@ fn gather_moves<'a, 'tcx>(mir: &Mir<'tcx>, tcx: TyCtxt<'a, 'tcx, 'tcx>) -> MoveD
                         Rvalue::InlineAsm { .. } => {}
                     }
                 }
+                StatementKind::SetDiscriminant{ .. } => {
+                    span_bug!(stmt.source_info.span,
+                              "SetDiscriminant should not exist during borrowck");
+                }
             }
         }
 
index 7c912e8bac6bbe9ede63801c133e23c07c19f91f..c563fdb8f44e6dd8a18db69f49bcb89914cbbbbe 100644 (file)
@@ -369,6 +369,9 @@ fn drop_flag_effects_for_location<'a, 'tcx, F>(
     let block = &mir[loc.block];
     match block.statements.get(loc.index) {
         Some(stmt) => match stmt.kind {
+            repr::StatementKind::SetDiscriminant{ .. } => {
+                span_bug!(stmt.source_info.span, "SetDiscrimant should not exist during borrowck");
+            }
             repr::StatementKind::Assign(ref lvalue, _) => {
                 debug!("drop_flag_effects: assignment {:?}", stmt);
                  on_all_children_bits(tcx, mir, move_data,
index fccd4a607fdcf1cd2e9a520b7ee0efe7b2d5cecb..cd6f0ed9cbac68a81313211464ccc802eb75c1c1 100644 (file)
@@ -39,7 +39,7 @@ fn run_pass<'a>(&mut self, tcx: TyCtxt<'a, 'tcx, 'tcx>,
 
         let mut curr: usize = 0;
         for bb in mir.basic_blocks_mut() {
-            let idx = match get_aggregate_statement(curr, &bb.statements) {
+            let idx = match get_aggregate_statement_index(curr, &bb.statements) {
                 Some(idx) => idx,
                 None => continue,
             };
@@ -48,7 +48,11 @@ fn run_pass<'a>(&mut self, tcx: TyCtxt<'a, 'tcx, 'tcx>,
             let src_info = bb.statements[idx].source_info;
             let suffix_stmts = bb.statements.split_off(idx+1);
             let orig_stmt = bb.statements.pop().unwrap();
-            let StatementKind::Assign(ref lhs, ref rhs) = orig_stmt.kind;
+            let (lhs, rhs) = match orig_stmt.kind {
+                StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs),
+                StatementKind::SetDiscriminant{ .. } =>
+                    span_bug!(src_info.span, "expected aggregate, not {:?}", orig_stmt.kind),
+            };
             let (agg_kind, operands) = match rhs {
                 &Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands),
                 _ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs),
@@ -64,10 +68,14 @@ fn run_pass<'a>(&mut self, tcx: TyCtxt<'a, 'tcx, 'tcx>,
                 let ty = variant_def.fields[i].ty(tcx, substs);
                 let rhs = Rvalue::Use(op.clone());
 
-                // since we don't handle enums, we don't need a cast
-                let lhs_cast = lhs.clone();
-
-                // FIXME we cannot deaggregate enums issue: #35186
+                let lhs_cast = if adt_def.variants.len() > 1 {
+                    Lvalue::Projection(Box::new(LvalueProjection {
+                        base: lhs.clone(),
+                        elem: ProjectionElem::Downcast(adt_def, variant),
+                    }))
+                } else {
+                    lhs.clone()
+                };
 
                 let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection {
                     base: lhs_cast,
@@ -80,18 +88,34 @@ fn run_pass<'a>(&mut self, tcx: TyCtxt<'a, 'tcx, 'tcx>,
                 debug!("inserting: {:?} @ {:?}", new_statement, idx + i);
                 bb.statements.push(new_statement);
             }
+
+            // if the aggregate was an enum, we need to set the discriminant
+            if adt_def.variants.len() > 1 {
+                let set_discriminant = Statement {
+                    kind: StatementKind::SetDiscriminant {
+                        lvalue: lhs.clone(),
+                        variant_index: variant,
+                    },
+                    source_info: src_info,
+                };
+                bb.statements.push(set_discriminant);
+            };
+
             curr = bb.statements.len();
             bb.statements.extend(suffix_stmts);
         }
     }
 }
 
-fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
+fn get_aggregate_statement_index<'a, 'tcx, 'b>(start: usize,
                                          statements: &Vec<Statement<'tcx>>)
                                          -> Option<usize> {
-    for i in curr..statements.len() {
+    for i in start..statements.len() {
         let ref statement = statements[i];
-        let StatementKind::Assign(_, ref rhs) = statement.kind;
+        let rhs = match statement.kind {
+            StatementKind::Assign(_, ref rhs) => rhs,
+            StatementKind::SetDiscriminant{ .. } => continue,
+        };
         let (kind, operands) = match rhs {
             &Rvalue::Aggregate(ref kind, ref operands) => (kind, operands),
             _ => continue,
@@ -100,9 +124,8 @@ fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
             &AggregateKind::Adt(adt_def, variant, _) => (adt_def, variant),
             _ => continue,
         };
-        if operands.len() == 0 || adt_def.variants.len() > 1 {
+        if operands.len() == 0 {
             // don't deaggregate ()
-            // don't deaggregate enums ... for now
             continue;
         }
         debug!("getting variant {:?}", variant);
index fa3490cbcf3384f920426c867fc03584dbae7d31..eb0d8697f15d4c88e028a3d5b9cd3a68c71bc9e2 100644 (file)
@@ -219,7 +219,13 @@ fn promote_temp(&mut self, temp: Temp) -> Temp {
         let (mut rvalue, mut call) = (None, None);
         let source_info = if stmt_idx < no_stmts {
             let statement = &mut self.source[bb].statements[stmt_idx];
-            let StatementKind::Assign(_, ref mut rhs) = statement.kind;
+            let mut rhs = match statement.kind {
+                StatementKind::Assign(_, ref mut rhs) => rhs,
+                StatementKind::SetDiscriminant{ .. } =>
+                    span_bug!(statement.source_info.span,
+                              "cannot promote SetDiscriminant {:?}",
+                              statement),
+            };
             if self.keep_original {
                 rvalue = Some(rhs.clone());
             } else {
@@ -300,10 +306,16 @@ fn promote_candidate(mut self, candidate: Candidate) {
         });
         let mut rvalue = match candidate {
             Candidate::Ref(Location { block: bb, statement_index: stmt_idx }) => {
-                match self.source[bb].statements[stmt_idx].kind {
+                let ref mut statement = self.source[bb].statements[stmt_idx];
+                match statement.kind {
                     StatementKind::Assign(_, ref mut rvalue) => {
                         mem::replace(rvalue, Rvalue::Use(new_operand))
                     }
+                    StatementKind::SetDiscriminant{ .. } => {
+                        span_bug!(statement.source_info.span,
+                                  "cannot promote SetDiscriminant {:?}",
+                                  statement);
+                    }
                 }
             }
             Candidate::ShuffleIndices(bb) => {
@@ -340,7 +352,11 @@ pub fn promote_candidates<'a, 'tcx>(mir: &mut Mir<'tcx>,
         let (span, ty) = match candidate {
             Candidate::Ref(Location { block: bb, statement_index: stmt_idx }) => {
                 let statement = &mir[bb].statements[stmt_idx];
-                let StatementKind::Assign(ref dest, _) = statement.kind;
+                let dest = match statement.kind {
+                    StatementKind::Assign(ref dest, _) => dest,
+                    StatementKind::SetDiscriminant{ .. } =>
+                        panic!("cannot promote SetDiscriminant"),
+                };
                 if let Lvalue::Temp(index) = *dest {
                     if temps[index] == TempState::PromotedOut {
                         // Already promoted.
index 52f41741b08d69c6975ca1049ddafb3e91ecae6a..934357c9e1da2f4deb3d6019e30ae50bcdb3b972 100644 (file)
@@ -14,7 +14,7 @@
 use rustc::infer::{self, InferCtxt, InferOk};
 use rustc::traits::{self, Reveal};
 use rustc::ty::fold::TypeFoldable;
-use rustc::ty::{self, Ty, TyCtxt};
+use rustc::ty::{self, Ty, TyCtxt, TypeVariants};
 use rustc::mir::repr::*;
 use rustc::mir::tcx::LvalueTy;
 use rustc::mir::transform::{MirPass, MirSource, Pass};
@@ -360,10 +360,27 @@ fn check_stmt(&mut self, mir: &Mir<'tcx>, stmt: &Statement<'tcx>) {
                         span_mirbug!(self, stmt, "bad assignment ({:?} = {:?}): {:?}",
                                      lv_ty, rv_ty, terr);
                     }
-                }
-
                 // FIXME: rvalue with undeterminable type - e.g. inline
                 // asm.
+                }
+            }
+            StatementKind::SetDiscriminant{ ref lvalue, variant_index } => {
+                let lvalue_type = lvalue.ty(mir, tcx).to_ty(tcx);
+                let adt = match lvalue_type.sty {
+                    TypeVariants::TyEnum(adt, _) => adt,
+                    _ => {
+                        span_bug!(stmt.source_info.span,
+                                  "bad set discriminant ({:?} = {:?}): lhs is not an enum",
+                                  lvalue,
+                                  variant_index);
+                    }
+                };
+                if variant_index >= adt.variants.len() {
+                     span_bug!(stmt.source_info.span,
+                               "bad set discriminant ({:?} = {:?}): value of of range",
+                               lvalue,
+                               variant_index);
+                };
             }
         }
     }
index 35ded7042969f33023d8e3d97b5def3eb79c0ea6..7ca94b6356e40185f5ce73b31ec5693b9d6c45b1 100644 (file)
@@ -285,6 +285,9 @@ fn trans(&mut self) -> Result<Const<'tcx>, ConstEvalFailure> {
                             Err(err) => if failure.is_ok() { failure = Err(err); }
                         }
                     }
+                    mir::StatementKind::SetDiscriminant{ .. } => {
+                        span_bug!(span, "SetDiscriminant should not appear in constants?");
+                    }
                 }
             }
 
index 44d264c7e98f27cacd710ef00caed9622cba1c3b..7e3074f4cedf0740beed7f444454e9561a76a2e0 100644 (file)
@@ -14,6 +14,8 @@
 
 use super::MirContext;
 use super::LocalRef;
+use super::super::adt;
+use super::super::disr::Disr;
 
 impl<'bcx, 'tcx> MirContext<'bcx, 'tcx> {
     pub fn trans_statement(&mut self,
@@ -57,6 +59,18 @@ pub fn trans_statement(&mut self,
                     self.trans_rvalue(bcx, tr_dest, rvalue, debug_loc)
                 }
             }
+            mir::StatementKind::SetDiscriminant{ref lvalue, variant_index} => {
+                let ty = self.monomorphized_lvalue_ty(lvalue);
+                let repr = adt::represent_type(bcx.ccx(), ty);
+                let lvalue_transed = self.trans_lvalue(&bcx, lvalue);
+                bcx.with_block(|bcx|
+                    adt::trans_set_discr(bcx,
+                                         &repr,
+                                        lvalue_transed.llval,
+                                        Disr::from(variant_index))
+                );
+                bcx
+            }
         }
     }
 }
diff --git a/src/test/mir-opt/deaggregator_test_enum.rs b/src/test/mir-opt/deaggregator_test_enum.rs
new file mode 100644 (file)
index 0000000..ccfa760
--- /dev/null
@@ -0,0 +1,45 @@
+// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+enum Baz {
+    Empty,
+    Foo { x: usize },
+}
+
+fn bar(a: usize) -> Baz {
+    Baz::Foo { x: a }
+}
+
+fn main() {
+    let x = bar(10);
+    match x {
+        Baz::Empty => println!("empty"),
+        Baz::Foo { x } => println!("{}", x),
+    };
+}
+
+// END RUST SOURCE
+// START rustc.node10.Deaggregator.before.mir
+// bb0: {
+//     var0 = arg0;                     // scope 0 at main.rs:7:8: 7:9
+//     tmp0 = var0;                     // scope 1 at main.rs:8:19: 8:20
+//     return = Baz::Foo { x: tmp0 };   // scope 1 at main.rs:8:5: 8:21
+//     goto -> bb1;                     // scope 1 at main.rs:7:1: 9:2
+// }
+// END rustc.node10.Deaggregator.before.mir
+// START rustc.node10.Deaggregator.after.mir
+// bb0: {
+//     var0 = arg0;                     // scope 0 at main.rs:7:8: 7:9
+//     tmp0 = var0;                     // scope 1 at main.rs:8:19: 8:20
+//     ((return as Foo).0: usize) = tmp0; // scope 1 at main.rs:8:5: 8:21
+//     discriminant(return) = 1;         // scope 1 at main.rs:8:5: 8:21
+//     goto -> bb1;                     // scope 1 at main.rs:7:1: 9:2
+// }
+// END rustc.node10.Deaggregator.after.mir
\ No newline at end of file