]> git.lizzy.rs Git - rust.git/blob - src/libstd/sync/task_pool.rs
rollup merge of #20157: alexcrichton/issue-20068
[rust.git] / src / libstd / sync / task_pool.rs
1 // Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! Abstraction of a thread pool for basic parallelism.
12
13 use core::prelude::*;
14
15 use thread::Thread;
16 use comm::{channel, Sender, Receiver};
17 use sync::{Arc, Mutex};
18 use thunk::Thunk;
19
20 struct Sentinel<'a> {
21     jobs: &'a Arc<Mutex<Receiver<Thunk>>>,
22     active: bool
23 }
24
25 impl<'a> Sentinel<'a> {
26     fn new(jobs: &Arc<Mutex<Receiver<Thunk>>>) -> Sentinel {
27         Sentinel {
28             jobs: jobs,
29             active: true
30         }
31     }
32
33     // Cancel and destroy this sentinel.
34     fn cancel(mut self) {
35         self.active = false;
36     }
37 }
38
39 #[unsafe_destructor]
40 impl<'a> Drop for Sentinel<'a> {
41     fn drop(&mut self) {
42         if self.active {
43             spawn_in_pool(self.jobs.clone())
44         }
45     }
46 }
47
48 /// A thread pool used to execute functions in parallel.
49 ///
50 /// Spawns `n` worker threads and replenishes the pool if any worker threads
51 /// panic.
52 ///
53 /// # Example
54 ///
55 /// ```rust
56 /// use std::sync::TaskPool;
57 /// use std::iter::AdditiveIterator;
58 /// use std::comm::channel;
59 ///
60 /// let pool = TaskPool::new(4u);
61 ///
62 /// let (tx, rx) = channel();
63 /// for _ in range(0, 8u) {
64 ///     let tx = tx.clone();
65 ///     pool.execute(move|| {
66 ///         tx.send(1u);
67 ///     });
68 /// }
69 ///
70 /// assert_eq!(rx.iter().take(8u).sum(), 8u);
71 /// ```
72 pub struct TaskPool {
73     // How the threadpool communicates with subthreads.
74     //
75     // This is the only such Sender, so when it is dropped all subthreads will
76     // quit.
77     jobs: Sender<Thunk>
78 }
79
80 impl TaskPool {
81     /// Spawns a new thread pool with `threads` threads.
82     ///
83     /// # Panics
84     ///
85     /// This function will panic if `threads` is 0.
86     pub fn new(threads: uint) -> TaskPool {
87         assert!(threads >= 1);
88
89         let (tx, rx) = channel::<Thunk>();
90         let rx = Arc::new(Mutex::new(rx));
91
92         // Threadpool threads
93         for _ in range(0, threads) {
94             spawn_in_pool(rx.clone());
95         }
96
97         TaskPool { jobs: tx }
98     }
99
100     /// Executes the function `job` on a thread in the pool.
101     pub fn execute<F>(&self, job: F)
102         where F : FnOnce(), F : Send
103     {
104         self.jobs.send(Thunk::new(job));
105     }
106 }
107
108 fn spawn_in_pool(jobs: Arc<Mutex<Receiver<Thunk>>>) {
109     Thread::spawn(move |:| {
110         // Will spawn a new thread on panic unless it is cancelled.
111         let sentinel = Sentinel::new(&jobs);
112
113         loop {
114             let message = {
115                 // Only lock jobs for the time it takes
116                 // to get a job, not run it.
117                 let lock = jobs.lock().unwrap();
118                 lock.recv_opt()
119             };
120
121             match message {
122                 Ok(job) => job.invoke(()),
123
124                 // The Taskpool was dropped.
125                 Err(..) => break
126             }
127         }
128
129         sentinel.cancel();
130     }).detach();
131 }
132
133 #[cfg(test)]
134 mod test {
135     use prelude::v1::*;
136     use super::*;
137     use comm::channel;
138
139     const TEST_TASKS: uint = 4u;
140
141     #[test]
142     fn test_works() {
143         use iter::AdditiveIterator;
144
145         let pool = TaskPool::new(TEST_TASKS);
146
147         let (tx, rx) = channel();
148         for _ in range(0, TEST_TASKS) {
149             let tx = tx.clone();
150             pool.execute(move|| {
151                 tx.send(1u);
152             });
153         }
154
155         assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
156     }
157
158     #[test]
159     #[should_fail]
160     fn test_zero_tasks_panic() {
161         TaskPool::new(0);
162     }
163
164     #[test]
165     fn test_recovery_from_subtask_panic() {
166         use iter::AdditiveIterator;
167
168         let pool = TaskPool::new(TEST_TASKS);
169
170         // Panic all the existing threads.
171         for _ in range(0, TEST_TASKS) {
172             pool.execute(move|| -> () { panic!() });
173         }
174
175         // Ensure new threads were spawned to compensate.
176         let (tx, rx) = channel();
177         for _ in range(0, TEST_TASKS) {
178             let tx = tx.clone();
179             pool.execute(move|| {
180                 tx.send(1u);
181             });
182         }
183
184         assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
185     }
186
187     #[test]
188     fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
189         use sync::{Arc, Barrier};
190
191         let pool = TaskPool::new(TEST_TASKS);
192         let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
193
194         // Panic all the existing threads in a bit.
195         for _ in range(0, TEST_TASKS) {
196             let waiter = waiter.clone();
197             pool.execute(move|| {
198                 waiter.wait();
199                 panic!();
200             });
201         }
202
203         drop(pool);
204
205         // Kick off the failure.
206         waiter.wait();
207     }
208 }