Skip to content

Commit

Permalink
sync: fix mark_changed when version overflows (#6017)
Browse files Browse the repository at this point in the history
  • Loading branch information
uklotzde committed Sep 19, 2023
1 parent 9d51b76 commit ad7f988
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions tokio/src/sync/watch.rs
Expand Up @@ -359,7 +359,10 @@ mod state {
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::SeqCst;

const CLOSED: usize = 1;
const CLOSED_BIT: usize = 1;

// Using 2 as the step size preserves the `CLOSED_BIT`.
const STEP_SIZE: usize = 2;

/// The version part of the state. The lowest bit is always zero.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
Expand All @@ -378,39 +381,34 @@ mod state {
pub(super) struct AtomicState(AtomicUsize);

impl Version {
/// Get the initial version when creating the channel.
pub(super) fn initial() -> Self {
// The initial version is 1 so that `mark_changed` can decrement by one.
// (The value is 2 due to the closed bit.)
Version(2)
}

/// Decrements the version.
pub(super) fn decrement(&mut self) {
// Decrement by two to avoid touching the CLOSED bit.
if self.0 >= 2 {
self.0 -= 2;
}
// Using a wrapping decrement here is required to ensure that the
// operation is consistent with `std::sync::atomic::AtomicUsize::fetch_add()`
// which wraps on overflow.
self.0 = self.0.wrapping_sub(STEP_SIZE);
}

pub(super) const INITIAL: Self = Version(0);
}

impl StateSnapshot {
/// Extract the version from the state.
pub(super) fn version(self) -> Version {
Version(self.0 & !CLOSED)
Version(self.0 & !CLOSED_BIT)
}

/// Is the closed bit set?
pub(super) fn is_closed(self) -> bool {
(self.0 & CLOSED) == CLOSED
(self.0 & CLOSED_BIT) == CLOSED_BIT
}
}

impl AtomicState {
/// Create a new `AtomicState` that is not closed and which has the
/// version set to `Version::initial()`.
pub(super) fn new() -> Self {
AtomicState(AtomicUsize::new(2))
AtomicState(AtomicUsize::new(Version::INITIAL.0))
}

/// Load the current value of the state.
Expand All @@ -420,13 +418,12 @@ mod state {

/// Increment the version counter.
pub(super) fn increment_version(&self) {
// Increment by two to avoid touching the CLOSED bit.
self.0.fetch_add(2, SeqCst);
self.0.fetch_add(STEP_SIZE, SeqCst);
}

/// Set the closed bit in the state.
pub(super) fn set_closed(&self) {
self.0.fetch_or(CLOSED, SeqCst);
self.0.fetch_or(CLOSED_BIT, SeqCst);
}
}
}
Expand Down Expand Up @@ -482,7 +479,7 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {

let rx = Receiver {
shared,
version: Version::initial(),
version: Version::INITIAL,
};

(tx, rx)
Expand Down

0 comments on commit ad7f988

Please sign in to comment.