Skip to content

Commit

Permalink
Add ThreadPool::broadcast
Browse files Browse the repository at this point in the history
A broadcast runs the closure on every thread in the pool, then collects
the results.  It's scheduled somewhat like a very soft interrupt -- it
won't preempt a thread's local work, but will run before it goes to
steal from any other threads.

This can be used when you want to precisely split your work per-thread,
or to set or retrieve some thread-local data in the pool, e.g. #483.
  • Loading branch information
cuviper committed Oct 3, 2018
1 parent df86443 commit 4e05307
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 7 deletions.
1 change: 1 addition & 0 deletions rayon-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ mod test;
#[cfg(rayon_unstable)]
pub mod internal;
pub use thread_pool::ThreadPool;
pub use thread_pool::broadcast;
pub use thread_pool::current_thread_index;
pub use thread_pool::current_thread_has_pending_tasks;
pub use join::{join, join_context};
Expand Down
1 change: 1 addition & 0 deletions rayon-core/src/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub enum Event {
WaitUntil { worker: usize },
LatchSet { worker: usize },
InjectJobs { count: usize },
BroadcastJobs { count: usize },
Join { worker: usize },
PoppedJob { worker: usize },
PoppedRhs { worker: usize },
Expand Down
120 changes: 113 additions & 7 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub struct Registry {

struct RegistryState {
job_injector: Deque<JobRef>,
direct_injectors: Vec<Deque<JobRef>>,
}

/// ////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -102,6 +103,12 @@ impl Registry {

let inj_worker = Deque::new();
let inj_stealer = inj_worker.stealer();

let dinj_workers: Vec<_> = (0..n_threads)
.map(|_| Deque::new())
.collect();
let dinj_stealers: Vec<_> = dinj_workers.iter().map(|d| d.stealer()).collect();

let workers: Vec<_> = (0..n_threads)
.map(|_| Deque::new())
.collect();
Expand All @@ -111,7 +118,7 @@ impl Registry {
thread_infos: stealers.into_iter()
.map(|s| ThreadInfo::new(s))
.collect(),
state: Mutex::new(RegistryState::new(inj_worker)),
state: Mutex::new(RegistryState::new(inj_worker, dinj_workers)),
sleep: Sleep::new(),
job_uninjector: inj_stealer,
terminate_latch: CountLatch::new(),
Expand All @@ -123,7 +130,7 @@ impl Registry {
// If we return early or panic, make sure to terminate existing threads.
let t1000 = Terminator(&registry);

for (index, worker) in workers.into_iter().enumerate() {
for (index, (worker, stealer)) in workers.into_iter().zip(dinj_stealers).enumerate() {
let registry = registry.clone();
let mut b = thread::Builder::new();
if let Some(name) = builder.get_thread_name(index) {
Expand All @@ -132,7 +139,7 @@ impl Registry {
if let Some(stack_size) = builder.get_stack_size() {
b = b.stack_size(stack_size);
}
if let Err(e) = b.spawn(move || unsafe { main_loop(worker, registry, index, breadth_first) }) {
if let Err(e) = b.spawn(move || unsafe { main_loop(worker, stealer, registry, index, breadth_first) }) {
return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)))
}
}
Expand Down Expand Up @@ -287,7 +294,7 @@ impl Registry {
}

/// Push a job into the "external jobs" queue; it will be taken by
/// whatever worker has nothing to do. Use this is you know that
/// whatever worker has nothing to do. Use this if you know that
/// you are not on a worker of this registry.
pub fn inject(&self, injected_jobs: &[JobRef]) {
log!(InjectJobs { count: injected_jobs.len() });
Expand Down Expand Up @@ -321,6 +328,89 @@ impl Registry {
}
}

/// Push a job into each thread's own "external jobs" queue; it will be
/// executed only on that thread, when it has nothing else to do locally,
/// before it tries to steal other work.
///
/// **Panics** if not given exactly as many jobs as there are threads.
pub fn inject_all(&self, injected_jobs: &[JobRef]) {
assert_eq!(self.num_threads(), injected_jobs.len());
log!(BroadcastJobs { count: injected_jobs.len() });
{
let state = self.state.lock().unwrap();

// It should not be possible for `state.terminate` to be true
// here. It is only set to true when the user creates (and
// drops) a `ThreadPool`; and, in that case, they cannot be
// calling `inject_all()` later, since they dropped their
// `ThreadPool`.
assert!(!self.terminate_latch.probe(), "inject_all() sees state.terminate as true");

assert_eq!(state.direct_injectors.len(), injected_jobs.len());
for (worker, &job_ref) in state.direct_injectors.iter().zip(injected_jobs) {
worker.push(job_ref);
}
}
self.sleep.tickle(usize::MAX);
}

/// Execute `op` on every thread in the pool. It will be executed on each
/// thread when they have nothing else to do locally, before they try to
/// steal work from other threads. This function will not return until all
/// threads have completed the `op`.
pub fn broadcast<OP, R>(&self, op: OP) -> Vec<R>
where OP: Fn(&WorkerThread) -> R + Sync,
R: Send
{
unsafe {
if let Some(current_thread) = WorkerThread::current().as_ref() {
if current_thread.registry().id() == self.id() {
// broadcasting within in our own pool
self.broadcast_jobs(op, SpinLatch::new,
|latch| current_thread.wait_until(latch))
} else {
// broadcasting from a different pool
let sleep = &current_thread.registry().sleep;
self.broadcast_jobs(op,
|| TickleLatch::new(SpinLatch::new(), sleep),
|latch| current_thread.wait_until(latch))
}
} else {
// broadcasting from outside any pool
self.broadcast_jobs(op, LockLatch::new, |latch| latch.wait())
}
}
}

/// Common broadcast helper with different kinds of latches
unsafe fn broadcast_jobs<OP, R, L, New, Wait>(&self, op: OP, latch: New, wait: Wait) -> Vec<R>
where OP: Fn(&WorkerThread) -> R + Sync,
R: Send,
L: Latch + Sync,
New: Fn() -> L,
Wait: Fn(&L),
{
let f = |injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(&*worker_thread)
};

let n_threads = self.thread_infos.len();
let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, latch())).collect();
let job_refs: Vec<_> = jobs.iter().map(|job| job.as_job_ref()).collect();

self.inject_all(&job_refs);

// Let all jobs have a chance to complete.
for job in &jobs {
wait(&job.latch);
}

// Collect the results, maybe propagating a panic.
jobs.into_iter().map(|job| job.into_result()).collect()
}

/// If already in a worker-thread of this registry, just execute `op`.
/// Otherwise, inject `op` in this thread-pool. Either way, block until `op`
/// completes and return its return value. If `op` panics, that panic will
Expand Down Expand Up @@ -417,9 +507,10 @@ pub struct RegistryId {
}

impl RegistryState {
pub fn new(job_injector: Deque<JobRef>) -> RegistryState {
pub fn new(job_injector: Deque<JobRef>, direct_injectors: Vec<Deque<JobRef>>) -> RegistryState {
RegistryState {
job_injector: job_injector,
direct_injectors: direct_injectors,
}
}
}
Expand Down Expand Up @@ -455,6 +546,9 @@ pub struct WorkerThread {
/// the "worker" half of our local deque
worker: Deque<JobRef>,

/// the "stealer" half of the worker's direct injection deque
stealer: Stealer<JobRef>,

index: usize,

/// are these workers configured to steal breadth-first or not?
Expand Down Expand Up @@ -523,16 +617,26 @@ impl WorkerThread {
#[inline]
pub unsafe fn take_local_job(&self) -> Option<JobRef> {
if !self.breadth_first {
self.worker.pop()
if let opt_job @ Some(_) = self.worker.pop() {
return opt_job;
}
} else {
loop {
match self.worker.steal() {
Steal::Empty => return None,
Steal::Empty => break,
Steal::Data(d) => return Some(d),
Steal::Retry => {},
}
}
}

loop {
match self.stealer.steal() {
Steal::Empty => return None,
Steal::Data(d) => return Some(d),
Steal::Retry => {},
}
}
}

/// Wait until the latch is set. Try to keep busy by popping and
Expand Down Expand Up @@ -631,11 +735,13 @@ impl WorkerThread {
/// ////////////////////////////////////////////////////////////////////////

unsafe fn main_loop(worker: Deque<JobRef>,
stealer: Stealer<JobRef>,
registry: Arc<Registry>,
index: usize,
breadth_first: bool) {
let worker_thread = WorkerThread {
worker: worker,
stealer: stealer,
breadth_first: breadth_first,
index: index,
rng: XorShift64Star::new(),
Expand Down
57 changes: 57 additions & 0 deletions rayon-core/src/thread_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,48 @@ impl ThreadPool {
self.registry.in_worker(|_, _| op())
}

/// Executes `op` within every thread in the threadpool. Any attempts to use
/// `join`, `scope`, or parallel iterators will then operate within that
/// threadpool.
///
/// # Warning: thread-local data
///
/// Because `op` is executing within the Rayon thread-pool,
/// thread-local data from the current thread will not be
/// accessible.
///
/// # Panics
///
/// If `op` should panic on one or more threads, exactly one panic
/// will be propagated, only after all threads have completed
/// (or panicked) their own `op`.
///
/// # Examples
///
/// ```
/// # use rayon_core as rayon;
/// use std::sync::atomic::{AtomicUsize, Ordering};
///
/// fn main() {
/// let pool = rayon::ThreadPoolBuilder::new().num_threads(5).build().unwrap();
///
/// // The argument is the index of each thread
/// let v: Vec<usize> = pool.broadcast(|i| i * i);
/// assert_eq!(v, &[0, 1, 4, 9, 16]);
///
/// // The closure can reference the local stack
/// let count = AtomicUsize::new(0);
/// pool.broadcast(|_| count.fetch_add(1, Ordering::Relaxed));
/// assert_eq!(count.into_inner(), 5);
/// }
/// ```
pub fn broadcast<OP, R>(&self, op: OP) -> Vec<R>
where OP: Fn(usize) -> R + Sync,
R: Send
{
self.registry.broadcast(|worker| op(worker.index()))
}

/// Returns the (current) number of threads in the thread pool.
///
/// # Future compatibility note
Expand Down Expand Up @@ -316,3 +358,18 @@ pub fn current_thread_has_pending_tasks() -> Option<bool> {
}
}
}

/// Executes `op` within every thread in the current threadpool. If this is
/// called from a non-Rayon thread, it will execute in the global threadpool.
/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
/// within that threadpool.
///
/// For more information, see the [`ThreadPool::broadcast()`][m] method.
///
/// [m]: struct.ThreadPool.html#method.broadcast
pub fn broadcast<OP, R>(op: OP) -> Vec<R>
where OP: Fn(usize) -> R + Sync,
R: Send
{
Registry::current().broadcast(|worker| op(worker.index()))
}
85 changes: 85 additions & 0 deletions rayon-core/src/thread_pool/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,88 @@ fn check_thread_pool_new() {
let pool = ThreadPool::new(Configuration::new().num_threads(22)).unwrap();
assert_eq!(pool.current_num_threads(), 22);
}

#[test]
fn broadcast_global() {
let v = ::broadcast(|i| i);
assert!(v.into_iter().eq(0..::current_num_threads()));
}

#[test]
fn broadcast_pool() {
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let v = pool.broadcast(|i| i);
assert!(v.into_iter().eq(0..7));
}

#[test]
fn broadcast_self() {
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let v = pool.install(|| ::broadcast(|i| i));
assert!(v.into_iter().eq(0..7));
}

#[test]
fn broadcast_mutual() {
let count = AtomicUsize::new(0);
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.install(|| {
pool2.broadcast(|_| {
pool1.broadcast(|_| { count.fetch_add(1, Ordering::Relaxed); })
})
});
assert_eq!(count.into_inner(), 3 * 7);
}

#[test]
fn broadcast_mutual_sleepy() {
use std::{thread, time};

let count = AtomicUsize::new(0);
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.install(|| {
thread::sleep(time::Duration::from_secs(1));
pool2.broadcast(|_| {
thread::sleep(time::Duration::from_secs(1));
pool1.broadcast(|_| {
thread::sleep(time::Duration::from_millis(100));
count.fetch_add(1, Ordering::Relaxed);
})
})
});
assert_eq!(count.into_inner(), 3 * 7);
}

#[test]
fn broadcast_panic_one() {
let count = AtomicUsize::new(0);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let result = unwind::halt_unwinding(|| {
pool.broadcast(|i| {
count.fetch_add(1, Ordering::Relaxed);
if i == 3 {
panic!("Hello, world!");
}
})
});
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}

#[test]
fn broadcast_panic_many() {
let count = AtomicUsize::new(0);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let result = unwind::halt_unwinding(|| {
pool.broadcast(|i| {
count.fetch_add(1, Ordering::Relaxed);
if i % 2 == 0 {
panic!("Hello, world!");
}
})
});
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,4 @@ pub use rayon_core::{join, join_context};
pub use rayon_core::FnContext;
pub use rayon_core::{scope, Scope};
pub use rayon_core::spawn;
pub use rayon_core::broadcast;

0 comments on commit 4e05307

Please sign in to comment.