Skip to content

Commit

Permalink
Merge pull request #44 from Amanieu/split-tls
Browse files Browse the repository at this point in the history
  • Loading branch information
Amanieu committed Feb 8, 2023
2 parents 3472eb1 + cedabb1 commit 1466993
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 25 deletions.
8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@ license = "MIT OR Apache-2.0"
repository = "https://github.com/Amanieu/thread_local-rs"
readme = "README.md"
keywords = ["thread_local", "concurrent", "thread"]
edition = "2018"
edition = "2021"

[features]
# this feature provides performance improvements using nightly features
nightly = []

[badges]
travis-ci = { repository = "Amanieu/thread_local-rs" }

[dependencies]
once_cell = "1.5.2"
# this is required to gate `nightly` related code paths
cfg-if = "1.0.0"

# This is actually a dev-dependency, see https://github.com/rust-lang/cargo/issues/1596
criterion = { version = "0.4.0", optional = true }
Expand Down
17 changes: 10 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@

#![warn(missing_docs)]
#![allow(clippy::mutex_atomic)]
#![cfg_attr(feature = "nightly", feature(thread_local))]

mod cached;
mod thread_id;
Expand Down Expand Up @@ -189,8 +190,7 @@ impl<T: Send> ThreadLocal<T> {

/// Returns the element for the current thread, if it exists.
pub fn get(&self) -> Option<&T> {
let thread = thread_id::get();
self.get_inner(thread)
thread_id::try_get().and_then(|thread| self.get_inner(thread))
}

/// Returns the element for the current thread, or creates it if it doesn't
Expand All @@ -212,11 +212,13 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> Result<T, E>,
{
let thread = thread_id::get();
match self.get_inner(thread) {
Some(x) => Ok(x),
None => Ok(self.insert(thread, create()?)),
let thread = thread_id::try_get();
if let Some(thread) = thread {
if let Some(val) = self.get_inner(thread) {
return Ok(val);
}
}
Ok(self.insert(create()?))
}

fn get_inner(&self, thread: Thread) -> Option<&T> {
Expand All @@ -237,7 +239,8 @@ impl<T: Send> ThreadLocal<T> {
}

#[cold]
fn insert(&self, thread: Thread, data: T) -> &T {
fn insert(&self, data: T) -> &T {
let thread = thread_id::get();
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);

Expand Down
104 changes: 87 additions & 17 deletions src/thread_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,95 @@ impl Thread {
}
}

/// Wrapper around `Thread` that allocates and deallocates the ID.
struct ThreadHolder(Thread);
impl ThreadHolder {
fn new() -> ThreadHolder {
ThreadHolder(Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc()))
}
}
impl Drop for ThreadHolder {
fn drop(&mut self) {
THREAD_ID_MANAGER.lock().unwrap().free(self.0.id);
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "nightly")] {
// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller.
#[thread_local]
static mut THREAD: Option<Thread> = None;
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard }; }

// Guard to ensure the thread ID is released on thread exit.
struct ThreadGuard;

impl Drop for ThreadGuard {
fn drop(&mut self) {
// SAFETY: this is safe because we know that we (the current thread)
// are the only one who can be accessing our `THREAD` and thus
// it's safe for us to access and drop it.
if let Some(thread) = unsafe { THREAD.take() } {
THREAD_ID_MANAGER.lock().unwrap().free(thread.id);
}
}
}

/// Attempts to get the current thread if `get` has previously been
/// called.
#[inline]
pub(crate) fn try_get() -> Option<Thread> {
unsafe {
THREAD
}
}

thread_local!(static THREAD_HOLDER: ThreadHolder = ThreadHolder::new());
/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get() -> Thread {
if let Some(thread) = unsafe { THREAD } {
thread
} else {
let new = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
unsafe {
THREAD = Some(new);
}
THREAD_GUARD.with(|_| {});
new
}
}
} else {
use std::cell::Cell;

// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller.
thread_local! { static THREAD: Cell<Option<Thread>> = const { Cell::new(None) }; }
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard }; }

// Guard to ensure the thread ID is released on thread exit.
struct ThreadGuard;

impl Drop for ThreadGuard {
fn drop(&mut self) {
let thread = THREAD.with(|thread| thread.get()).unwrap();
THREAD_ID_MANAGER.lock().unwrap().free(thread.id);
}
}

/// Get the current thread.
#[inline]
pub(crate) fn get() -> Thread {
THREAD_HOLDER.with(|holder| holder.0)
/// Attempts to get the current thread if `get` has previously been
/// called.
#[inline]
pub(crate) fn try_get() -> Option<Thread> {
THREAD.with(|thread| thread.get())
}

/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get() -> Thread {
THREAD.with(|thread| {
if let Some(thread) = thread.get() {
thread
} else {
let new = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
thread.set(Some(new));
THREAD_GUARD.with(|_| {});
new
}
})
}
}
}

#[test]
Expand Down

0 comments on commit 1466993

Please sign in to comment.