From 76a0e78d4b9d5f48a0d9a863a0e5b1fc317bbe67 Mon Sep 17 00:00:00 2001 From: Victor Timofei Date: Tue, 29 Aug 2023 23:14:04 +0300 Subject: [PATCH 1/6] sync: imlement `watch::Receiver::mark_unseen()` Fixes: #5871 --- tokio/src/sync/watch.rs | 33 ++++++++++++++++++++++++++++----- tokio/tests/sync_watch.rs | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 5a46670eeeb..ad6c408fe00 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -115,7 +115,8 @@ use crate::sync::notify::Notify; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::atomic::Ordering::Relaxed; -use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; +use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use std::cell::Cell; use std::fmt; use std::mem; use std::ops; @@ -136,6 +137,9 @@ pub struct Receiver { /// Last observed version version: Version, + + /// Whether current version is marked as unseen + unseen: Mutex>, } /// Sends values to the associated [`Receiver`](struct@Receiver). @@ -473,6 +477,7 @@ pub fn channel(init: T) -> (Sender, Receiver) { let rx = Receiver { shared, version: Version::initial(), + unseen: Mutex::new(Cell::new(false)), }; (tx, rx) @@ -484,7 +489,7 @@ impl Receiver { // not memory access. shared.ref_count_rx.fetch_add(1, Relaxed); - Self { shared, version } + Self { shared, version, unseen: Mutex::new(Cell::new(false)) } } /// Returns a reference to the most recently sent value. @@ -537,7 +542,10 @@ impl Receiver { // After obtaining a read-lock no concurrent writes could occur // and the loaded version matches that of the borrowed reference. let new_version = self.shared.state.load().version(); - let has_changed = self.version != new_version; + let unseen = self.unseen.lock(); + let has_changed = self.version != new_version || unseen.get(); + + unseen.set(false); Ref { inner, has_changed } } @@ -631,7 +639,12 @@ impl Receiver { } let new_version = state.version(); - Ok(self.version != new_version) + Ok(self.version != new_version || self.unseen.lock().get()) + } + + /// Marks the state as unseen. + pub fn mark_unseen(&mut self) { + self.unseen.lock().set(true); } /// Waits for a change notification, then marks the newest value as seen. @@ -677,6 +690,14 @@ impl Receiver { /// } /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { + { + let unseen = self.unseen.lock(); + + if unseen.get() { + unseen.set(false); + return Ok(()); + } + } changed_impl(&self.shared, &mut self.version).await } @@ -750,10 +771,12 @@ impl Receiver { let inner = self.shared.value.read().unwrap(); let new_version = self.shared.state.load().version(); - let has_changed = self.version != new_version; + let unseen = self.unseen.lock(); + let has_changed = self.version != new_version || unseen.get(); self.version = new_version; if (!closed || has_changed) && f(&inner) { + unseen.set(false); return Ok(Ref { inner, has_changed }); } } diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index dab57aa5af6..30d95667deb 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -44,6 +44,41 @@ fn single_rx_recv() { assert_eq!(*rx.borrow(), "two"); } +#[test] +fn rx_mark_unseen() { + let (tx, mut rx) = watch::channel("one"); + + let mut rx2 = rx.clone(); + let mut rx3 = rx.clone(); + + { + rx.mark_unseen(); + assert!(rx.has_changed().unwrap()); + + let mut t = spawn(rx.changed()); + assert_ready_ok!(t.poll()); + } + + { + assert!(!rx2.has_changed().unwrap()); + + let mut t = spawn(rx2.changed()); + assert_pending!(t.poll()); + } + + { + rx3.mark_unseen(); + assert_eq!(*rx3.borrow(), "one"); + + assert!(!rx3.has_changed().unwrap()); + + let mut t = spawn(rx3.changed()); + assert_pending!(t.poll()); + } + + assert_eq!(*rx.borrow(), "one"); +} + #[test] fn multi_rx() { let (tx, mut rx1) = watch::channel("one"); From bda4df0bdcf1810e24f088720fb1fb150dee62e4 Mon Sep 17 00:00:00 2001 From: Victor Timofei Date: Tue, 29 Aug 2023 23:55:21 +0300 Subject: [PATCH 2/6] fix fmt and clippy --- tokio/src/sync/watch.rs | 16 ++++++++++------ tokio/tests/sync_watch.rs | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index ad6c408fe00..7e0c2443d09 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -489,7 +489,11 @@ impl Receiver { // not memory access. shared.ref_count_rx.fetch_add(1, Relaxed); - Self { shared, version, unseen: Mutex::new(Cell::new(false)) } + Self { + shared, + version, + unseen: Mutex::new(Cell::new(false)), + } } /// Returns a reference to the most recently sent value. @@ -691,12 +695,12 @@ impl Receiver { /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { { - let unseen = self.unseen.lock(); + let unseen = self.unseen.lock(); - if unseen.get() { - unseen.set(false); - return Ok(()); - } + if unseen.get() { + unseen.set(false); + return Ok(()); + } } changed_impl(&self.shared, &mut self.version).await } diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index 30d95667deb..ab3efdcd623 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -46,7 +46,7 @@ fn single_rx_recv() { #[test] fn rx_mark_unseen() { - let (tx, mut rx) = watch::channel("one"); + let (_tx, mut rx) = watch::channel("one"); let mut rx2 = rx.clone(); let mut rx3 = rx.clone(); From 13216edf219fa98577fc5e6b2d74fc71ca249de3 Mon Sep 17 00:00:00 2001 From: Victor Timofei Date: Wed, 30 Aug 2023 23:47:37 +0300 Subject: [PATCH 3/6] change `mark_unseen` to use the version --- tokio/src/loom/std/atomic_isize.rs | 56 ++++++++++++++++++++++++++++++ tokio/src/loom/std/mod.rs | 2 ++ tokio/src/sync/watch.rs | 53 ++++++++++------------------ tokio/tests/sync_watch.rs | 4 +++ 4 files changed, 81 insertions(+), 34 deletions(-) create mode 100644 tokio/src/loom/std/atomic_isize.rs diff --git a/tokio/src/loom/std/atomic_isize.rs b/tokio/src/loom/std/atomic_isize.rs new file mode 100644 index 00000000000..dd94282b241 --- /dev/null +++ b/tokio/src/loom/std/atomic_isize.rs @@ -0,0 +1,56 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::ops; + +/// `AtomicIsize` providing an additional `unsync_load` function. +pub(crate) struct AtomicIsize { + inner: UnsafeCell, +} + +unsafe impl Send for AtomicIsize {} +unsafe impl Sync for AtomicIsize {} + +impl AtomicIsize { + pub(crate) const fn new(val: isize) -> AtomicIsize { + let inner = UnsafeCell::new(std::sync::atomic::AtomicIsize::new(val)); + AtomicIsize { inner } + } + + /// Performs an unsynchronized load. + /// + /// # Safety + /// + /// All mutations must have happened before the unsynchronized load. + /// Additionally, there must be no concurrent mutations. + pub(crate) unsafe fn unsync_load(&self) -> isize { + core::ptr::read(self.inner.get() as *const isize) + } + + pub(crate) fn with_mut(&mut self, f: impl FnOnce(&mut isize) -> R) -> R { + // safety: we have mutable access + f(unsafe { (*self.inner.get()).get_mut() }) + } +} + +impl ops::Deref for AtomicIsize { + type Target = std::sync::atomic::AtomicIsize; + + fn deref(&self) -> &Self::Target { + // safety: it is always safe to access `&self` fns on the inner value as + // we never perform unsafe mutations. + unsafe { &*self.inner.get() } + } +} + +impl ops::DerefMut for AtomicIsize { + fn deref_mut(&mut self) -> &mut Self::Target { + // safety: we hold `&mut self` + unsafe { &mut *self.inner.get() } + } +} + +impl fmt::Debug for AtomicIsize { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(fmt) + } +} diff --git a/tokio/src/loom/std/mod.rs b/tokio/src/loom/std/mod.rs index 0c611af162a..c2e91201460 100644 --- a/tokio/src/loom/std/mod.rs +++ b/tokio/src/loom/std/mod.rs @@ -1,5 +1,6 @@ #![cfg_attr(any(not(feature = "full"), loom), allow(unused_imports, dead_code))] +mod atomic_isize; mod atomic_u16; mod atomic_u32; mod atomic_u64; @@ -70,6 +71,7 @@ pub(crate) mod sync { pub(crate) use crate::loom::std::mutex::Mutex; pub(crate) mod atomic { + pub(crate) use crate::loom::std::atomic_isize::AtomicIsize; pub(crate) use crate::loom::std::atomic_u16::AtomicU16; pub(crate) use crate::loom::std::atomic_u32::AtomicU32; pub(crate) use crate::loom::std::atomic_u64::{AtomicU64, StaticAtomicU64}; diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 7e0c2443d09..265955d7b44 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -115,8 +115,7 @@ use crate::sync::notify::Notify; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::atomic::Ordering::Relaxed; -use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; -use std::cell::Cell; +use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; use std::fmt; use std::mem; use std::ops; @@ -137,9 +136,6 @@ pub struct Receiver { /// Last observed version version: Version, - - /// Whether current version is marked as unseen - unseen: Mutex>, } /// Sends values to the associated [`Receiver`](struct@Receiver). @@ -360,14 +356,14 @@ mod big_notify { use self::state::{AtomicState, Version}; mod state { - use crate::loom::sync::atomic::AtomicUsize; + use crate::loom::sync::atomic::AtomicIsize; use crate::loom::sync::atomic::Ordering::SeqCst; - const CLOSED: usize = 1; + const CLOSED: isize = 1; /// The version part of the state. The lowest bit is always zero. #[derive(Copy, Clone, Debug, Eq, PartialEq)] - pub(super) struct Version(usize); + pub(super) struct Version(isize); /// Snapshot of the state. The first bit is used as the CLOSED bit. /// The remaining bits are used as the version. @@ -375,17 +371,23 @@ mod state { /// The CLOSED bit tracks whether the Sender has been dropped. Dropping all /// receivers does not set it. #[derive(Copy, Clone, Debug)] - pub(super) struct StateSnapshot(usize); + pub(super) struct StateSnapshot(isize); /// The state stored in an atomic integer. #[derive(Debug)] - pub(super) struct AtomicState(AtomicUsize); + pub(super) struct AtomicState(AtomicIsize); impl Version { /// Get the initial version when creating the channel. pub(super) fn initial() -> Self { Version(0) } + + /// Decrements the version. + pub(super) fn decrement(&mut self) { + // Decrement by two to avoid touching the CLOSED bit. + self.0 -= 2; + } } impl StateSnapshot { @@ -404,7 +406,7 @@ mod state { /// Create a new `AtomicState` that is not closed and which has the /// version set to `Version::initial()`. pub(super) fn new() -> Self { - AtomicState(AtomicUsize::new(0)) + AtomicState(AtomicIsize::new(0)) } /// Load the current value of the state. @@ -477,7 +479,6 @@ pub fn channel(init: T) -> (Sender, Receiver) { let rx = Receiver { shared, version: Version::initial(), - unseen: Mutex::new(Cell::new(false)), }; (tx, rx) @@ -489,11 +490,7 @@ impl Receiver { // not memory access. shared.ref_count_rx.fetch_add(1, Relaxed); - Self { - shared, - version, - unseen: Mutex::new(Cell::new(false)), - } + Self { shared, version } } /// Returns a reference to the most recently sent value. @@ -546,10 +543,7 @@ impl Receiver { // After obtaining a read-lock no concurrent writes could occur // and the loaded version matches that of the borrowed reference. let new_version = self.shared.state.load().version(); - let unseen = self.unseen.lock(); - let has_changed = self.version != new_version || unseen.get(); - - unseen.set(false); + let has_changed = self.version != new_version; Ref { inner, has_changed } } @@ -643,12 +637,13 @@ impl Receiver { } let new_version = state.version(); - Ok(self.version != new_version || self.unseen.lock().get()) + println!("observed: {:?}, current: {new_version:?}", self.version); + Ok(self.version != new_version) } /// Marks the state as unseen. pub fn mark_unseen(&mut self) { - self.unseen.lock().set(true); + self.version.decrement(); } /// Waits for a change notification, then marks the newest value as seen. @@ -694,14 +689,6 @@ impl Receiver { /// } /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { - { - let unseen = self.unseen.lock(); - - if unseen.get() { - unseen.set(false); - return Ok(()); - } - } changed_impl(&self.shared, &mut self.version).await } @@ -775,12 +762,10 @@ impl Receiver { let inner = self.shared.value.read().unwrap(); let new_version = self.shared.state.load().version(); - let unseen = self.unseen.lock(); - let has_changed = self.version != new_version || unseen.get(); + let has_changed = self.version != new_version; self.version = new_version; if (!closed || has_changed) && f(&inner) { - unseen.set(false); return Ok(Ref { inner, has_changed }); } } diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index ab3efdcd623..5c0721587ca 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -70,6 +70,10 @@ fn rx_mark_unseen() { rx3.mark_unseen(); assert_eq!(*rx3.borrow(), "one"); + assert!(rx3.has_changed().unwrap()); + + assert_eq!(*rx3.borrow_and_update(), "one"); + assert!(!rx3.has_changed().unwrap()); let mut t = spawn(rx3.changed()); From cecb70167c88521488b519efd6b95b2bf1ba95b3 Mon Sep 17 00:00:00 2001 From: Victor Timofei Date: Thu, 31 Aug 2023 00:11:47 +0300 Subject: [PATCH 4/6] remove println --- tokio/src/sync/watch.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 265955d7b44..ce187de3554 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -637,7 +637,6 @@ impl Receiver { } let new_version = state.version(); - println!("observed: {:?}, current: {new_version:?}", self.version); Ok(self.version != new_version) } From 1a9ff9ca98ce59dba0ef70b8fb33ffd0e0187dbf Mon Sep 17 00:00:00 2001 From: Victor Timofei Date: Fri, 1 Sep 2023 20:39:46 +0300 Subject: [PATCH 5/6] use usize and start initial version at 2 --- tokio/src/loom/std/atomic_isize.rs | 56 ------------------------------ tokio/src/loom/std/mod.rs | 2 -- tokio/src/sync/watch.rs | 20 ++++++----- tokio/tests/sync_watch.rs | 9 +++++ 4 files changed, 21 insertions(+), 66 deletions(-) delete mode 100644 tokio/src/loom/std/atomic_isize.rs diff --git a/tokio/src/loom/std/atomic_isize.rs b/tokio/src/loom/std/atomic_isize.rs deleted file mode 100644 index dd94282b241..00000000000 --- a/tokio/src/loom/std/atomic_isize.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::cell::UnsafeCell; -use std::fmt; -use std::ops; - -/// `AtomicIsize` providing an additional `unsync_load` function. -pub(crate) struct AtomicIsize { - inner: UnsafeCell, -} - -unsafe impl Send for AtomicIsize {} -unsafe impl Sync for AtomicIsize {} - -impl AtomicIsize { - pub(crate) const fn new(val: isize) -> AtomicIsize { - let inner = UnsafeCell::new(std::sync::atomic::AtomicIsize::new(val)); - AtomicIsize { inner } - } - - /// Performs an unsynchronized load. - /// - /// # Safety - /// - /// All mutations must have happened before the unsynchronized load. - /// Additionally, there must be no concurrent mutations. - pub(crate) unsafe fn unsync_load(&self) -> isize { - core::ptr::read(self.inner.get() as *const isize) - } - - pub(crate) fn with_mut(&mut self, f: impl FnOnce(&mut isize) -> R) -> R { - // safety: we have mutable access - f(unsafe { (*self.inner.get()).get_mut() }) - } -} - -impl ops::Deref for AtomicIsize { - type Target = std::sync::atomic::AtomicIsize; - - fn deref(&self) -> &Self::Target { - // safety: it is always safe to access `&self` fns on the inner value as - // we never perform unsafe mutations. - unsafe { &*self.inner.get() } - } -} - -impl ops::DerefMut for AtomicIsize { - fn deref_mut(&mut self) -> &mut Self::Target { - // safety: we hold `&mut self` - unsafe { &mut *self.inner.get() } - } -} - -impl fmt::Debug for AtomicIsize { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - (**self).fmt(fmt) - } -} diff --git a/tokio/src/loom/std/mod.rs b/tokio/src/loom/std/mod.rs index c2e91201460..0c611af162a 100644 --- a/tokio/src/loom/std/mod.rs +++ b/tokio/src/loom/std/mod.rs @@ -1,6 +1,5 @@ #![cfg_attr(any(not(feature = "full"), loom), allow(unused_imports, dead_code))] -mod atomic_isize; mod atomic_u16; mod atomic_u32; mod atomic_u64; @@ -71,7 +70,6 @@ pub(crate) mod sync { pub(crate) use crate::loom::std::mutex::Mutex; pub(crate) mod atomic { - pub(crate) use crate::loom::std::atomic_isize::AtomicIsize; pub(crate) use crate::loom::std::atomic_u16::AtomicU16; pub(crate) use crate::loom::std::atomic_u32::AtomicU32; pub(crate) use crate::loom::std::atomic_u64::{AtomicU64, StaticAtomicU64}; diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index ce187de3554..7326f8831f8 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -356,14 +356,14 @@ mod big_notify { use self::state::{AtomicState, Version}; mod state { - use crate::loom::sync::atomic::AtomicIsize; + use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::atomic::Ordering::SeqCst; - const CLOSED: isize = 1; + const CLOSED: usize = 1; /// The version part of the state. The lowest bit is always zero. #[derive(Copy, Clone, Debug, Eq, PartialEq)] - pub(super) struct Version(isize); + pub(super) struct Version(usize); /// Snapshot of the state. The first bit is used as the CLOSED bit. /// The remaining bits are used as the version. @@ -371,22 +371,26 @@ mod state { /// The CLOSED bit tracks whether the Sender has been dropped. Dropping all /// receivers does not set it. #[derive(Copy, Clone, Debug)] - pub(super) struct StateSnapshot(isize); + pub(super) struct StateSnapshot(usize); /// The state stored in an atomic integer. #[derive(Debug)] - pub(super) struct AtomicState(AtomicIsize); + pub(super) struct AtomicState(AtomicUsize); impl Version { /// Get the initial version when creating the channel. pub(super) fn initial() -> Self { - Version(0) + // We start counting at 2 so that we can mark the + // Receiver as unread on creation. + Version(2) } /// Decrements the version. pub(super) fn decrement(&mut self) { // Decrement by two to avoid touching the CLOSED bit. - self.0 -= 2; + if self.0 >= 2 { + self.0 -= 2; + } } } @@ -406,7 +410,7 @@ mod state { /// Create a new `AtomicState` that is not closed and which has the /// version set to `Version::initial()`. pub(super) fn new() -> Self { - AtomicState(AtomicIsize::new(0)) + AtomicState(AtomicUsize::new(2)) } /// Load the current value of the state. diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index 5c0721587ca..2397915736d 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -44,6 +44,15 @@ fn single_rx_recv() { assert_eq!(*rx.borrow(), "two"); } +#[test] +fn rx_version_underflow() { + let (_tx, mut rx) = watch::channel("one"); + + // Version starts at 2, validate we do not underflow + rx.mark_unseen(); + rx.mark_unseen(); +} + #[test] fn rx_mark_unseen() { let (_tx, mut rx) = watch::channel("one"); From f2704f1100f474f93a89f4e3c467d722bf276b95 Mon Sep 17 00:00:00 2001 From: Victor Timofei Date: Sun, 3 Sep 2023 16:21:12 +0000 Subject: [PATCH 6/6] update comment and extend tests --- tokio/src/sync/watch.rs | 4 ++-- tokio/tests/sync_watch.rs | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 7326f8831f8..0452a81aa0a 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -380,8 +380,8 @@ mod state { impl Version { /// Get the initial version when creating the channel. pub(super) fn initial() -> Self { - // We start counting at 2 so that we can mark the - // Receiver as unread on creation. + // The initial version is 1 so that `mark_unseen` can decrement by one. + // (The value is 2 due to the closed bit.) Version(2) } diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index 2397915736d..5bcb7476888 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -55,11 +55,11 @@ fn rx_version_underflow() { #[test] fn rx_mark_unseen() { - let (_tx, mut rx) = watch::channel("one"); + let (tx, mut rx) = watch::channel("one"); let mut rx2 = rx.clone(); let mut rx3 = rx.clone(); - + let mut rx4 = rx.clone(); { rx.mark_unseen(); assert!(rx.has_changed().unwrap()); @@ -89,7 +89,17 @@ fn rx_mark_unseen() { assert_pending!(t.poll()); } - assert_eq!(*rx.borrow(), "one"); + { + tx.send("two").unwrap(); + assert!(rx4.has_changed().unwrap()); + assert_eq!(*rx4.borrow_and_update(), "two"); + + rx4.mark_unseen(); + assert!(rx4.has_changed().unwrap()); + assert_eq!(*rx4.borrow_and_update(), "two") + } + + assert_eq!(*rx.borrow(), "two"); } #[test]