diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 5ef2ce8935d..175a7edf6e0 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -118,8 +118,9 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard}; use crate::util::linked_list::{self, LinkedList}; +use crate::util::WakeList; use std::fmt; use std::future::Future; @@ -569,12 +570,10 @@ impl Sender { // Release the slot lock before notifying the receivers. drop(slot); - tail.notify_rx(); - - // Release the mutex. This must happen after the slot lock is released, - // otherwise the writer lock bit could be cleared while another thread - // is in the critical section. - drop(tail); + // Notify and release the mutex. This must happen after the slot lock is + // released, otherwise the writer lock bit could be cleared while another + // thread is in the critical section. + self.shared.notify_rx(tail); Ok(rem) } @@ -766,7 +765,7 @@ impl Sender { let mut tail = self.shared.tail.lock(); tail.closed = true; - tail.notify_rx(); + self.shared.notify_rx(tail); } } @@ -787,18 +786,47 @@ fn new_receiver(shared: Arc>) -> Receiver { Receiver { shared, next } } -impl Tail { - fn notify_rx(&mut self) { - while let Some(mut waiter) = self.waiters.pop_back() { - // Safety: `waiters` lock is still held. - let waiter = unsafe { waiter.as_mut() }; +impl Shared { + fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) { + let mut wakers = WakeList::new(); + 'outer: loop { + while wakers.can_push() { + match tail.waiters.pop_back() { + Some(mut waiter) => { + // Safety: `tail` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.queued); + waiter.queued = false; + + if let Some(waker) = waiter.waker.take() { + wakers.push(waker); + } + } + None => { + break 'outer; + } + } + } + + // Release the lock before waking. + drop(tail); + + // Before we acquire the lock again all sorts of things can happen: + // some waiters may remove themselves from the list and new waiters + // may be added. This is fine since at worst we will unnecessarily + // wake up waiters which will then queue themselves again. - assert!(waiter.queued); - waiter.queued = false; + wakers.wake_all(); - let waker = waiter.waker.take().unwrap(); - waker.wake(); + // Acquire the lock again. + tail = self.tail.lock(); } + + // Release the lock before waking. + drop(tail); + + wakers.wake_all(); } } @@ -930,6 +958,8 @@ impl Receiver { // the slot lock. drop(slot); + let mut old_waker = None; + let mut tail = self.shared.tail.lock(); // Acquire slot lock again @@ -962,7 +992,10 @@ impl Receiver { match (*ptr).waker { Some(ref w) if w.will_wake(waker) => {} _ => { - (*ptr).waker = Some(waker.clone()); + old_waker = std::mem::replace( + &mut (*ptr).waker, + Some(waker.clone()), + ); } } @@ -974,6 +1007,11 @@ impl Receiver { } } + // Drop the old waker after releasing the locks. + drop(slot); + drop(tail); + drop(old_waker); + return Err(TryRecvError::Empty); } diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index cd6692448bb..feed03148af 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -587,3 +587,57 @@ fn sender_len_random() { assert_eq!(tx.len(), expected_len); } } + +#[test] +fn send_in_waker_drop() { + use futures::task::ArcWake; + use std::future::Future; + use std::task::Context; + + struct SendOnDrop(broadcast::Sender<()>); + + impl Drop for SendOnDrop { + fn drop(&mut self) { + let _ = self.0.send(()); + } + } + + impl ArcWake for SendOnDrop { + fn wake_by_ref(_arc_self: &Arc) {} + } + + // Test if there is no deadlock when replacing the old waker. + + let (tx, mut rx) = broadcast::channel(16); + + let mut fut = Box::pin(async { + let _ = rx.recv().await; + }); + + // Store our special waker in the receiving future. + let waker = futures::task::waker(Arc::new(SendOnDrop(tx))); + let mut cx = Context::from_waker(&waker); + assert!(fut.as_mut().poll(&mut cx).is_pending()); + drop(waker); + + // Second poll shouldn't deadlock. + let mut cx = Context::from_waker(futures::task::noop_waker_ref()); + let _ = fut.as_mut().poll(&mut cx); + + // Test if there is no deadlock when calling waker.wake(). + + let (tx, mut rx) = broadcast::channel(16); + + let mut fut = Box::pin(async { + let _ = rx.recv().await; + }); + + // Store our special waker in the receiving future. + let waker = futures::task::waker(Arc::new(SendOnDrop(tx.clone()))); + let mut cx = Context::from_waker(&waker); + assert!(fut.as_mut().poll(&mut cx).is_pending()); + drop(waker); + + // Shouldn't deadlock. + let _ = tx.send(()); +}