Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task: Add TaskLocalFuture::take_value method to the task_local #6340

Merged
merged 9 commits into from
Feb 21, 2024
70 changes: 70 additions & 0 deletions tokio/src/task/task_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,76 @@ pin_project! {
}
}

impl<T, F> TaskLocalFuture<T, F>
where
T: 'static,
{
/// Takes the task local value `T` owned by the `TaskLocalFuture`. If the
/// task local value exists, then returns `Some(T)` and the task local value
/// inside the `TaskLocalFuture` becomes unset. If it does not exist,
/// it returns `None`.
///
/// # Examples
///
/// ```
/// # async fn dox() {
/// tokio::task_local! {
/// static KEY: u32;
/// }
///
/// let fut = KEY.scope(42, async {
/// // Do some async work
/// });
///
/// let mut pinned = Box::pin(fut);
///
/// // Complete the TaskLocalFuture
/// let _ = (&mut pinned).as_mut().await;
mox692 marked this conversation as resolved.
Show resolved Hide resolved
///
/// // And here, we can take task local value
/// let value = pinned.as_mut().take_value();
///
/// assert_eq!(value, Some(42));
/// # }
/// ```
///
/// # Note
mox692 marked this conversation as resolved.
Show resolved Hide resolved
///
/// Note that this function attempts to take the task local value regardless of
/// whether the `TaskLocalFuture` is completed or not. This means that if you
/// call this function before the `TaskLocalFuture` is completed, you need to
/// make sure that the access to the task local value in the `TaskLocalFuture` is safe.
mox692 marked this conversation as resolved.
Show resolved Hide resolved
///
/// For example, the following code returns `Err` for accessing the `KEY` variable
/// in the async block of the `scope` function.
///
/// ```
/// # async fn dox() {
/// tokio::task_local! {
/// static KEY: u32;
/// }
///
/// let fut = KEY.scope(42, async {
/// // Since `take_value()` has already been called at this point,
/// // `try_with` here will fail.
/// assert!(KEY.try_with(|_| {}).is_err())
/// });
///
/// let mut pinned = Box::pin(fut);
///
/// // With this call, the task local value of fut is unset.
/// assert_eq!(pinned.as_mut().take_value(), Some(42));
///
/// // Poll **after** invoking `take_value()`
/// let _ = (&mut pinned).as_mut().await;
/// # }
/// ```
pub fn take_value(self: Pin<&mut Self>) -> Option<T> {
mox692 marked this conversation as resolved.
Show resolved Hide resolved
let this = self.project();
this.slot.take()
}
}

impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
type Output = F::Output;

Expand Down
28 changes: 28 additions & 0 deletions tokio/tests/task_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,31 @@ async fn task_local_available_on_completion_drop() {
assert_eq!(rx.await.unwrap(), 42);
h.await.unwrap();
}

#[tokio::test]
async fn take_value() {
tokio::task_local! {
static KEY: u32
}
let fut = KEY.scope(1, async {});
let mut pinned = Box::pin(fut);
assert_eq!(pinned.as_mut().take_value(), Some(1));
assert_eq!(pinned.as_mut().take_value(), None);
}

#[tokio::test]
async fn poll_after_take_value_should_fail() {
tokio::task_local! {
static KEY: u32
}
let fut = KEY.scope(1, async {
let result = KEY.try_with(|_| {});
// The task local value no longer exists.
assert!(result.is_err());
});
let mut fut = Box::pin(fut);
fut.as_mut().take_value();

// Poll the future after `take_value` has been called
fut.await;
}