Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream: add next_many and poll_next_many to StreamMap #6409

Merged
merged 10 commits into from
Mar 26, 2024
3 changes: 3 additions & 0 deletions tokio-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@
#[macro_use]
mod macros;

mod poll_fn;
pub(crate) use poll_fn::poll_fn;

pub mod wrappers;

mod stream_ext;
Expand Down
35 changes: 35 additions & 0 deletions tokio-stream/src/poll_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

pub(crate) struct PollFn<F> {
f: F,
}

pub(crate) fn poll_fn<T, F>(f: F) -> PollFn<F>
where
F: FnMut(&mut Context<'_>) -> Poll<T>,
{
PollFn { f }
}

impl<T, F> Future for PollFn<F>
where
F: FnMut(&mut Context<'_>) -> Poll<T>,
{
type Output = T;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
// Safety: We never construct a `Pin<&mut F>` anywhere, so accessing `f`
// mutably in an unpinned way is sound.
//
// This use of unsafe cannot be replaced with the pin-project macro
// because:
// * If we put `#[pin]` on the field, then it gives us a `Pin<&mut F>`,
// which we can't use to call the closure.
// * If we don't put `#[pin]` on the field, then it makes `PollFn` be
// unconditionally `Unpin`, which we also don't want.
let me = unsafe { Pin::into_inner_unchecked(self) };
(me.f)(cx)
}
}
93 changes: 89 additions & 4 deletions tokio-stream/src/stream_map.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Stream;
use crate::{poll_fn, Stream};

use std::borrow::Borrow;
use std::hash::Hash;
Expand Down Expand Up @@ -516,9 +516,7 @@ where
K: Unpin,
V: Stream + Unpin,
{
/// Polls the next value, includes the vec entry index
fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<Option<(usize, V::Item)>> {
let start = self::rand::thread_rng_n(self.entries.len() as u32) as usize;
fn poll_one(&mut self, cx: &mut Context<'_>, start: usize) -> Poll<Option<(usize, V::Item)>> {
let mut idx = start;

for _ in 0..self.entries.len() {
Expand Down Expand Up @@ -553,6 +551,13 @@ where
Poll::Pending
}
}

/// Polls the next value, includes the vec entry index
fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<Option<(usize, V::Item)>> {
let start = self::rand::thread_rng_n(self.entries.len() as u32) as usize;

self.poll_one(cx, start)
}
}

impl<K, V> Default for StreamMap<K, V> {
Expand All @@ -561,6 +566,86 @@ impl<K, V> Default for StreamMap<K, V> {
}
}

impl<K, V> StreamMap<K, V>
where
K: Clone + Unpin,
V: Stream + Unpin,
{
/// Receives multiple items on this [`StreamMap`], extending the provided `buffer`.
///
/// This method returns the number of items that is appended to the `buffer`.
///
/// Note that this method does not guarantee that exactly `limit` items
/// are received. Rather, if at least one item is available, it returns
/// as many items as it can up to the given limit. This method returns
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
///
/// # Cancel safety
///
/// This method is cancel safe. If `next_many` is used as the event in a
/// [`tokio::select!`](tokio::select) statement and some other branch
/// completes first, it is guaranteed that no items were received on any of
/// the underlying streams.
pub async fn next_many(&mut self, buffer: &mut Vec<(K, V::Item)>, limit: usize) -> usize {
poll_fn(|cx| self.poll_next_many(cx, buffer, limit)).await
}

/// Polls to receive multiple items on this `StreamMap`, extending the provided `buffer`.
///
/// This method returns:
/// * `Poll::Pending` if no items are available but the `StreamMap` is not empty.
/// * `Poll::Ready(count)` where `count` is the number of items successfully received and
/// stored in `buffer`. This can be less than, or equal to, `limit`.
/// * `Poll::Ready(0)` if `limit` is set to zero or when the `StreamMap` is empty.
///
/// When the method returns `Poll::Pending`, the `Waker` in the provided
/// `Context` is scheduled to receive a wakeup when an item is sent on any of the
/// underlying stream. Note that on multiple calls to `poll_next_many`, only
/// the `Waker` from the `Context` passed to the most recent call is scheduled
/// to receive a wakeup.
///
maminrayej marked this conversation as resolved.
Show resolved Hide resolved
/// Note that this method does not guarantee that exactly `limit` items
/// are received. Rather, if at least one item is available, it returns
/// as many items as it can up to the given limit. This method returns
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
pub fn poll_next_many(
&mut self,
cx: &mut Context<'_>,
buffer: &mut Vec<(K, V::Item)>,
limit: usize,
) -> Poll<usize> {
if limit == 0 || self.entries.is_empty() {
return Poll::Ready(0);
}

let mut remaining = limit;
let mut start = self::rand::thread_rng_n(self.entries.len() as u32) as usize;

while remaining > 0 {
match self.poll_one(cx, start) {
Poll::Ready(Some((idx, val))) => {
maminrayej marked this conversation as resolved.
Show resolved Hide resolved
remaining -= 1;
let key = self.entries[idx].0.clone();
buffer.push((key, val));

start = idx.wrapping_add(1) % self.entries.len();
}
Poll::Ready(None) | Poll::Pending => break,
}
}

let added = limit - remaining;

maminrayej marked this conversation as resolved.
Show resolved Hide resolved
if added > 0 {
Poll::Ready(added)
} else if self.entries.is_empty() {
Poll::Ready(0)
} else {
Poll::Pending
}
}
}

impl<K, V> Stream for StreamMap<K, V>
where
K: Clone + Unpin,
Expand Down
151 changes: 149 additions & 2 deletions tokio-stream/tests/stream_stream_map.rs
maminrayej marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use futures::stream::iter;
use tokio_stream::{self as stream, pending, Stream, StreamExt, StreamMap};
use tokio_test::{assert_ok, assert_pending, assert_ready, task};

use std::future::{poll_fn, Future};
use std::pin::{pin, Pin};
use std::task::Poll;

mod support {
pub(crate) mod mpsc;
}

use support::mpsc;

use std::pin::Pin;

macro_rules! assert_ready_some {
($($t:tt)*) => {
match assert_ready!($($t)*) {
Expand Down Expand Up @@ -328,3 +331,147 @@ fn one_ready_many_none() {
fn pin_box<T: Stream<Item = U> + 'static, U>(s: T) -> Pin<Box<dyn Stream<Item = U>>> {
Box::pin(s)
}

type UsizeStream = Pin<Box<dyn Stream<Item = usize> + Send>>;

#[tokio::test]
async fn poll_next_many_zero() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(pending()) as UsizeStream);

let n = poll_fn(|cx| stream_map.poll_next_many(cx, &mut vec![], 0)).await;

assert_eq!(n, 0);
}

#[tokio::test]
async fn poll_next_many_empty() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

let n = poll_fn(|cx| stream_map.poll_next_many(cx, &mut vec![], 1)).await;

assert_eq!(n, 0);
}

#[tokio::test]
async fn poll_next_many_pending() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(pending()) as UsizeStream);

let mut is_pending = false;
poll_fn(|cx| {
let poll = stream_map.poll_next_many(cx, &mut vec![], 1);

is_pending = poll.is_pending();

Poll::Ready(())
})
.await;

assert!(is_pending);
}

#[tokio::test]
async fn poll_next_many_not_enough() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(iter([0usize].into_iter())) as UsizeStream);
stream_map.insert(1, Box::pin(iter([1usize].into_iter())) as UsizeStream);

let mut buffer = vec![];
let n = poll_fn(|cx| stream_map.poll_next_many(cx, &mut buffer, 3)).await;

assert_eq!(n, 2);
assert_eq!(buffer.len(), 2);
assert!(buffer.contains(&(0, 0)));
assert!(buffer.contains(&(1, 1)));
}

#[tokio::test]
async fn poll_next_many_enough() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(iter([0usize].into_iter())) as UsizeStream);
stream_map.insert(1, Box::pin(iter([1usize].into_iter())) as UsizeStream);

let mut buffer = vec![];
let n = poll_fn(|cx| stream_map.poll_next_many(cx, &mut buffer, 2)).await;

assert_eq!(n, 2);
assert_eq!(buffer.len(), 2);
assert!(buffer.contains(&(0, 0)));
assert!(buffer.contains(&(1, 1)));
}

#[tokio::test]
async fn next_many_zero() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(pending()) as UsizeStream);

let n = poll_fn(|cx| pin!(stream_map.next_many(&mut vec![], 0)).poll(cx)).await;

assert_eq!(n, 0);
}

#[tokio::test]
async fn next_many_empty() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

let n = stream_map.next_many(&mut vec![], 1).await;

assert_eq!(n, 0);
}

#[tokio::test]
async fn next_many_pending() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(pending()) as UsizeStream);

let mut is_pending = false;
poll_fn(|cx| {
let poll = pin!(stream_map.next_many(&mut vec![], 1)).poll(cx);

is_pending = poll.is_pending();

Poll::Ready(())
})
.await;

assert!(is_pending);
}

#[tokio::test]
async fn next_many_not_enough() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(iter([0usize].into_iter())) as UsizeStream);
stream_map.insert(1, Box::pin(iter([1usize].into_iter())) as UsizeStream);

let mut buffer = vec![];
let n = poll_fn(|cx| pin!(stream_map.next_many(&mut buffer, 3)).poll(cx)).await;

assert_eq!(n, 2);
assert_eq!(buffer.len(), 2);
assert!(buffer.contains(&(0, 0)));
assert!(buffer.contains(&(1, 1)));
}

#[tokio::test]
async fn next_many_enough() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(iter([0usize].into_iter())) as UsizeStream);
stream_map.insert(1, Box::pin(iter([1usize].into_iter())) as UsizeStream);

let mut buffer = vec![];
let n = poll_fn(|cx| pin!(stream_map.next_many(&mut buffer, 2)).poll(cx)).await;

assert_eq!(n, 2);
assert_eq!(buffer.len(), 2);
assert!(buffer.contains(&(0, 0)));
assert!(buffer.contains(&(1, 1)));
}