Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: do not trace tasks while locking OwnedTasks #6036

Merged
merged 9 commits into from Oct 6, 2023
8 changes: 8 additions & 0 deletions tokio/src/runtime/handle.rs
Expand Up @@ -543,6 +543,14 @@ cfg_taskdump! {
scheduler::Handle::MultiThreadAlt(_) => panic!("task dump not implemented for this runtime flavor"),
}
}

/// Produces `true` if the current task is being traced for a dump;
/// otherwise false. This function is only public for integration
/// testing purposes. Do not rely on it.
#[doc(hidden)]
pub fn is_tracing() -> bool {
super::task::trace::Context::is_tracing()
}
}

cfg_rt_multi_thread! {
Expand Down
10 changes: 10 additions & 0 deletions tokio/src/runtime/task/mod.rs
Expand Up @@ -444,6 +444,16 @@ impl<S: Schedule> UnownedTask<S> {
}
}

impl<S: 'static> Clone for Task<S> {
fn clone(&self) -> Task<S> {
// SAFETY: We increment the ref count.
unsafe {
self.raw.ref_inc();
Task::new(self.raw)
}
}
}
jswrenn marked this conversation as resolved.
Show resolved Hide resolved

impl<S: 'static> Drop for Task<S> {
fn drop(&mut self) {
// Decrement the ref count
Expand Down
72 changes: 41 additions & 31 deletions tokio/src/runtime/task/trace/mod.rs
Expand Up @@ -100,6 +100,16 @@ impl Context {
Self::try_with_current(|context| f(&context.collector)).expect(FAIL_NO_THREAD_LOCAL)
}
}

/// Produces `true` if the current task is being traced; otherwise false.
pub(crate) fn is_tracing() -> bool {
Self::with_current_collector(|maybe_collector| {
let collector = maybe_collector.take();
let result = collector.is_some();
maybe_collector.set(collector);
result
})
}
}

impl Trace {
Expand Down Expand Up @@ -268,22 +278,8 @@ pub(in crate::runtime) fn trace_current_thread(
drop(task);
}

// notify each task
let mut tasks = vec![];
owned.for_each(|task| {
// set the notified bit
task.as_raw().state().transition_to_notified_for_tracing();
// store the raw tasks into a vec
tasks.push(task.as_raw());
});

tasks
.into_iter()
.map(|task| {
let ((), trace) = Trace::capture(|| task.poll());
trace
})
.collect()
// precondition: We have drained the tasks from the injection queue.
trace_owned(owned)
}

cfg_rt_multi_thread! {
Expand Down Expand Up @@ -316,21 +312,35 @@ cfg_rt_multi_thread! {

drop(synced);

// notify each task
let mut traces = vec![];
owned.for_each(|task| {
// set the notified bit
task.as_raw().state().transition_to_notified_for_tracing();

// trace the task
let ((), trace) = Trace::capture(|| task.as_raw().poll());
traces.push(trace);
// precondition: we have drained the tasks from the local and injection
// queues.
trace_owned(owned)
}
}

// reschedule the task
let _ = task.as_raw().state().transition_to_notified_by_ref();
task.as_raw().schedule();
});
/// Trace the `OwnedTasks`.
///
/// # Preconditions
///
/// This helper presumes exclusive access to each task. The tasks must not exist
/// in any other queue.
fn trace_owned<S>(owned: &OwnedTasks<Arc<S>>) -> Vec<Trace> {
// notify each task
let mut tasks = vec![];
owned.for_each(|task| {
// set the notified bit
task.as_raw().state().transition_to_notified_for_tracing();
// store the tasks into a vec
tasks.push(task.clone());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so, I've figured out how I think we should handle this refcount business:

  • First, we wrap transition_to_notified_for_tracing with a method on Task that returns a Notified. This Notified will own the refcount that we created inside of transition_to_notified_for_tracing.
  • Next, when you want to poll the task, instead of calling poll directly, you instead use OwnedTasks::assert_owner to convert the Notified into a LocalNotified.
  • To poll the task, you call LocalNotified::run.

This way, we only increment the refcount once (so there's no call to Task::clone), and all refcounts we create are owned by Notified object, so their destructor will clear it if it gets dropped before it is polled, e.g. due to a panic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

// do NOT poll `task` here, since we hold a lock on `owned` and the task
// may complete and need to remove itself from `owned`.
});

traces
}
tasks
.into_iter()
.map(|task| {
let ((), trace) = Trace::capture(|| task.as_raw().poll());
trace
})
.collect()
}
57 changes: 57 additions & 0 deletions tokio/tests/dump.rs
Expand Up @@ -97,3 +97,60 @@ fn multi_thread() {
);
});
}

/// Regression tests for #6035.
///
/// These tests ensure that dumping will not deadlock if a future completes
/// during a trace.
mod future_completes_during_trace {
use super::*;

use core::future::{poll_fn, Future};

/// A future that completes only during a trace.
fn complete_during_trace() -> impl Future<Output = ()> + Send {
use std::task::Poll;
poll_fn(|cx| {
if Handle::is_tracing() {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
})
}

#[test]
fn current_thread() {
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();

async fn dump() {
let handle = Handle::current();
let _dump = handle.dump().await;
}

rt.block_on(async {
let _ = tokio::join!(tokio::spawn(complete_during_trace()), dump());
});
}

#[test]
fn multi_thread() {
let rt = runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();

async fn dump() {
let handle = Handle::current();
let _dump = handle.dump().await;
}

rt.block_on(async {
let _ = tokio::join!(tokio::spawn(complete_during_trace()), dump());
});
}
}