]> git.lizzy.rs Git - rust.git/blob - src/libstd/sync/task_pool.rs
Fix misspelled comments.
[rust.git] / src / libstd / sync / task_pool.rs
1 // Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! Abstraction of a thread pool for basic parallelism.
12
13 #![unstable = "the semantics of a failing task and whether a thread is \
14                re-attached to a thread pool are somewhat unclear, and the \
15                utility of this type in `std::sync` is questionable with \
16                respect to the jobs of other primitives"]
17
18 use core::prelude::*;
19
20 use sync::{Arc, Mutex};
21 use sync::mpsc::{channel, Sender, Receiver};
22 use thread::Thread;
23 use thunk::Thunk;
24
25 struct Sentinel<'a> {
26     jobs: &'a Arc<Mutex<Receiver<Thunk>>>,
27     active: bool
28 }
29
30 impl<'a> Sentinel<'a> {
31     fn new(jobs: &Arc<Mutex<Receiver<Thunk>>>) -> Sentinel {
32         Sentinel {
33             jobs: jobs,
34             active: true
35         }
36     }
37
38     // Cancel and destroy this sentinel.
39     fn cancel(mut self) {
40         self.active = false;
41     }
42 }
43
44 #[unsafe_destructor]
45 impl<'a> Drop for Sentinel<'a> {
46     fn drop(&mut self) {
47         if self.active {
48             spawn_in_pool(self.jobs.clone())
49         }
50     }
51 }
52
53 /// A thread pool used to execute functions in parallel.
54 ///
55 /// Spawns `n` worker threads and replenishes the pool if any worker threads
56 /// panic.
57 ///
58 /// # Example
59 ///
60 /// ```rust
61 /// use std::sync::TaskPool;
62 /// use std::iter::AdditiveIterator;
63 /// use std::sync::mpsc::channel;
64 ///
65 /// let pool = TaskPool::new(4u);
66 ///
67 /// let (tx, rx) = channel();
68 /// for _ in range(0, 8u) {
69 ///     let tx = tx.clone();
70 ///     pool.execute(move|| {
71 ///         tx.send(1u).unwrap();
72 ///     });
73 /// }
74 ///
75 /// assert_eq!(rx.iter().take(8u).sum(), 8u);
76 /// ```
77 pub struct TaskPool {
78     // How the threadpool communicates with subthreads.
79     //
80     // This is the only such Sender, so when it is dropped all subthreads will
81     // quit.
82     jobs: Sender<Thunk>
83 }
84
85 impl TaskPool {
86     /// Spawns a new thread pool with `threads` threads.
87     ///
88     /// # Panics
89     ///
90     /// This function will panic if `threads` is 0.
91     pub fn new(threads: uint) -> TaskPool {
92         assert!(threads >= 1);
93
94         let (tx, rx) = channel::<Thunk>();
95         let rx = Arc::new(Mutex::new(rx));
96
97         // Threadpool threads
98         for _ in range(0, threads) {
99             spawn_in_pool(rx.clone());
100         }
101
102         TaskPool { jobs: tx }
103     }
104
105     /// Executes the function `job` on a thread in the pool.
106     pub fn execute<F>(&self, job: F)
107         where F : FnOnce(), F : Send
108     {
109         self.jobs.send(Thunk::new(job)).unwrap();
110     }
111 }
112
113 fn spawn_in_pool(jobs: Arc<Mutex<Receiver<Thunk>>>) {
114     Thread::spawn(move |:| {
115         // Will spawn a new thread on panic unless it is cancelled.
116         let sentinel = Sentinel::new(&jobs);
117
118         loop {
119             let message = {
120                 // Only lock jobs for the time it takes
121                 // to get a job, not run it.
122                 let lock = jobs.lock().unwrap();
123                 lock.recv()
124             };
125
126             match message {
127                 Ok(job) => job.invoke(()),
128
129                 // The Taskpool was dropped.
130                 Err(..) => break
131             }
132         }
133
134         sentinel.cancel();
135     }).detach();
136 }
137
138 #[cfg(test)]
139 mod test {
140     use prelude::v1::*;
141     use super::*;
142     use sync::mpsc::channel;
143
144     const TEST_TASKS: uint = 4u;
145
146     #[test]
147     fn test_works() {
148         use iter::AdditiveIterator;
149
150         let pool = TaskPool::new(TEST_TASKS);
151
152         let (tx, rx) = channel();
153         for _ in range(0, TEST_TASKS) {
154             let tx = tx.clone();
155             pool.execute(move|| {
156                 tx.send(1u).unwrap();
157             });
158         }
159
160         assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
161     }
162
163     #[test]
164     #[should_fail]
165     fn test_zero_tasks_panic() {
166         TaskPool::new(0);
167     }
168
169     #[test]
170     fn test_recovery_from_subtask_panic() {
171         use iter::AdditiveIterator;
172
173         let pool = TaskPool::new(TEST_TASKS);
174
175         // Panic all the existing threads.
176         for _ in range(0, TEST_TASKS) {
177             pool.execute(move|| -> () { panic!() });
178         }
179
180         // Ensure new threads were spawned to compensate.
181         let (tx, rx) = channel();
182         for _ in range(0, TEST_TASKS) {
183             let tx = tx.clone();
184             pool.execute(move|| {
185                 tx.send(1u).unwrap();
186             });
187         }
188
189         assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
190     }
191
192     #[test]
193     fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
194         use sync::{Arc, Barrier};
195
196         let pool = TaskPool::new(TEST_TASKS);
197         let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
198
199         // Panic all the existing threads in a bit.
200         for _ in range(0, TEST_TASKS) {
201             let waiter = waiter.clone();
202             pool.execute(move|| {
203                 waiter.wait();
204                 panic!();
205             });
206         }
207
208         drop(pool);
209
210         // Kick off the failure.
211         waiter.wait();
212     }
213 }