-// Copyright 2012 The Rust Project Developers. See the COPYRIGHT
+// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
use core::prelude::*;
-use task;
use task::spawn;
-use vec::Vec;
-use comm::{channel, Sender};
+use comm::{channel, Sender, Receiver};
+use sync::{Arc, Mutex};
-enum Msg<T> {
- Execute(proc(&T):Send),
- Quit
+struct Sentinel<'a> {
+ jobs: &'a Arc<Mutex<Receiver<proc(): Send>>>,
+ active: bool
}
-/// A task pool used to execute functions in parallel.
-pub struct TaskPool<T> {
- channels: Vec<Sender<Msg<T>>>,
- next_index: uint,
+impl<'a> Sentinel<'a> {
+ fn new(jobs: &Arc<Mutex<Receiver<proc(): Send>>>) -> Sentinel {
+ Sentinel {
+ jobs: jobs,
+ active: true
+ }
+ }
+
+ // Cancel and destroy this sentinel.
+ fn cancel(mut self) {
+ self.active = false;
+ }
}
#[unsafe_destructor]
-impl<T> Drop for TaskPool<T> {
+impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
- for channel in self.channels.iter_mut() {
- channel.send(Quit);
+ if self.active {
+ spawn_in_pool(self.jobs.clone())
}
}
}
-impl<T> TaskPool<T> {
- /// Spawns a new task pool with `n_tasks` tasks. The provided
- /// `init_fn_factory` returns a function which, given the index of the
- /// task, should return local data to be kept around in that task.
+/// A task pool used to execute functions in parallel.
+///
+/// Spawns `n` worker tasks and replenishes the pool if any worker tasks
+/// panic.
+///
+/// # Example
+///
+/// ```rust
+/// # use sync::TaskPool;
+/// # use iter::AdditiveIterator;
+///
+/// let pool = TaskPool::new(4u);
+///
+/// let (tx, rx) = channel();
+/// for _ in range(0, 8u) {
+/// let tx = tx.clone();
+/// pool.execute(proc() {
+/// tx.send(1u);
+/// });
+/// }
+///
+/// assert_eq!(rx.iter().take(8u).sum(), 8u);
+/// ```
+pub struct TaskPool {
+ // How the taskpool communicates with subtasks.
+ //
+ // This is the only such Sender, so when it is dropped all subtasks will
+ // quit.
+ jobs: Sender<proc(): Send>
+}
+
+impl TaskPool {
+ /// Spawns a new task pool with `tasks` tasks.
///
/// # Panics
///
- /// This function will panic if `n_tasks` is less than 1.
- pub fn new(n_tasks: uint,
- init_fn_factory: || -> proc(uint):Send -> T)
- -> TaskPool<T> {
- assert!(n_tasks >= 1);
-
- let channels = Vec::from_fn(n_tasks, |i| {
- let (tx, rx) = channel::<Msg<T>>();
- let init_fn = init_fn_factory();
-
- let task_body = proc() {
- let local_data = init_fn(i);
- loop {
- match rx.recv() {
- Execute(f) => f(&local_data),
- Quit => break
- }
- }
- };
+ /// This function will panic if `tasks` is 0.
+ pub fn new(tasks: uint) -> TaskPool {
+ assert!(tasks >= 1);
- // Run on this scheduler.
- task::spawn(task_body);
+ let (tx, rx) = channel::<proc(): Send>();
+ let rx = Arc::new(Mutex::new(rx));
- tx
- });
+ // Taskpool tasks.
+ for _ in range(0, tasks) {
+ spawn_in_pool(rx.clone());
+ }
- return TaskPool {
- channels: channels,
- next_index: 0,
- };
+ TaskPool { jobs: tx }
}
- /// Executes the function `f` on a task in the pool. The function
- /// receives a reference to the local data returned by the `init_fn`.
- pub fn execute(&mut self, f: proc(&T):Send) {
- self.channels[self.next_index].send(Execute(f));
- self.next_index += 1;
- if self.next_index == self.channels.len() { self.next_index = 0; }
+ /// Executes the function `job` on a task in the pool.
+ pub fn execute(&self, job: proc():Send) {
+ self.jobs.send(job);
}
}
-#[test]
-fn test_task_pool() {
- let f: || -> proc(uint):Send -> uint = || { proc(i) i };
- let mut pool = TaskPool::new(4, f);
- for _ in range(0u, 8) {
- pool.execute(proc(i) println!("Hello from thread {}!", *i));
- }
+fn spawn_in_pool(jobs: Arc<Mutex<Receiver<proc(): Send>>>) {
+ spawn(proc() {
+ // Will spawn a new task on panic unless it is cancelled.
+ let sentinel = Sentinel::new(&jobs);
+
+ loop {
+ let message = {
+ // Only lock jobs for the time it takes
+ // to get a job, not run it.
+ let lock = jobs.lock();
+ lock.recv_opt()
+ };
+
+ match message {
+ Ok(job) => job(),
+
+ // The Taskpool was dropped.
+ Err(..) => break
+ }
+ }
+
+ sentinel.cancel();
+ })
}
-#[test]
-#[should_fail]
-fn test_zero_tasks_panic() {
- let f: || -> proc(uint):Send -> uint = || { proc(i) i };
- TaskPool::new(0, f);
+#[cfg(test)]
+mod test {
+ use core::prelude::*;
+ use super::*;
+ use comm::channel;
+ use iter::range;
+
+ const TEST_TASKS: uint = 4u;
+
+ #[test]
+ fn test_works() {
+ use iter::AdditiveIterator;
+
+ let pool = TaskPool::new(TEST_TASKS);
+
+ let (tx, rx) = channel();
+ for _ in range(0, TEST_TASKS) {
+ let tx = tx.clone();
+ pool.execute(proc() {
+ tx.send(1u);
+ });
+ }
+
+ assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
+ }
+
+ #[test]
+ #[should_fail]
+ fn test_zero_tasks_panic() {
+ TaskPool::new(0);
+ }
+
+ #[test]
+ fn test_recovery_from_subtask_panic() {
+ use iter::AdditiveIterator;
+
+ let pool = TaskPool::new(TEST_TASKS);
+
+ // Panic all the existing tasks.
+ for _ in range(0, TEST_TASKS) {
+ pool.execute(proc() { panic!() });
+ }
+
+ // Ensure new tasks were spawned to compensate.
+ let (tx, rx) = channel();
+ for _ in range(0, TEST_TASKS) {
+ let tx = tx.clone();
+ pool.execute(proc() {
+ tx.send(1u);
+ });
+ }
+
+ assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
+ }
+
+ #[test]
+ fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
+ use sync::{Arc, Barrier};
+
+ let pool = TaskPool::new(TEST_TASKS);
+ let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
+
+ // Panic all the existing tasks in a bit.
+ for _ in range(0, TEST_TASKS) {
+ let waiter = waiter.clone();
+ pool.execute(proc() {
+ waiter.wait();
+ panic!();
+ });
+ }
+
+ drop(pool);
+
+ // Kick off the failure.
+ waiter.wait();
+ }
}
+