Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Introduce ThreadPoolBuilder::use_current_thread. #1063

Merged
merged 6 commits into from Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions rayon-core/Cargo.toml
Expand Up @@ -53,3 +53,7 @@ path = "tests/simple_panic.rs"
[[test]]
name = "scoped_threadpool"
path = "tests/scoped_threadpool.rs"

[[test]]
name = "use_current_thread"
path = "tests/use_current_thread.rs"
33 changes: 32 additions & 1 deletion rayon-core/src/lib.rs
Expand Up @@ -147,6 +147,7 @@ pub struct ThreadPoolBuildError {
#[derive(Debug)]
enum ErrorKind {
GlobalPoolAlreadyInitialized,
CurrentThreadAlreadyInPool,
IOError(io::Error),
}

Expand Down Expand Up @@ -174,6 +175,9 @@ pub struct ThreadPoolBuilder<S = DefaultSpawn> {
/// If RAYON_NUM_THREADS is invalid or zero will use the default.
num_threads: usize,

/// The thread we're building *from* will also be part of the pool.
use_current_thread: bool,

/// Custom closure, if any, to handle a panic that we cannot propagate
/// anywhere else.
panic_handler: Option<Box<PanicHandler>>,
Expand Down Expand Up @@ -227,6 +231,7 @@ impl Default for ThreadPoolBuilder {
fn default() -> Self {
ThreadPoolBuilder {
num_threads: 0,
use_current_thread: false,
panic_handler: None,
get_thread_name: None,
stack_size: None,
Expand Down Expand Up @@ -437,6 +442,7 @@ impl<S> ThreadPoolBuilder<S> {
spawn_handler: CustomSpawn::new(spawn),
// ..self
num_threads: self.num_threads,
use_current_thread: self.use_current_thread,
panic_handler: self.panic_handler,
get_thread_name: self.get_thread_name,
stack_size: self.stack_size,
Expand Down Expand Up @@ -529,6 +535,24 @@ impl<S> ThreadPoolBuilder<S> {
self
}

/// Use the current thread as one of the threads in the pool.
///
/// The current thread is guaranteed to be at index 0, and since the thread is not managed by
/// rayon, the spawn and exit handlers do not run for that thread.
///
/// Note that the current thread won't run the main work-stealing loop, so jobs spawned into
/// the thread-pool will generally not be picked up automatically by this thread unless you
/// yield to rayon in some way, like via [`yield_now()`], [`yield_local()`], or [`scope()`].
///
/// # Local thread-pools
///
/// Using this in a local thread-pool means the registry will be leaked. In future versions
/// there might be a way of cleaning up the current-thread state.
pub fn use_current_thread(mut self) -> Self {
self.use_current_thread = true;
self
}

/// Returns a copy of the current panic handler.
fn take_panic_handler(&mut self) -> Option<Box<PanicHandler>> {
self.panic_handler.take()
Expand Down Expand Up @@ -731,18 +755,22 @@ impl ThreadPoolBuildError {
const GLOBAL_POOL_ALREADY_INITIALIZED: &str =
"The global thread pool has already been initialized.";

const CURRENT_THREAD_ALREADY_IN_POOL: &str =
"The current thread is already part of another thread pool.";

impl Error for ThreadPoolBuildError {
#[allow(deprecated)]
fn description(&self) -> &str {
match self.kind {
ErrorKind::GlobalPoolAlreadyInitialized => GLOBAL_POOL_ALREADY_INITIALIZED,
ErrorKind::CurrentThreadAlreadyInPool => CURRENT_THREAD_ALREADY_IN_POOL,
ErrorKind::IOError(ref e) => e.description(),
}
}

fn source(&self) -> Option<&(dyn Error + 'static)> {
match &self.kind {
ErrorKind::GlobalPoolAlreadyInitialized => None,
ErrorKind::GlobalPoolAlreadyInitialized | ErrorKind::CurrentThreadAlreadyInPool => None,
ErrorKind::IOError(e) => Some(e),
}
}
Expand All @@ -751,6 +779,7 @@ impl Error for ThreadPoolBuildError {
impl fmt::Display for ThreadPoolBuildError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
ErrorKind::CurrentThreadAlreadyInPool => CURRENT_THREAD_ALREADY_IN_POOL.fmt(f),
ErrorKind::GlobalPoolAlreadyInitialized => GLOBAL_POOL_ALREADY_INITIALIZED.fmt(f),
ErrorKind::IOError(e) => e.fmt(f),
}
Expand All @@ -768,6 +797,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let ThreadPoolBuilder {
ref num_threads,
ref use_current_thread,
ref get_thread_name,
ref panic_handler,
ref stack_size,
Expand All @@ -792,6 +822,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {

f.debug_struct("ThreadPoolBuilder")
.field("num_threads", num_threads)
.field("use_current_thread", use_current_thread)
.field("get_thread_name", &get_thread_name)
.field("panic_handler", &panic_handler)
.field("stack_size", &stack_size)
Expand Down
40 changes: 20 additions & 20 deletions rayon-core/src/registry.rs
Expand Up @@ -207,26 +207,7 @@ fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
// is stubbed out, and we won't have to change anything if they do add real threading.
let unsupported = matches!(&result, Err(e) if e.is_unsupported());
if unsupported && WorkerThread::current().is_null() {
let builder = ThreadPoolBuilder::new()
.num_threads(1)
.spawn_handler(|thread| {
// Rather than starting a new thread, we're just taking over the current thread
// *without* running the main loop, so we can still return from here.
// The WorkerThread is leaked, but we never shutdown the global pool anyway.
let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));
let registry = &*worker_thread.registry;
let index = worker_thread.index;

unsafe {
WorkerThread::set_current(worker_thread);

// let registry know we are ready to do work
Latch::set(&registry.thread_infos[index].primed);
}

Ok(())
});

let builder = ThreadPoolBuilder::new().num_threads(1).use_current_thread();
let fallback_result = Registry::new(builder);
if fallback_result.is_ok() {
return fallback_result;
Expand Down Expand Up @@ -300,6 +281,25 @@ impl Registry {
stealer,
index,
};

if index == 0 && builder.use_current_thread {
if !WorkerThread::current().is_null() {
return Err(ThreadPoolBuildError::new(
ErrorKind::CurrentThreadAlreadyInPool,
));
}
// Rather than starting a new thread, we're just taking over the current thread
// *without* running the main loop, so we can still return from here.
// The WorkerThread is leaked, but we never shutdown the global pool anyway.
let worker_thread = Box::into_raw(Box::new(WorkerThread::from(thread)));

unsafe {
WorkerThread::set_current(worker_thread);
Latch::set(&registry.thread_infos[index].primed);
}
continue;
}

if let Err(e) = builder.get_spawn_handler().spawn(thread) {
return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
}
Expand Down
57 changes: 57 additions & 0 deletions rayon-core/tests/use_current_thread.rs
@@ -0,0 +1,57 @@
use rayon_core::ThreadPoolBuilder;
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle};

#[test]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn use_current_thread_basic() {
static JOIN_HANDLES: Mutex<Vec<JoinHandle<()>>> = Mutex::new(Vec::new());
let pool = ThreadPoolBuilder::new()
.num_threads(2)
.use_current_thread()
.spawn_handler(|builder| {
let handle = thread::Builder::new().spawn(|| builder.run())?;
JOIN_HANDLES.lock().unwrap().push(handle);
Ok(())
})
.build()
.unwrap();
assert_eq!(rayon_core::current_thread_index(), Some(0));
assert_eq!(
JOIN_HANDLES.lock().unwrap().len(),
1,
"Should only spawn one extra thread"
);

let another_pool = ThreadPoolBuilder::new()
.num_threads(2)
.use_current_thread()
.build();
assert!(
another_pool.is_err(),
"Should error if the thread is already part of a pool"
);

let pair = Arc::new((Mutex::new(false), Condvar::new()));
let pair2 = Arc::clone(&pair);
pool.spawn(move || {
assert_ne!(rayon_core::current_thread_index(), Some(0));
// This should execute even if the current thread is blocked, since we have two threads in
// the pool.
let &(ref started, ref condvar) = &*pair2;
*started.lock().unwrap() = true;
condvar.notify_one();
});

let _guard = pair
.1
.wait_while(pair.0.lock().unwrap(), |ran| !*ran)
.unwrap();
std::mem::drop(pool); // Drop the pool.

// Wait until all threads have actually exited. This is not really needed, other than to
// reduce noise of leak-checking tools.
for handle in std::mem::take(&mut *JOIN_HANDLES.lock().unwrap()) {
let _ = handle.join();
}
}