From 9f1723b15bf2e475a58816a33b8596283a8adb9c Mon Sep 17 00:00:00 2001 From: Uwe Klotz Date: Tue, 19 Sep 2023 22:03:28 +0200 Subject: [PATCH] sync::watch: Use Acquire/Release memory ordering instead of SeqCst --- tokio/src/sync/watch.rs | 54 ++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 61307fce47d..edd39ad5f74 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -114,7 +114,7 @@ use crate::sync::notify::Notify; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::atomic::Ordering::Relaxed; +use crate::loom::sync::atomic::Ordering; use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; use std::fmt; use std::mem; @@ -247,7 +247,8 @@ struct Shared { impl fmt::Debug for Shared { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let state = self.state.load(); + // Using `Relaxed` ordering is sufficient for this purpose. + let state = self.state.load(Ordering::Relaxed); f.debug_struct("Shared") .field("value", &self.value) .field("version", &state.version()) @@ -341,7 +342,7 @@ mod big_notify { /// This function implements the case where randomness is not available. #[cfg(not(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros"))))] pub(super) fn notified(&self) -> Notified<'_> { - let i = self.next.fetch_add(1, Relaxed) % 8; + let i = self.next.fetch_add(1, Ordering::Relaxed) % 8; self.inner[i].notified() } @@ -357,7 +358,7 @@ mod big_notify { use self::state::{AtomicState, Version}; mod state { use crate::loom::sync::atomic::AtomicUsize; - use crate::loom::sync::atomic::Ordering::SeqCst; + use crate::loom::sync::atomic::Ordering; const CLOSED_BIT: usize = 1; @@ -377,6 +378,11 @@ mod state { pub(super) struct StateSnapshot(usize); /// The state stored in an atomic integer. + /// + /// The `Sender` uses `Release` ordering for storing a new state + /// and the `Receiver`s use `Acquire` ordering for loading the + /// current state. This ensures that written values are seen by + /// the `Receiver`s for a proper handover. #[derive(Debug)] pub(super) struct AtomicState(AtomicUsize); @@ -412,18 +418,32 @@ mod state { } /// Load the current value of the state. - pub(super) fn load(&self) -> StateSnapshot { - StateSnapshot(self.0.load(SeqCst)) + pub(super) fn load(&self, ordering: Ordering) -> StateSnapshot { + StateSnapshot(self.0.load(ordering)) + } + + /// Load the current value of the state. + /// + /// The receiver side (read-only) uses `Acquire` ordering for a proper handover + /// with the sender side (single writer). + pub(super) fn load_receiver(&self) -> StateSnapshot { + StateSnapshot(self.0.load(Ordering::Acquire)) } /// Increment the version counter. pub(super) fn increment_version(&self) { - self.0.fetch_add(STEP_SIZE, SeqCst); + // Use `Release` ordering to ensure that storing the version + // state is seen by the receiver side that uses `Acquire` for + // loading the state. + self.0.fetch_add(STEP_SIZE, Ordering::Release); } /// Set the closed bit in the state. pub(super) fn set_closed(&self) { - self.0.fetch_or(CLOSED_BIT, SeqCst); + // Use `Release` ordering to ensure that storing the version + // state is seen by the receiver side that uses `Acquire` for + // loading the state. + self.0.fetch_or(CLOSED_BIT, Ordering::Release); } } } @@ -489,7 +509,7 @@ impl Receiver { fn from_shared(version: Version, shared: Arc>) -> Self { // No synchronization necessary as this is only used as a counter and // not memory access. - shared.ref_count_rx.fetch_add(1, Relaxed); + shared.ref_count_rx.fetch_add(1, Ordering::Relaxed); Self { shared, version } } @@ -543,7 +563,7 @@ impl Receiver { // After obtaining a read-lock no concurrent writes could occur // and the loaded version matches that of the borrowed reference. - let new_version = self.shared.state.load().version(); + let new_version = self.shared.state.load_receiver().version(); let has_changed = self.version != new_version; Ref { inner, has_changed } @@ -590,7 +610,7 @@ impl Receiver { // After obtaining a read-lock no concurrent writes could occur // and the loaded version matches that of the borrowed reference. - let new_version = self.shared.state.load().version(); + let new_version = self.shared.state.load_receiver().version(); let has_changed = self.version != new_version; // Mark the shared value as seen by updating the version @@ -631,7 +651,7 @@ impl Receiver { /// ``` pub fn has_changed(&self) -> Result { // Load the version from the state - let state = self.shared.state.load(); + let state = self.shared.state.load_receiver(); if state.is_closed() { // The sender has dropped. return Err(error::RecvError(())); @@ -768,7 +788,7 @@ impl Receiver { { let inner = self.shared.value.read().unwrap(); - let new_version = self.shared.state.load().version(); + let new_version = self.shared.state.load_receiver().version(); let has_changed = self.version != new_version; self.version = new_version; @@ -814,7 +834,7 @@ fn maybe_changed( version: &mut Version, ) -> Option> { // Load the version from the state - let state = shared.state.load(); + let state = shared.state.load_receiver(); let new_version = state.version(); if *version != new_version { @@ -865,7 +885,7 @@ impl Drop for Receiver { fn drop(&mut self) { // No synchronization necessary as this is only used as a counter and // not memory access. - if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) { + if 1 == self.shared.ref_count_rx.fetch_sub(1, Ordering::Relaxed) { // This is the last `Receiver` handle, tasks waiting on `Sender::closed()` self.shared.notify_tx.notify_waiters(); } @@ -1228,7 +1248,7 @@ impl Sender { /// ``` pub fn subscribe(&self) -> Receiver { let shared = self.shared.clone(); - let version = shared.state.load().version(); + let version = shared.state.load_receiver().version(); // The CLOSED bit in the state tracks only whether the sender is // dropped, so we do not need to unset it if this reopens the channel. @@ -1254,7 +1274,7 @@ impl Sender { /// } /// ``` pub fn receiver_count(&self) -> usize { - self.shared.ref_count_rx.load(Relaxed) + self.shared.ref_count_rx.load(Ordering::Relaxed) } }