Skip to content

Commit

Permalink
sync: avoid deadlocks in broadcast with custom wakers (#5578)
Browse files Browse the repository at this point in the history
  • Loading branch information
satakuma committed Apr 16, 2023
1 parent 1b22cbf commit 8497f37
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 18 deletions.
74 changes: 56 additions & 18 deletions tokio/src/sync/broadcast.rs
Expand Up @@ -118,8 +118,9 @@

use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
use crate::util::linked_list::{self, LinkedList};
use crate::util::WakeList;

use std::fmt;
use std::future::Future;
Expand Down Expand Up @@ -569,12 +570,10 @@ impl<T> Sender<T> {
// Release the slot lock before notifying the receivers.
drop(slot);

tail.notify_rx();

// Release the mutex. This must happen after the slot lock is released,
// otherwise the writer lock bit could be cleared while another thread
// is in the critical section.
drop(tail);
// Notify and release the mutex. This must happen after the slot lock is
// released, otherwise the writer lock bit could be cleared while another
// thread is in the critical section.
self.shared.notify_rx(tail);

Ok(rem)
}
Expand Down Expand Up @@ -766,7 +765,7 @@ impl<T> Sender<T> {
let mut tail = self.shared.tail.lock();
tail.closed = true;

tail.notify_rx();
self.shared.notify_rx(tail);
}
}

Expand All @@ -787,18 +786,47 @@ fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> {
Receiver { shared, next }
}

impl Tail {
fn notify_rx(&mut self) {
while let Some(mut waiter) = self.waiters.pop_back() {
// Safety: `waiters` lock is still held.
let waiter = unsafe { waiter.as_mut() };
impl<T> Shared<T> {
fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) {
let mut wakers = WakeList::new();
'outer: loop {
while wakers.can_push() {
match tail.waiters.pop_back() {
Some(mut waiter) => {
// Safety: `tail` lock is still held.
let waiter = unsafe { waiter.as_mut() };

assert!(waiter.queued);
waiter.queued = false;

if let Some(waker) = waiter.waker.take() {
wakers.push(waker);
}
}
None => {
break 'outer;
}
}
}

// Release the lock before waking.
drop(tail);

// Before we acquire the lock again all sorts of things can happen:
// some waiters may remove themselves from the list and new waiters
// may be added. This is fine since at worst we will unnecessarily
// wake up waiters which will then queue themselves again.

assert!(waiter.queued);
waiter.queued = false;
wakers.wake_all();

let waker = waiter.waker.take().unwrap();
waker.wake();
// Acquire the lock again.
tail = self.tail.lock();
}

// Release the lock before waking.
drop(tail);

wakers.wake_all();
}
}

Expand Down Expand Up @@ -930,6 +958,8 @@ impl<T> Receiver<T> {
// the slot lock.
drop(slot);

let mut old_waker = None;

let mut tail = self.shared.tail.lock();

// Acquire slot lock again
Expand Down Expand Up @@ -962,7 +992,10 @@ impl<T> Receiver<T> {
match (*ptr).waker {
Some(ref w) if w.will_wake(waker) => {}
_ => {
(*ptr).waker = Some(waker.clone());
old_waker = std::mem::replace(
&mut (*ptr).waker,
Some(waker.clone()),
);
}
}

Expand All @@ -974,6 +1007,11 @@ impl<T> Receiver<T> {
}
}

// Drop the old waker after releasing the locks.
drop(slot);
drop(tail);
drop(old_waker);

return Err(TryRecvError::Empty);
}

Expand Down
54 changes: 54 additions & 0 deletions tokio/tests/sync_broadcast.rs
Expand Up @@ -587,3 +587,57 @@ fn sender_len_random() {
assert_eq!(tx.len(), expected_len);
}
}

#[test]
fn send_in_waker_drop() {
use futures::task::ArcWake;
use std::future::Future;
use std::task::Context;

struct SendOnDrop(broadcast::Sender<()>);

impl Drop for SendOnDrop {
fn drop(&mut self) {
let _ = self.0.send(());
}
}

impl ArcWake for SendOnDrop {
fn wake_by_ref(_arc_self: &Arc<Self>) {}
}

// Test if there is no deadlock when replacing the old waker.

let (tx, mut rx) = broadcast::channel(16);

let mut fut = Box::pin(async {
let _ = rx.recv().await;
});

// Store our special waker in the receiving future.
let waker = futures::task::waker(Arc::new(SendOnDrop(tx)));
let mut cx = Context::from_waker(&waker);
assert!(fut.as_mut().poll(&mut cx).is_pending());
drop(waker);

// Second poll shouldn't deadlock.
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
let _ = fut.as_mut().poll(&mut cx);

// Test if there is no deadlock when calling waker.wake().

let (tx, mut rx) = broadcast::channel(16);

let mut fut = Box::pin(async {
let _ = rx.recv().await;
});

// Store our special waker in the receiving future.
let waker = futures::task::waker(Arc::new(SendOnDrop(tx.clone())));
let mut cx = Context::from_waker(&waker);
assert!(fut.as_mut().poll(&mut cx).is_pending());
drop(waker);

// Shouldn't deadlock.
let _ = tx.send(());
}

0 comments on commit 8497f37

Please sign in to comment.