]> git.lizzy.rs Git - rust.git/commitdiff
Rewrite std::sync::TaskPool to be load balancing and panic-resistant
authorJonathan Reem <jonathan.reem@gmail.com>
Fri, 14 Nov 2014 02:04:28 +0000 (18:04 -0800)
committerJonathan Reem <jonathan.reem@gmail.com>
Fri, 14 Nov 2014 06:57:33 +0000 (22:57 -0800)
The previous implementation was very likely to cause panics during
unwinding through this process:

- child panics, drops its receiver
- taskpool comes back around and sends another job over to that child
- the child receiver has hung up, so the taskpool panics on send
- during unwinding, the taskpool attempts to send a quit message to
  the child, causing a panic during unwinding
- panic during unwinding causes a process abort

This meant that TaskPool upgraded any child panic to a full process
abort. This came up in Iron when it caused crashes in long-running
servers.

This implementation uses a single channel to communicate between
spawned tasks and the TaskPool, which significantly reduces the complexity
of the implementation and cuts down on allocation. The TaskPool uses
the channel as a single-producer-multiple-consumer queue.

Additionally, through the use of send_opt and recv_opt instead of
send and recv, this TaskPool is robust on the face of child panics,
both before, during, and after the TaskPool itself is dropped.

Due to the TaskPool no longer using an `init_fn_factory`, this is a

[breaking-change]

otherwise, the API has not changed.

If you used `init_fn_factory` in your code, and this change breaks for
you, you can instead use an `AtomicUint` counter and a channel to
move information into child tasks.

src/libstd/sync/task_pool.rs

index d4a60fb584457e90088b73fa9739837af669d3ef..2682582d708a871084e040584c75ab7b2ae7e475 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright 2012 The Rust Project Developers. See the COPYRIGHT
+// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
 // file at the top-level directory of this distribution and at
 // http://rust-lang.org/COPYRIGHT.
 //
 
 use core::prelude::*;
 
-use task;
 use task::spawn;
-use vec::Vec;
-use comm::{channel, Sender};
+use comm::{channel, Sender, Receiver};
+use sync::{Arc, Mutex};
 
-enum Msg<T> {
-    Execute(proc(&T):Send),
-    Quit
+struct Sentinel<'a> {
+    jobs: &'a Arc<Mutex<Receiver<proc(): Send>>>,
+    active: bool
 }
 
-/// A task pool used to execute functions in parallel.
-pub struct TaskPool<T> {
-    channels: Vec<Sender<Msg<T>>>,
-    next_index: uint,
+impl<'a> Sentinel<'a> {
+    fn new(jobs: &Arc<Mutex<Receiver<proc(): Send>>>) -> Sentinel {
+        Sentinel {
+            jobs: jobs,
+            active: true
+        }
+    }
+
+    // Cancel and destroy this sentinel.
+    fn cancel(mut self) {
+        self.active = false;
+    }
 }
 
 #[unsafe_destructor]
-impl<T> Drop for TaskPool<T> {
+impl<'a> Drop for Sentinel<'a> {
     fn drop(&mut self) {
-        for channel in self.channels.iter_mut() {
-            channel.send(Quit);
+        if self.active {
+            spawn_in_pool(self.jobs.clone())
         }
     }
 }
 
-impl<T> TaskPool<T> {
-    /// Spawns a new task pool with `n_tasks` tasks. The provided
-    /// `init_fn_factory` returns a function which, given the index of the
-    /// task, should return local data to be kept around in that task.
+/// A task pool used to execute functions in parallel.
+///
+/// Spawns `n` worker tasks and replenishes the pool if any worker tasks
+/// panic.
+///
+/// # Example
+///
+/// ```rust
+/// # use sync::TaskPool;
+/// # use iter::AdditiveIterator;
+///
+/// let pool = TaskPool::new(4u);
+///
+/// let (tx, rx) = channel();
+/// for _ in range(0, 8u) {
+///     let tx = tx.clone();
+///     pool.execute(proc() {
+///         tx.send(1u);
+///     });
+/// }
+///
+/// assert_eq!(rx.iter().take(8u).sum(), 8u);
+/// ```
+pub struct TaskPool {
+    // How the taskpool communicates with subtasks.
+    //
+    // This is the only such Sender, so when it is dropped all subtasks will
+    // quit.
+    jobs: Sender<proc(): Send>
+}
+
+impl TaskPool {
+    /// Spawns a new task pool with `tasks` tasks.
     ///
     /// # Panics
     ///
-    /// This function will panic if `n_tasks` is less than 1.
-    pub fn new(n_tasks: uint,
-               init_fn_factory: || -> proc(uint):Send -> T)
-               -> TaskPool<T> {
-        assert!(n_tasks >= 1);
-
-        let channels = Vec::from_fn(n_tasks, |i| {
-            let (tx, rx) = channel::<Msg<T>>();
-            let init_fn = init_fn_factory();
-
-            let task_body = proc() {
-                let local_data = init_fn(i);
-                loop {
-                    match rx.recv() {
-                        Execute(f) => f(&local_data),
-                        Quit => break
-                    }
-                }
-            };
+    /// This function will panic if `tasks` is 0.
+    pub fn new(tasks: uint) -> TaskPool {
+        assert!(tasks >= 1);
 
-            // Run on this scheduler.
-            task::spawn(task_body);
+        let (tx, rx) = channel::<proc(): Send>();
+        let rx = Arc::new(Mutex::new(rx));
 
-            tx
-        });
+        // Taskpool tasks.
+        for _ in range(0, tasks) {
+            spawn_in_pool(rx.clone());
+        }
 
-        return TaskPool {
-            channels: channels,
-            next_index: 0,
-        };
+        TaskPool { jobs: tx }
     }
 
-    /// Executes the function `f` on a task in the pool. The function
-    /// receives a reference to the local data returned by the `init_fn`.
-    pub fn execute(&mut self, f: proc(&T):Send) {
-        self.channels[self.next_index].send(Execute(f));
-        self.next_index += 1;
-        if self.next_index == self.channels.len() { self.next_index = 0; }
+    /// Executes the function `job` on a task in the pool.
+    pub fn execute(&self, job: proc():Send) {
+        self.jobs.send(job);
     }
 }
 
-#[test]
-fn test_task_pool() {
-    let f: || -> proc(uint):Send -> uint = || { proc(i) i };
-    let mut pool = TaskPool::new(4, f);
-    for _ in range(0u, 8) {
-        pool.execute(proc(i) println!("Hello from thread {}!", *i));
-    }
+fn spawn_in_pool(jobs: Arc<Mutex<Receiver<proc(): Send>>>) {
+    spawn(proc() {
+        // Will spawn a new task on panic unless it is cancelled.
+        let sentinel = Sentinel::new(&jobs);
+
+        loop {
+            let message = {
+                // Only lock jobs for the time it takes
+                // to get a job, not run it.
+                let lock = jobs.lock();
+                lock.recv_opt()
+            };
+
+            match message {
+                Ok(job) => job(),
+
+                // The Taskpool was dropped.
+                Err(..) => break
+            }
+        }
+
+        sentinel.cancel();
+    })
 }
 
-#[test]
-#[should_fail]
-fn test_zero_tasks_panic() {
-    let f: || -> proc(uint):Send -> uint = || { proc(i) i };
-    TaskPool::new(0, f);
+#[cfg(test)]
+mod test {
+    use core::prelude::*;
+    use super::*;
+    use comm::channel;
+    use iter::range;
+
+    const TEST_TASKS: uint = 4u;
+
+    #[test]
+    fn test_works() {
+        use iter::AdditiveIterator;
+
+        let pool = TaskPool::new(TEST_TASKS);
+
+        let (tx, rx) = channel();
+        for _ in range(0, TEST_TASKS) {
+            let tx = tx.clone();
+            pool.execute(proc() {
+                tx.send(1u);
+            });
+        }
+
+        assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
+    }
+
+    #[test]
+    #[should_fail]
+    fn test_zero_tasks_panic() {
+        TaskPool::new(0);
+    }
+
+    #[test]
+    fn test_recovery_from_subtask_panic() {
+        use iter::AdditiveIterator;
+
+        let pool = TaskPool::new(TEST_TASKS);
+
+        // Panic all the existing tasks.
+        for _ in range(0, TEST_TASKS) {
+            pool.execute(proc() { panic!() });
+        }
+
+        // Ensure new tasks were spawned to compensate.
+        let (tx, rx) = channel();
+        for _ in range(0, TEST_TASKS) {
+            let tx = tx.clone();
+            pool.execute(proc() {
+                tx.send(1u);
+            });
+        }
+
+        assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
+    }
+
+    #[test]
+    fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
+        use sync::{Arc, Barrier};
+
+        let pool = TaskPool::new(TEST_TASKS);
+        let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
+
+        // Panic all the existing tasks in a bit.
+        for _ in range(0, TEST_TASKS) {
+            let waiter = waiter.clone();
+            pool.execute(proc() {
+                waiter.wait();
+                panic!();
+            });
+        }
+
+        drop(pool);
+
+        // Kick off the failure.
+        waiter.wait();
+    }
 }
+