Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: do not trace tasks while locking OwnedTasks #6036

Merged
merged 9 commits into from
Oct 6, 2023
8 changes: 8 additions & 0 deletions tokio/src/runtime/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,14 @@ cfg_taskdump! {
scheduler::Handle::MultiThreadAlt(_) => panic!("task dump not implemented for this runtime flavor"),
}
}

/// Produces `true` if the current task is being traced for a dump;
/// otherwise false. This function is only public for integration
/// testing purposes. Do not rely on it.
#[doc(hidden)]
pub fn is_tracing() -> bool {
super::task::trace::Context::is_tracing()
}
}

cfg_rt_multi_thread! {
Expand Down
17 changes: 17 additions & 0 deletions tokio/src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ impl<S: 'static> Task<S> {
fn header_ptr(&self) -> NonNull<Header> {
self.raw.header_ptr()
}

cfg_taskdump! {
pub(super) fn notify_for_tracing(&self) -> Notified<S> {
self.as_raw().state().transition_to_notified_for_tracing();
Notified(self.clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since transition_to_notified_for_tracing already increments the refcount, you do not also need to clone it. You probably have a memory leak right now.

Suggested change
Notified(self.clone())
Notified(self)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a leak here. Remember, when we're calling notify_for_tracing, we're iterating over OwnedTasks, and need to get notified copies of those tasks into a separate vector. A clone has to occur somewhere. Either we can do it here, or we can modified this method to consume self and do the clone in trace_helper instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, you increment twice (in transition_to_notified_for_tracing and clone), but only decrement once (in run)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yep. I didn't catch that run mem::forgets the LocalNotified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of removing the ref_inc from transition_to_notified_for_tracing, it makes more sense to me to remove the call to clone. This way, you only touch the atomic once instead of twice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 4264f61, but note that there's now a need for cloning in trace_owned, because we have to stash the Tasks (or Notified)s outside of OwnedTasks.

}
}
}

impl<S: 'static> Notified<S> {
Expand Down Expand Up @@ -444,6 +451,16 @@ impl<S: Schedule> UnownedTask<S> {
}
}

impl<S: 'static> Clone for Task<S> {
fn clone(&self) -> Task<S> {
// SAFETY: We increment the ref count.
unsafe {
self.raw.ref_inc();
Task::new(self.raw)
}
}
}
jswrenn marked this conversation as resolved.
Show resolved Hide resolved

impl<S: 'static> Drop for Task<S> {
fn drop(&mut self) {
// Decrement the ref count
Expand Down
3 changes: 2 additions & 1 deletion tokio/src/runtime/task/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ impl RawTask {

/// Increment the task's reference count.
///
/// Currently, this is used only when creating an `AbortHandle`.
/// Currently, this is used only when creating an `AbortHandle`,
/// and when cloning a `Task`.
jswrenn marked this conversation as resolved.
Show resolved Hide resolved
pub(super) fn ref_inc(self) {
self.header().state.ref_inc();
}
Expand Down
73 changes: 41 additions & 32 deletions tokio/src/runtime/task/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mod tree;
use symbol::Symbol;
use tree::Tree;

use super::{Notified, OwnedTasks};
use super::{Notified, OwnedTasks, Schedule};

type Backtrace = Vec<BacktraceFrame>;
type SymbolTrace = Vec<Symbol>;
Expand Down Expand Up @@ -100,6 +100,16 @@ impl Context {
Self::try_with_current(|context| f(&context.collector)).expect(FAIL_NO_THREAD_LOCAL)
}
}

/// Produces `true` if the current task is being traced; otherwise false.
pub(crate) fn is_tracing() -> bool {
Self::with_current_collector(|maybe_collector| {
let collector = maybe_collector.take();
let result = collector.is_some();
maybe_collector.set(collector);
result
})
}
}

impl Trace {
Expand Down Expand Up @@ -268,22 +278,8 @@ pub(in crate::runtime) fn trace_current_thread(
drop(task);
}

// notify each task
let mut tasks = vec![];
owned.for_each(|task| {
// set the notified bit
task.as_raw().state().transition_to_notified_for_tracing();
// store the raw tasks into a vec
tasks.push(task.as_raw());
});

tasks
.into_iter()
.map(|task| {
let ((), trace) = Trace::capture(|| task.poll());
trace
})
.collect()
// precondition: We have drained the tasks from the injection queue.
trace_owned(owned)
}

cfg_rt_multi_thread! {
Expand Down Expand Up @@ -316,21 +312,34 @@ cfg_rt_multi_thread! {

drop(synced);

// notify each task
let mut traces = vec![];
owned.for_each(|task| {
// set the notified bit
task.as_raw().state().transition_to_notified_for_tracing();

// trace the task
let ((), trace) = Trace::capture(|| task.as_raw().poll());
traces.push(trace);
// precondition: we have drained the tasks from the local and injection
// queues.
trace_owned(owned)
}
}

// reschedule the task
let _ = task.as_raw().state().transition_to_notified_by_ref();
task.as_raw().schedule();
});
/// Trace the `OwnedTasks`.
///
/// # Preconditions
///
/// This helper presumes exclusive access to each task. The tasks must not exist
/// in any other queue.
fn trace_owned<S: Schedule>(owned: &OwnedTasks<S>) -> Vec<Trace> {
// notify each task
let mut tasks = vec![];
owned.for_each(|task| {
// notify the task (and thus make it poll-able) and stash it
tasks.push(task.notify_for_tracing());
// we do not poll it here since we hold a lock on `owned` and the task
// may complete and need to remove itself from `owned`.
});

traces
}
tasks
.into_iter()
.map(|task| {
let local_notified = owned.assert_owner(task);
let ((), trace) = Trace::capture(|| local_notified.run());
trace
})
.collect()
}
57 changes: 57 additions & 0 deletions tokio/tests/dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,60 @@ fn multi_thread() {
);
});
}

/// Regression tests for #6035.
///
/// These tests ensure that dumping will not deadlock if a future completes
/// during a trace.
mod future_completes_during_trace {
use super::*;

use core::future::{poll_fn, Future};

/// A future that completes only during a trace.
fn complete_during_trace() -> impl Future<Output = ()> + Send {
use std::task::Poll;
poll_fn(|cx| {
if Handle::is_tracing() {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
})
}

#[test]
fn current_thread() {
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();

async fn dump() {
let handle = Handle::current();
let _dump = handle.dump().await;
}

rt.block_on(async {
let _ = tokio::join!(tokio::spawn(complete_during_trace()), dump());
});
}

#[test]
fn multi_thread() {
let rt = runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();

async fn dump() {
let handle = Handle::current();
let _dump = handle.dump().await;
}

rt.block_on(async {
let _ = tokio::join!(tokio::spawn(complete_during_trace()), dump());
});
}
}