diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 5a46670eeeb..0452a81aa0a 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -380,7 +380,17 @@ mod state { impl Version { /// Get the initial version when creating the channel. pub(super) fn initial() -> Self { - Version(0) + // The initial version is 1 so that `mark_unseen` can decrement by one. + // (The value is 2 due to the closed bit.) + Version(2) + } + + /// Decrements the version. + pub(super) fn decrement(&mut self) { + // Decrement by two to avoid touching the CLOSED bit. + if self.0 >= 2 { + self.0 -= 2; + } } } @@ -400,7 +410,7 @@ mod state { /// Create a new `AtomicState` that is not closed and which has the /// version set to `Version::initial()`. pub(super) fn new() -> Self { - AtomicState(AtomicUsize::new(0)) + AtomicState(AtomicUsize::new(2)) } /// Load the current value of the state. @@ -634,6 +644,11 @@ impl Receiver { Ok(self.version != new_version) } + /// Marks the state as unseen. + pub fn mark_unseen(&mut self) { + self.version.decrement(); + } + /// Waits for a change notification, then marks the newest value as seen. /// /// If the newest value in the channel has not yet been marked seen when diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index dab57aa5af6..5bcb7476888 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -44,6 +44,64 @@ fn single_rx_recv() { assert_eq!(*rx.borrow(), "two"); } +#[test] +fn rx_version_underflow() { + let (_tx, mut rx) = watch::channel("one"); + + // Version starts at 2, validate we do not underflow + rx.mark_unseen(); + rx.mark_unseen(); +} + +#[test] +fn rx_mark_unseen() { + let (tx, mut rx) = watch::channel("one"); + + let mut rx2 = rx.clone(); + let mut rx3 = rx.clone(); + let mut rx4 = rx.clone(); + { + rx.mark_unseen(); + assert!(rx.has_changed().unwrap()); + + let mut t = spawn(rx.changed()); + assert_ready_ok!(t.poll()); + } + + { + assert!(!rx2.has_changed().unwrap()); + + let mut t = spawn(rx2.changed()); + assert_pending!(t.poll()); + } + + { + rx3.mark_unseen(); + assert_eq!(*rx3.borrow(), "one"); + + assert!(rx3.has_changed().unwrap()); + + assert_eq!(*rx3.borrow_and_update(), "one"); + + assert!(!rx3.has_changed().unwrap()); + + let mut t = spawn(rx3.changed()); + assert_pending!(t.poll()); + } + + { + tx.send("two").unwrap(); + assert!(rx4.has_changed().unwrap()); + assert_eq!(*rx4.borrow_and_update(), "two"); + + rx4.mark_unseen(); + assert!(rx4.has_changed().unwrap()); + assert_eq!(*rx4.borrow_and_update(), "two") + } + + assert_eq!(*rx.borrow(), "two"); +} + #[test] fn multi_rx() { let (tx, mut rx1) = watch::channel("one");