]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_hir_analysis/src/variance/solve.rs
rustc_typeck to rustc_hir_analysis
[rust.git] / compiler / rustc_hir_analysis / src / variance / solve.rs
1 //! Constraint solving
2 //!
3 //! The final phase iterates over the constraints, refining the variance
4 //! for each inferred until a fixed point is reached. This will be the
5 //! optimal solution to the constraints. The final variance for each
6 //! inferred is then written into the `variance_map` in the tcx.
7
8 use rustc_data_structures::fx::FxHashMap;
9 use rustc_hir::def_id::DefId;
10 use rustc_middle::ty;
11
12 use super::constraints::*;
13 use super::terms::VarianceTerm::*;
14 use super::terms::*;
15 use super::xform::*;
16
17 struct SolveContext<'a, 'tcx> {
18     terms_cx: TermsContext<'a, 'tcx>,
19     constraints: Vec<Constraint<'a>>,
20
21     // Maps from an InferredIndex to the inferred value for that variable.
22     solutions: Vec<ty::Variance>,
23 }
24
25 pub fn solve_constraints<'tcx>(
26     constraints_cx: ConstraintContext<'_, 'tcx>,
27 ) -> ty::CrateVariancesMap<'tcx> {
28     let ConstraintContext { terms_cx, constraints, .. } = constraints_cx;
29
30     let mut solutions = vec![ty::Bivariant; terms_cx.inferred_terms.len()];
31     for &(id, ref variances) in &terms_cx.lang_items {
32         let InferredIndex(start) = terms_cx.inferred_starts[&id];
33         for (i, &variance) in variances.iter().enumerate() {
34             solutions[start + i] = variance;
35         }
36     }
37
38     let mut solutions_cx = SolveContext { terms_cx, constraints, solutions };
39     solutions_cx.solve();
40     let variances = solutions_cx.create_map();
41
42     ty::CrateVariancesMap { variances }
43 }
44
45 impl<'a, 'tcx> SolveContext<'a, 'tcx> {
46     fn solve(&mut self) {
47         // Propagate constraints until a fixed point is reached.  Note
48         // that the maximum number of iterations is 2C where C is the
49         // number of constraints (each variable can change values at most
50         // twice). Since number of constraints is linear in size of the
51         // input, so is the inference process.
52         let mut changed = true;
53         while changed {
54             changed = false;
55
56             for constraint in &self.constraints {
57                 let Constraint { inferred, variance: term } = *constraint;
58                 let InferredIndex(inferred) = inferred;
59                 let variance = self.evaluate(term);
60                 let old_value = self.solutions[inferred];
61                 let new_value = glb(variance, old_value);
62                 if old_value != new_value {
63                     debug!(
64                         "updating inferred {} \
65                             from {:?} to {:?} due to {:?}",
66                         inferred, old_value, new_value, term
67                     );
68
69                     self.solutions[inferred] = new_value;
70                     changed = true;
71                 }
72             }
73         }
74     }
75
76     fn enforce_const_invariance(&self, generics: &ty::Generics, variances: &mut [ty::Variance]) {
77         let tcx = self.terms_cx.tcx;
78
79         // Make all const parameters invariant.
80         for param in generics.params.iter() {
81             if let ty::GenericParamDefKind::Const { .. } = param.kind {
82                 variances[param.index as usize] = ty::Invariant;
83             }
84         }
85
86         // Make all the const parameters in the parent invariant (recursively).
87         if let Some(def_id) = generics.parent {
88             self.enforce_const_invariance(tcx.generics_of(def_id), variances);
89         }
90     }
91
92     fn create_map(&self) -> FxHashMap<DefId, &'tcx [ty::Variance]> {
93         let tcx = self.terms_cx.tcx;
94
95         let solutions = &self.solutions;
96         self.terms_cx
97             .inferred_starts
98             .iter()
99             .map(|(&def_id, &InferredIndex(start))| {
100                 let generics = tcx.generics_of(def_id);
101                 let count = generics.count();
102
103                 let variances = tcx.arena.alloc_slice(&solutions[start..(start + count)]);
104
105                 // Const parameters are always invariant.
106                 self.enforce_const_invariance(generics, variances);
107
108                 // Functions are permitted to have unused generic parameters: make those invariant.
109                 if let ty::FnDef(..) = tcx.type_of(def_id).kind() {
110                     for variance in variances.iter_mut() {
111                         if *variance == ty::Bivariant {
112                             *variance = ty::Invariant;
113                         }
114                     }
115                 }
116
117                 (def_id.to_def_id(), &*variances)
118             })
119             .collect()
120     }
121
122     fn evaluate(&self, term: VarianceTermPtr<'a>) -> ty::Variance {
123         match *term {
124             ConstantTerm(v) => v,
125
126             TransformTerm(t1, t2) => {
127                 let v1 = self.evaluate(t1);
128                 let v2 = self.evaluate(t2);
129                 v1.xform(v2)
130             }
131
132             InferredTerm(InferredIndex(index)) => self.solutions[index],
133         }
134     }
135 }