]> git.lizzy.rs Git - rust.git/blob - src/librustc_data_structures/obligation_forest/mod.rs
Rollup merge of #51765 - jonas-schievink:patch-1, r=KodrAus
[rust.git] / src / librustc_data_structures / obligation_forest / mod.rs
1 // Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! The `ObligationForest` is a utility data structure used in trait
12 //! matching to track the set of outstanding obligations (those not
13 //! yet resolved to success or error). It also tracks the "backtrace"
14 //! of each pending obligation (why we are trying to figure this out
15 //! in the first place). See README.md for a general overview of how
16 //! to use this class.
17
18 use fx::{FxHashMap, FxHashSet};
19
20 use std::cell::Cell;
21 use std::collections::hash_map::Entry;
22 use std::fmt::Debug;
23 use std::hash;
24 use std::marker::PhantomData;
25
26 mod node_index;
27 use self::node_index::NodeIndex;
28
29 #[cfg(test)]
30 mod test;
31
32 pub trait ForestObligation : Clone + Debug {
33     type Predicate : Clone + hash::Hash + Eq + Debug;
34
35     fn as_predicate(&self) -> &Self::Predicate;
36 }
37
38 pub trait ObligationProcessor {
39     type Obligation : ForestObligation;
40     type Error : Debug;
41
42     fn process_obligation(&mut self,
43                           obligation: &mut Self::Obligation)
44                           -> ProcessResult<Self::Obligation, Self::Error>;
45
46     /// As we do the cycle check, we invoke this callback when we
47     /// encounter an actual cycle. `cycle` is an iterator that starts
48     /// at the start of the cycle in the stack and walks **toward the
49     /// top**.
50     ///
51     /// In other words, if we had O1 which required O2 which required
52     /// O3 which required O1, we would give an iterator yielding O1,
53     /// O2, O3 (O1 is not yielded twice).
54     fn process_backedge<'c, I>(&mut self,
55                                cycle: I,
56                                _marker: PhantomData<&'c Self::Obligation>)
57         where I: Clone + Iterator<Item=&'c Self::Obligation>;
58 }
59
60 /// The result type used by `process_obligation`.
61 #[derive(Debug)]
62 pub enum ProcessResult<O, E> {
63     Unchanged,
64     Changed(Vec<O>),
65     Error(E),
66 }
67
68 pub struct ObligationForest<O: ForestObligation> {
69     /// The list of obligations. In between calls to
70     /// `process_obligations`, this list only contains nodes in the
71     /// `Pending` or `Success` state (with a non-zero number of
72     /// incomplete children). During processing, some of those nodes
73     /// may be changed to the error state, or we may find that they
74     /// are completed (That is, `num_incomplete_children` drops to 0).
75     /// At the end of processing, those nodes will be removed by a
76     /// call to `compress`.
77     ///
78     /// At all times we maintain the invariant that every node appears
79     /// at a higher index than its parent. This is needed by the
80     /// backtrace iterator (which uses `split_at`).
81     nodes: Vec<Node<O>>,
82     /// A cache of predicates that have been successfully completed.
83     done_cache: FxHashSet<O::Predicate>,
84     /// An cache of the nodes in `nodes`, indexed by predicate.
85     waiting_cache: FxHashMap<O::Predicate, NodeIndex>,
86     scratch: Option<Vec<usize>>,
87 }
88
89 #[derive(Debug)]
90 struct Node<O> {
91     obligation: O,
92     state: Cell<NodeState>,
93
94     /// The parent of a node - the original obligation of
95     /// which it is a subobligation. Except for error reporting,
96     /// it is just like any member of `dependents`.
97     parent: Option<NodeIndex>,
98
99     /// Obligations that depend on this obligation for their
100     /// completion. They must all be in a non-pending state.
101     dependents: Vec<NodeIndex>,
102 }
103
104 /// The state of one node in some tree within the forest. This
105 /// represents the current state of processing for the obligation (of
106 /// type `O`) associated with this node.
107 ///
108 /// Outside of ObligationForest methods, nodes should be either Pending
109 /// or Waiting.
110 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
111 enum NodeState {
112     /// Obligations for which selection had not yet returned a
113     /// non-ambiguous result.
114     Pending,
115
116     /// This obligation was selected successfully, but may or
117     /// may not have subobligations.
118     Success,
119
120     /// This obligation was selected successfully, but it has
121     /// a pending subobligation.
122     Waiting,
123
124     /// This obligation, along with its subobligations, are complete,
125     /// and will be removed in the next collection.
126     Done,
127
128     /// This obligation was resolved to an error. Error nodes are
129     /// removed from the vector by the compression step.
130     Error,
131
132     /// This is a temporary state used in DFS loops to detect cycles,
133     /// it should not exist outside of these DFSes.
134     OnDfsStack,
135 }
136
137 #[derive(Debug)]
138 pub struct Outcome<O, E> {
139     /// Obligations that were completely evaluated, including all
140     /// (transitive) subobligations.
141     pub completed: Vec<O>,
142
143     /// Backtrace of obligations that were found to be in error.
144     pub errors: Vec<Error<O, E>>,
145
146     /// If true, then we saw no successful obligations, which means
147     /// there is no point in further iteration. This is based on the
148     /// assumption that when trait matching returns `Error` or
149     /// `Unchanged`, those results do not affect environmental
150     /// inference state. (Note that if we invoke `process_obligations`
151     /// with no pending obligations, stalled will be true.)
152     pub stalled: bool,
153 }
154
155 #[derive(Debug, PartialEq, Eq)]
156 pub struct Error<O, E> {
157     pub error: E,
158     pub backtrace: Vec<O>,
159 }
160
161 impl<O: ForestObligation> ObligationForest<O> {
162     pub fn new() -> ObligationForest<O> {
163         ObligationForest {
164             nodes: vec![],
165             done_cache: FxHashSet(),
166             waiting_cache: FxHashMap(),
167             scratch: Some(vec![]),
168         }
169     }
170
171     /// Return the total number of nodes in the forest that have not
172     /// yet been fully resolved.
173     pub fn len(&self) -> usize {
174         self.nodes.len()
175     }
176
177     /// Registers an obligation
178     ///
179     /// This CAN be done in a snapshot
180     pub fn register_obligation(&mut self, obligation: O) {
181         // Ignore errors here - there is no guarantee of success.
182         let _ = self.register_obligation_at(obligation, None);
183     }
184
185     // returns Err(()) if we already know this obligation failed.
186     fn register_obligation_at(&mut self, obligation: O, parent: Option<NodeIndex>)
187                               -> Result<(), ()>
188     {
189         if self.done_cache.contains(obligation.as_predicate()) {
190             return Ok(())
191         }
192
193         match self.waiting_cache.entry(obligation.as_predicate().clone()) {
194             Entry::Occupied(o) => {
195                 debug!("register_obligation_at({:?}, {:?}) - duplicate of {:?}!",
196                        obligation, parent, o.get());
197                 let node = &mut self.nodes[o.get().get()];
198                 if let Some(parent) = parent {
199                     // If the node is already in `waiting_cache`, it's already
200                     // been marked with a parent. (It's possible that parent
201                     // has been cleared by `apply_rewrites`, though.) So just
202                     // dump `parent` into `node.dependents`... unless it's
203                     // already in `node.dependents` or `node.parent`.
204                     if !node.dependents.contains(&parent) && Some(parent) != node.parent {
205                         node.dependents.push(parent);
206                     }
207                 }
208                 if let NodeState::Error = node.state.get() {
209                     Err(())
210                 } else {
211                     Ok(())
212                 }
213             }
214             Entry::Vacant(v) => {
215                 debug!("register_obligation_at({:?}, {:?}) - ok, new index is {}",
216                        obligation, parent, self.nodes.len());
217                 v.insert(NodeIndex::new(self.nodes.len()));
218                 self.nodes.push(Node::new(parent, obligation));
219                 Ok(())
220             }
221         }
222     }
223
224     /// Convert all remaining obligations to the given error.
225     ///
226     /// This cannot be done during a snapshot.
227     pub fn to_errors<E: Clone>(&mut self, error: E) -> Vec<Error<O, E>> {
228         let mut errors = vec![];
229         for index in 0..self.nodes.len() {
230             if let NodeState::Pending = self.nodes[index].state.get() {
231                 let backtrace = self.error_at(index);
232                 errors.push(Error {
233                     error: error.clone(),
234                     backtrace,
235                 });
236             }
237         }
238         let successful_obligations = self.compress();
239         assert!(successful_obligations.is_empty());
240         errors
241     }
242
243     /// Returns the set of obligations that are in a pending state.
244     pub fn map_pending_obligations<P, F>(&self, f: F) -> Vec<P>
245         where F: Fn(&O) -> P
246     {
247         self.nodes
248             .iter()
249             .filter(|n| n.state.get() == NodeState::Pending)
250             .map(|n| f(&n.obligation))
251             .collect()
252     }
253
254     /// Perform a pass through the obligation list. This must
255     /// be called in a loop until `outcome.stalled` is false.
256     ///
257     /// This CANNOT be unrolled (presently, at least).
258     pub fn process_obligations<P>(&mut self, processor: &mut P) -> Outcome<O, P::Error>
259         where P: ObligationProcessor<Obligation=O>
260     {
261         debug!("process_obligations(len={})", self.nodes.len());
262
263         let mut errors = vec![];
264         let mut stalled = true;
265
266         for index in 0..self.nodes.len() {
267             debug!("process_obligations: node {} == {:?}",
268                    index,
269                    self.nodes[index]);
270
271             let result = match self.nodes[index] {
272                 Node { state: ref _state, ref mut obligation, .. }
273                     if _state.get() == NodeState::Pending =>
274                 {
275                     processor.process_obligation(obligation)
276                 }
277                 _ => continue
278             };
279
280             debug!("process_obligations: node {} got result {:?}",
281                    index,
282                    result);
283
284             match result {
285                 ProcessResult::Unchanged => {
286                     // No change in state.
287                 }
288                 ProcessResult::Changed(children) => {
289                     // We are not (yet) stalled.
290                     stalled = false;
291                     self.nodes[index].state.set(NodeState::Success);
292
293                     for child in children {
294                         let st = self.register_obligation_at(
295                             child,
296                             Some(NodeIndex::new(index))
297                         );
298                         if let Err(()) = st {
299                             // error already reported - propagate it
300                             // to our node.
301                             self.error_at(index);
302                         }
303                     }
304                 }
305                 ProcessResult::Error(err) => {
306                     stalled = false;
307                     let backtrace = self.error_at(index);
308                     errors.push(Error {
309                         error: err,
310                         backtrace,
311                     });
312                 }
313             }
314         }
315
316         if stalled {
317             // There's no need to perform marking, cycle processing and compression when nothing
318             // changed.
319             return Outcome {
320                 completed: vec![],
321                 errors,
322                 stalled,
323             };
324         }
325
326         self.mark_as_waiting();
327         self.process_cycles(processor);
328
329         // Now we have to compress the result
330         let completed_obligations = self.compress();
331
332         debug!("process_obligations: complete");
333
334         Outcome {
335             completed: completed_obligations,
336             errors,
337             stalled,
338         }
339     }
340
341     /// Mark all NodeState::Success nodes as NodeState::Done and
342     /// report all cycles between them. This should be called
343     /// after `mark_as_waiting` marks all nodes with pending
344     /// subobligations as NodeState::Waiting.
345     fn process_cycles<P>(&mut self, processor: &mut P)
346         where P: ObligationProcessor<Obligation=O>
347     {
348         let mut stack = self.scratch.take().unwrap();
349         debug_assert!(stack.is_empty());
350
351         debug!("process_cycles()");
352
353         for index in 0..self.nodes.len() {
354             // For rustc-benchmarks/inflate-0.1.0 this state test is extremely
355             // hot and the state is almost always `Pending` or `Waiting`. It's
356             // a win to handle the no-op cases immediately to avoid the cost of
357             // the function call.
358             let state = self.nodes[index].state.get();
359             match state {
360                 NodeState::Waiting | NodeState::Pending | NodeState::Done | NodeState::Error => {},
361                 _ => self.find_cycles_from_node(&mut stack, processor, index),
362             }
363         }
364
365         debug!("process_cycles: complete");
366
367         debug_assert!(stack.is_empty());
368         self.scratch = Some(stack);
369     }
370
371     fn find_cycles_from_node<P>(&self, stack: &mut Vec<usize>,
372                                 processor: &mut P, index: usize)
373         where P: ObligationProcessor<Obligation=O>
374     {
375         let node = &self.nodes[index];
376         let state = node.state.get();
377         match state {
378             NodeState::OnDfsStack => {
379                 let index =
380                     stack.iter().rposition(|n| *n == index).unwrap();
381                 processor.process_backedge(stack[index..].iter().map(GetObligation(&self.nodes)),
382                                            PhantomData);
383             }
384             NodeState::Success => {
385                 node.state.set(NodeState::OnDfsStack);
386                 stack.push(index);
387                 for dependent in node.parent.iter().chain(node.dependents.iter()) {
388                     self.find_cycles_from_node(stack, processor, dependent.get());
389                 }
390                 stack.pop();
391                 node.state.set(NodeState::Done);
392             },
393             NodeState::Waiting | NodeState::Pending => {
394                 // this node is still reachable from some pending node. We
395                 // will get to it when they are all processed.
396             }
397             NodeState::Done | NodeState::Error => {
398                 // already processed that node
399             }
400         };
401     }
402
403     /// Returns a vector of obligations for `p` and all of its
404     /// ancestors, putting them into the error state in the process.
405     fn error_at(&mut self, p: usize) -> Vec<O> {
406         let mut error_stack = self.scratch.take().unwrap();
407         let mut trace = vec![];
408
409         let mut n = p;
410         loop {
411             self.nodes[n].state.set(NodeState::Error);
412             trace.push(self.nodes[n].obligation.clone());
413             error_stack.extend(self.nodes[n].dependents.iter().map(|x| x.get()));
414
415             // loop to the parent
416             match self.nodes[n].parent {
417                 Some(q) => n = q.get(),
418                 None => break
419             }
420         }
421
422         while let Some(i) = error_stack.pop() {
423             let node = &self.nodes[i];
424
425             match node.state.get() {
426                 NodeState::Error => continue,
427                 _ => node.state.set(NodeState::Error)
428             }
429
430             error_stack.extend(
431                 node.parent.iter().chain(node.dependents.iter()).map(|x| x.get())
432             );
433         }
434
435         self.scratch = Some(error_stack);
436         trace
437     }
438
439     #[inline]
440     fn mark_neighbors_as_waiting_from(&self, node: &Node<O>) {
441         for dependent in node.parent.iter().chain(node.dependents.iter()) {
442             self.mark_as_waiting_from(&self.nodes[dependent.get()]);
443         }
444     }
445
446     /// Marks all nodes that depend on a pending node as NodeState::Waiting.
447     fn mark_as_waiting(&self) {
448         for node in &self.nodes {
449             if node.state.get() == NodeState::Waiting {
450                 node.state.set(NodeState::Success);
451             }
452         }
453
454         for node in &self.nodes {
455             if node.state.get() == NodeState::Pending {
456                 self.mark_neighbors_as_waiting_from(node);
457             }
458         }
459     }
460
461     fn mark_as_waiting_from(&self, node: &Node<O>) {
462         match node.state.get() {
463             NodeState::Waiting | NodeState::Error | NodeState::OnDfsStack => return,
464             NodeState::Success => node.state.set(NodeState::Waiting),
465             NodeState::Pending | NodeState::Done => {},
466         }
467
468         self.mark_neighbors_as_waiting_from(node);
469     }
470
471     /// Compresses the vector, removing all popped nodes. This adjusts
472     /// the indices and hence invalidates any outstanding
473     /// indices. Cannot be used during a transaction.
474     ///
475     /// Beforehand, all nodes must be marked as `Done` and no cycles
476     /// on these nodes may be present. This is done by e.g. `process_cycles`.
477     #[inline(never)]
478     fn compress(&mut self) -> Vec<O> {
479         let nodes_len = self.nodes.len();
480         let mut node_rewrites: Vec<_> = self.scratch.take().unwrap();
481         node_rewrites.extend(0..nodes_len);
482         let mut dead_nodes = 0;
483
484         // Now move all popped nodes to the end. Try to keep the order.
485         //
486         // LOOP INVARIANT:
487         //     self.nodes[0..i - dead_nodes] are the first remaining nodes
488         //     self.nodes[i - dead_nodes..i] are all dead
489         //     self.nodes[i..] are unchanged
490         for i in 0..self.nodes.len() {
491             match self.nodes[i].state.get() {
492                 NodeState::Pending | NodeState::Waiting => {
493                     if dead_nodes > 0 {
494                         self.nodes.swap(i, i - dead_nodes);
495                         node_rewrites[i] -= dead_nodes;
496                     }
497                 }
498                 NodeState::Done => {
499                     self.waiting_cache.remove(self.nodes[i].obligation.as_predicate());
500                     // FIXME(HashMap): why can't I get my key back?
501                     self.done_cache.insert(self.nodes[i].obligation.as_predicate().clone());
502                     node_rewrites[i] = nodes_len;
503                     dead_nodes += 1;
504                 }
505                 NodeState::Error => {
506                     // We *intentionally* remove the node from the cache at this point. Otherwise
507                     // tests must come up with a different type on every type error they
508                     // check against.
509                     self.waiting_cache.remove(self.nodes[i].obligation.as_predicate());
510                     node_rewrites[i] = nodes_len;
511                     dead_nodes += 1;
512                 }
513                 NodeState::OnDfsStack | NodeState::Success => unreachable!()
514             }
515         }
516
517         // No compression needed.
518         if dead_nodes == 0 {
519             node_rewrites.truncate(0);
520             self.scratch = Some(node_rewrites);
521             return vec![];
522         }
523
524         // Pop off all the nodes we killed and extract the success
525         // stories.
526         let successful = (0..dead_nodes)
527                              .map(|_| self.nodes.pop().unwrap())
528                              .flat_map(|node| {
529                                  match node.state.get() {
530                                      NodeState::Error => None,
531                                      NodeState::Done => Some(node.obligation),
532                                      _ => unreachable!()
533                                  }
534                              })
535             .collect();
536         self.apply_rewrites(&node_rewrites);
537
538         node_rewrites.truncate(0);
539         self.scratch = Some(node_rewrites);
540
541         successful
542     }
543
544     fn apply_rewrites(&mut self, node_rewrites: &[usize]) {
545         let nodes_len = node_rewrites.len();
546
547         for node in &mut self.nodes {
548             if let Some(index) = node.parent {
549                 let new_index = node_rewrites[index.get()];
550                 if new_index >= nodes_len {
551                     // parent dead due to error
552                     node.parent = None;
553                 } else {
554                     node.parent = Some(NodeIndex::new(new_index));
555                 }
556             }
557
558             let mut i = 0;
559             while i < node.dependents.len() {
560                 let new_index = node_rewrites[node.dependents[i].get()];
561                 if new_index >= nodes_len {
562                     node.dependents.swap_remove(i);
563                 } else {
564                     node.dependents[i] = NodeIndex::new(new_index);
565                     i += 1;
566                 }
567             }
568         }
569
570         let mut kill_list = vec![];
571         for (predicate, index) in self.waiting_cache.iter_mut() {
572             let new_index = node_rewrites[index.get()];
573             if new_index >= nodes_len {
574                 kill_list.push(predicate.clone());
575             } else {
576                 *index = NodeIndex::new(new_index);
577             }
578         }
579
580         for predicate in kill_list { self.waiting_cache.remove(&predicate); }
581     }
582 }
583
584 impl<O> Node<O> {
585     fn new(parent: Option<NodeIndex>, obligation: O) -> Node<O> {
586         Node {
587             obligation,
588             state: Cell::new(NodeState::Pending),
589             parent,
590             dependents: vec![],
591         }
592     }
593 }
594
595 // I need a Clone closure
596 #[derive(Clone)]
597 struct GetObligation<'a, O: 'a>(&'a [Node<O>]);
598
599 impl<'a, 'b, O> FnOnce<(&'b usize,)> for GetObligation<'a, O> {
600     type Output = &'a O;
601     extern "rust-call" fn call_once(self, args: (&'b usize,)) -> &'a O {
602         &self.0[*args.0].obligation
603     }
604 }
605
606 impl<'a, 'b, O> FnMut<(&'b usize,)> for GetObligation<'a, O> {
607     extern "rust-call" fn call_mut(&mut self, args: (&'b usize,)) -> &'a O {
608         &self.0[*args.0].obligation
609     }
610 }