]> git.lizzy.rs Git - rust.git/blob - src/librustc_mir/transform/deaggregator.rs
rustc: split off BodyOwnerKind from MirSource.
[rust.git] / src / librustc_mir / transform / deaggregator.rs
1 // Copyright 2016 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 use rustc::hir;
12 use rustc::ty::TyCtxt;
13 use rustc::mir::*;
14 use rustc_data_structures::indexed_vec::Idx;
15 use transform::{MirPass, MirSource};
16
17 pub struct Deaggregator;
18
19 impl MirPass for Deaggregator {
20     fn run_pass<'a, 'tcx>(&self,
21                           tcx: TyCtxt<'a, 'tcx, 'tcx>,
22                           source: MirSource,
23                           mir: &mut Mir<'tcx>) {
24         let node_path = tcx.item_path_str(source.def_id);
25         debug!("running on: {:?}", node_path);
26         // we only run when mir_opt_level > 2
27         if tcx.sess.opts.debugging_opts.mir_opt_level <= 2 {
28             return;
29         }
30
31         // Don't run on constant MIR, because trans might not be able to
32         // evaluate the modified MIR.
33         // FIXME(eddyb) Remove check after miri is merged.
34         let id = tcx.hir.as_local_node_id(source.def_id).unwrap();
35         match (tcx.hir.body_owner_kind(id), source.promoted) {
36             (hir::BodyOwnerKind::Fn, None) => {},
37             _ => return
38         }
39         // In fact, we might not want to trigger in other cases.
40         // Ex: when we could use SROA.  See issue #35259
41
42         for bb in mir.basic_blocks_mut() {
43             let mut curr: usize = 0;
44             while let Some(idx) = get_aggregate_statement_index(curr, &bb.statements) {
45                 // do the replacement
46                 debug!("removing statement {:?}", idx);
47                 let src_info = bb.statements[idx].source_info;
48                 let suffix_stmts = bb.statements.split_off(idx+1);
49                 let orig_stmt = bb.statements.pop().unwrap();
50                 let (lhs, rhs) = match orig_stmt.kind {
51                     StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs),
52                     _ => span_bug!(src_info.span, "expected assign, not {:?}", orig_stmt),
53                 };
54                 let (agg_kind, operands) = match rhs {
55                     &Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands),
56                     _ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs),
57                 };
58                 let (adt_def, variant, substs) = match **agg_kind {
59                     AggregateKind::Adt(adt_def, variant, substs, None)
60                         => (adt_def, variant, substs),
61                     _ => span_bug!(src_info.span, "expected struct, not {:?}", rhs),
62                 };
63                 let n = bb.statements.len();
64                 bb.statements.reserve(n + operands.len() + suffix_stmts.len());
65                 for (i, op) in operands.iter().enumerate() {
66                     let ref variant_def = adt_def.variants[variant];
67                     let ty = variant_def.fields[i].ty(tcx, substs);
68                     let rhs = Rvalue::Use(op.clone());
69
70                     let lhs_cast = if adt_def.variants.len() > 1 {
71                         Lvalue::Projection(Box::new(LvalueProjection {
72                             base: lhs.clone(),
73                             elem: ProjectionElem::Downcast(adt_def, variant),
74                         }))
75                     } else {
76                         lhs.clone()
77                     };
78
79                     let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection {
80                         base: lhs_cast,
81                         elem: ProjectionElem::Field(Field::new(i), ty),
82                     }));
83                     let new_statement = Statement {
84                         source_info: src_info,
85                         kind: StatementKind::Assign(lhs_proj, rhs),
86                     };
87                     debug!("inserting: {:?} @ {:?}", new_statement, idx + i);
88                     bb.statements.push(new_statement);
89                 }
90
91                 // if the aggregate was an enum, we need to set the discriminant
92                 if adt_def.variants.len() > 1 {
93                     let set_discriminant = Statement {
94                         kind: StatementKind::SetDiscriminant {
95                             lvalue: lhs.clone(),
96                             variant_index: variant,
97                         },
98                         source_info: src_info,
99                     };
100                     bb.statements.push(set_discriminant);
101                 };
102
103                 curr = bb.statements.len();
104                 bb.statements.extend(suffix_stmts);
105             }
106         }
107     }
108 }
109
110 fn get_aggregate_statement_index<'a, 'tcx, 'b>(start: usize,
111                                          statements: &Vec<Statement<'tcx>>)
112                                          -> Option<usize> {
113     for i in start..statements.len() {
114         let ref statement = statements[i];
115         let rhs = match statement.kind {
116             StatementKind::Assign(_, ref rhs) => rhs,
117             _ => continue,
118         };
119         let (kind, operands) = match rhs {
120             &Rvalue::Aggregate(ref kind, ref operands) => (kind, operands),
121             _ => continue,
122         };
123         let (adt_def, variant) = match **kind {
124             AggregateKind::Adt(adt_def, variant, _, None) => (adt_def, variant),
125             _ => continue,
126         };
127         if operands.len() == 0 {
128             // don't deaggregate ()
129             continue;
130         }
131         debug!("getting variant {:?}", variant);
132         debug!("for adt_def {:?}", adt_def);
133         return Some(i);
134     };
135     None
136 }