Skip to content

Commit

Permalink
Integrate improvements from Amanieu#44 and get rid of ThreadHolder
Browse files Browse the repository at this point in the history
  • Loading branch information
terrarier2111 committed Dec 14, 2022
1 parent 5cd62a0 commit 8549e06
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 71 deletions.
35 changes: 17 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ mod unreachable;
#[allow(deprecated)]
pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};

#[cfg(feature = "nightly")]
use crate::thread_id::ThreadHolder;
use std::cell::UnsafeCell;
use std::fmt;
use std::iter::FusedIterator;
Expand Down Expand Up @@ -193,17 +191,16 @@ impl<T: Send> ThreadLocal<T> {
/// Returns the element for the current thread, if it exists.
#[cfg(feature = "nightly")]
pub fn get(&self) -> Option<&T> {
match thread_id::try_get_thread_holder() {
match thread_id::try_get_thread() {
None => None,
Some(x) => self.get_inner(x.into_inner()),
Some(x) => self.get_inner(x),
}
}

/// Returns the element for the current thread, if it exists.
#[cfg(not(feature = "nightly"))]
pub fn get(&self) -> Option<&T> {
let thread = thread_id::get();
self.get_inner(thread)
thread_id::try_get().and_then(|thread| self.get_inner(thread))
}

/// Returns the element for the current thread, or creates it if it doesn't
Expand All @@ -226,8 +223,8 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> Result<T, E>,
{
if let Some(thread) = thread_id::try_get_thread_holder() {
if let Some(inner) = self.get_inner(thread.into_inner()) {
if let Some(thread) = thread_id::try_get_thread() {
if let Some(inner) = self.get_inner(thread) {
return Ok(inner);
}
}
Expand All @@ -242,11 +239,13 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> Result<T, E>,
{
let thread = thread_id::get();
match self.get_inner(thread) {
Some(x) => Ok(x),
None => Ok(self.insert(thread, create()?)),
let thread = thread_id::try_get();
if let Some(thread) = thread {
if let Some(val) = self.get_inner(thread) {
return Ok(val);
}
}
Ok(self.insert(create()?))
}

fn get_inner(&self, thread: Thread) -> Option<&T> {
Expand All @@ -269,20 +268,20 @@ impl<T: Send> ThreadLocal<T> {
#[cold]
#[cfg(feature = "nightly")]
fn insert(&self, data: T) -> &T {
let thread = if let Some(thread) = thread_id::try_get_thread_holder() {
thread.into_inner()
let thread = if let Some(thread) = thread_id::try_get_thread() {
thread
} else {
let thread = ThreadHolder::new();
thread_id::set_thread_holder(thread);
unsafe { thread_id::try_get_thread_holder().unwrap_unchecked() }.into_inner()
thread_id::set_thread();
unsafe { thread_id::try_get_thread().unwrap_unchecked() }
};

self.insert_inner(thread, data)
}

#[cold]
#[cfg(not(feature = "nightly"))]
fn insert(&self, thread: Thread, data: T) -> &T {
fn insert(&self, data: T) -> &T {
let thread = thread_id::get();
self.insert_inner(thread, data)
}

Expand Down
107 changes: 54 additions & 53 deletions src/thread_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use cfg_if::cfg_if;
use once_cell::sync::Lazy;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::mem::transmute;
use std::sync::Mutex;
use std::usize;

Expand Down Expand Up @@ -75,64 +74,31 @@ impl Thread {
}
}

/// Wrapper around `Thread` that allocates and deallocates the ID.
#[derive(Clone)]
pub(crate) struct ThreadHolder<const OWNER: bool>(Thread);

impl<const OWNER: bool> ThreadHolder<OWNER> {
#[inline(always)]
pub(crate) fn into_inner(self) -> Thread {
self.0
}
}

cfg_if! {
if #[cfg(feature = "nightly")] {
impl ThreadHolder<true> {
pub(crate) fn new() -> Self {
// we have to initialize `THREAD_HOLDER_GUARD` in order for it to protect
// `THREAD_HOLDER` when it gets initialized
THREAD_HOLDER_GUARD.with(|_| {});
ThreadHolder(Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc()))
}
}
} else {
impl ThreadHolder<true> {
fn new() -> Self {
ThreadHolder(Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc()))
}
}
}
}

impl<const OWNER: bool> Drop for ThreadHolder<OWNER> {
fn drop(&mut self) {
if OWNER {
THREAD_ID_MANAGER.lock().unwrap().free(self.0.id);
}
}
}
// Guard to ensure the thread ID is released on thread exit.
struct ThreadGuard;

cfg_if! {
if #[cfg(feature = "nightly")] {
#[thread_local]
static mut THREAD_HOLDER: Option<ThreadHolder<true>> = None;

thread_local! { static THREAD_HOLDER_GUARD: ThreadHolderGuard = const { ThreadHolderGuard }; }
static mut THREAD: Option<Thread> = None;

struct ThreadHolderGuard;
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard }; }

impl Drop for ThreadHolderGuard {
impl Drop for ThreadGuard {
fn drop(&mut self) {
// SAFETY: this is safe because we know that we (the current thread)
// are the only one who can be accessing our `THREAD_HOLDER` and thus
// are the only one who can be accessing our `THREAD` and thus
// it's safe for us to access and drop it.
unsafe { THREAD_HOLDER.take(); }
if let Some(thread) = unsafe { THREAD.take() } {
THREAD_ID_MANAGER.lock().unwrap().free(thread.id);
}
}
}

#[inline]
pub(crate) fn try_get_thread_holder() -> Option<ThreadHolder<false>> {
pub(crate) fn try_get_thread() -> Option<Thread> {
use std::mem::transmute;

// SAFETY: this is safe as the only two possibilities for updates
// are when this thread gets stopped or when the thread holder
// gets first set (which is no problem for this as it can't happen
Expand All @@ -141,27 +107,62 @@ cfg_if! {
// the transmute is safe because the only thing we are changing
// with it is the const generic parameter to a more restrictive
// one which is safe
unsafe { transmute(THREAD_HOLDER.clone()) }
unsafe { transmute(THREAD.clone()) }
}

#[inline]
pub(crate) fn set_thread_holder(thread_holder: ThreadHolder<true>) {
pub(crate) fn set_thread() {
// we have to initialize `THREAD_GUARD` in order for it to protect
// `THREAD_HOLDER` when it gets initialized
THREAD_GUARD.with(|_| {});
let thread = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
// SAFETY: this is safe because we know that there are no references
// to `THREAD_HOLDER` alive when this function gets called
// to `THREAD` alive when this function gets called
// and thus we don't have to care about potential unsafety
// because of references, because there are none
// also the data is thread local which means that
// it's impossible for data races to occur
unsafe { THREAD_HOLDER = Some(thread_holder); }
unsafe { THREAD = Some(thread); }
}

} else {
thread_local!(static THREAD_HOLDER: ThreadHolder<true> = ThreadHolder::new());
use std::cell::Cell;

// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller.
thread_local! { static THREAD: Cell<Option<Thread>> = const { Cell::new(None) }; }
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard }; }

/// Get the current thread.
/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get() -> Thread {
THREAD_HOLDER.with(|holder| holder.0)
THREAD.with(|thread| {
if let Some(thread) = thread.get() {
thread
} else {
debug_assert!(thread.get().is_none());
let new = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
thread.set(Some(new));
THREAD_GUARD.with(|_| {});
new
}
})
}

/// Attempts to get the current thread if `get` has previously been
/// called.
#[inline]
pub(crate) fn try_get() -> Option<Thread> {
THREAD.with(|thread| thread.get())
}

impl Drop for ThreadGuard {
fn drop(&mut self) {
let thread = THREAD.with(|thread| thread.get()).unwrap();
THREAD_ID_MANAGER.lock().unwrap().free(thread.id);
}
}
}
}
Expand Down

0 comments on commit 8549e06

Please sign in to comment.