diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 9e7cd1a7966..61307fce47d 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -359,7 +359,10 @@ mod state { use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::atomic::Ordering::SeqCst; - const CLOSED: usize = 1; + const CLOSED_BIT: usize = 1; + + // Using 2 as the step size preserves the `CLOSED_BIT`. + const STEP_SIZE: usize = 2; /// The version part of the state. The lowest bit is always zero. #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -378,31 +381,26 @@ mod state { pub(super) struct AtomicState(AtomicUsize); impl Version { - /// Get the initial version when creating the channel. - pub(super) fn initial() -> Self { - // The initial version is 1 so that `mark_changed` can decrement by one. - // (The value is 2 due to the closed bit.) - Version(2) - } - /// Decrements the version. pub(super) fn decrement(&mut self) { - // Decrement by two to avoid touching the CLOSED bit. - if self.0 >= 2 { - self.0 -= 2; - } + // Using a wrapping decrement here is required to ensure that the + // operation is consistent with `std::sync::atomic::AtomicUsize::fetch_add()` + // which wraps on overflow. + self.0 = self.0.wrapping_sub(STEP_SIZE); } + + pub(super) const INITIAL: Self = Version(0); } impl StateSnapshot { /// Extract the version from the state. pub(super) fn version(self) -> Version { - Version(self.0 & !CLOSED) + Version(self.0 & !CLOSED_BIT) } /// Is the closed bit set? pub(super) fn is_closed(self) -> bool { - (self.0 & CLOSED) == CLOSED + (self.0 & CLOSED_BIT) == CLOSED_BIT } } @@ -410,7 +408,7 @@ mod state { /// Create a new `AtomicState` that is not closed and which has the /// version set to `Version::initial()`. pub(super) fn new() -> Self { - AtomicState(AtomicUsize::new(2)) + AtomicState(AtomicUsize::new(Version::INITIAL.0)) } /// Load the current value of the state. @@ -420,13 +418,12 @@ mod state { /// Increment the version counter. pub(super) fn increment_version(&self) { - // Increment by two to avoid touching the CLOSED bit. - self.0.fetch_add(2, SeqCst); + self.0.fetch_add(STEP_SIZE, SeqCst); } /// Set the closed bit in the state. pub(super) fn set_closed(&self) { - self.0.fetch_or(CLOSED, SeqCst); + self.0.fetch_or(CLOSED_BIT, SeqCst); } } } @@ -482,7 +479,7 @@ pub fn channel(init: T) -> (Sender, Receiver) { let rx = Receiver { shared, - version: Version::initial(), + version: Version::INITIAL, }; (tx, rx)