Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream: add next_many and poll_next_many to StreamMap #6409

Merged
merged 10 commits into from
Mar 26, 2024
3 changes: 3 additions & 0 deletions tokio-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@
#[macro_use]
mod macros;

mod poll_fn;
pub(crate) use poll_fn::poll_fn;

pub mod wrappers;

mod stream_ext;
Expand Down
35 changes: 35 additions & 0 deletions tokio-stream/src/poll_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

pub(crate) struct PollFn<F> {
f: F,
}

pub(crate) fn poll_fn<T, F>(f: F) -> PollFn<F>
where
F: FnMut(&mut Context<'_>) -> Poll<T>,
{
PollFn { f }
}

impl<T, F> Future for PollFn<F>
where
F: FnMut(&mut Context<'_>) -> Poll<T>,
{
type Output = T;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
// Safety: We never construct a `Pin<&mut F>` anywhere, so accessing `f`
// mutably in an unpinned way is sound.
//
// This use of unsafe cannot be replaced with the pin-project macro
// because:
// * If we put `#[pin]` on the field, then it gives us a `Pin<&mut F>`,
// which we can't use to call the closure.
// * If we don't put `#[pin]` on the field, then it makes `PollFn` be
// unconditionally `Unpin`, which we also don't want.
let me = unsafe { Pin::into_inner_unchecked(self) };
(me.f)(cx)
}
}
106 changes: 105 additions & 1 deletion tokio-stream/src/stream_map.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Stream;
use crate::{poll_fn, Stream};

use std::borrow::Borrow;
use std::hash::Hash;
Expand Down Expand Up @@ -561,6 +561,110 @@ impl<K, V> Default for StreamMap<K, V> {
}
}

impl<K, V> StreamMap<K, V>
where
K: Clone + Unpin,
V: Stream + Unpin,
{
/// Receives multiple items on this [`StreamMap`], extending the provided `buffer`.
///
/// This method returns the number of items that is appended to the `buffer`.
///
/// Note that this method does not guarantee that exactly `limit` items
/// are received. Rather, if at least one item is available, it returns
/// as many items as it can up to the given limit. This method returns
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
///
/// # Cancel safety
///
/// This method is cancel safe. If `next_many` is used as the event in a
/// [`tokio::select!`](tokio::select) statement and some other branch
/// completes first, it is guaranteed that no items were received on any of
/// the underlying streams.
pub async fn next_many(&mut self, buffer: &mut Vec<(K, V::Item)>, limit: usize) -> usize {
poll_fn(|cx| self.poll_next_many(cx, buffer, limit)).await
}

/// Polls to receive multiple items on this `StreamMap`, extending the provided `buffer`.
///
/// This method returns:
/// * `Poll::Pending` if no items are available but the `StreamMap` is not empty.
/// * `Poll::Ready(count)` where `count` is the number of items successfully received and
/// stored in `buffer`. This can be less than, or equal to, `limit`.
/// * `Poll::Ready(0)` if `limit` is set to zero or when the `StreamMap` is empty.
///
/// Note that this method does not guarantee that exactly `limit` items
/// are received. Rather, if at least one item is available, it returns
/// as many items as it can up to the given limit. This method returns
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
pub fn poll_next_many(
&mut self,
cx: &mut Context<'_>,
buffer: &mut Vec<(K, V::Item)>,
limit: usize,
) -> Poll<usize> {
if limit == 0 || self.entries.is_empty() {
return Poll::Ready(0);
}

let mut added = 0;

let start = self::rand::thread_rng_n(self.entries.len() as u32) as usize;
let mut idx = start;

while added < limit {
// Indicates whether at least one stream returned a value when polled or not
let mut should_loop = false;

for _ in 0..self.entries.len() {
let (_, stream) = &mut self.entries[idx];

match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(val)) => {
added += 1;

let key = self.entries[idx].0.clone();
buffer.push((key, val));

should_loop = true;

idx = idx.wrapping_add(1) % self.entries.len();
}
Poll::Ready(None) => {
// Remove the entry
self.entries.swap_remove(idx);

// Check if this was the last entry, if so the cursor needs
// to wrap
if idx == self.entries.len() {
idx = 0;
} else if idx < start && start <= self.entries.len() {
// The stream being swapped into the current index has
// already been polled, so skip it.
idx = idx.wrapping_add(1) % self.entries.len();
}
}
Poll::Pending => {
idx = idx.wrapping_add(1) % self.entries.len();
}
}
}

if !should_loop {
break;
}
}

if added > 0 {
Poll::Ready(added)
} else if self.entries.is_empty() {
Poll::Ready(0)
} else {
Poll::Pending
}
}
}

impl<K, V> Stream for StreamMap<K, V>
where
K: Clone + Unpin,
Expand Down