Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream: add next_many and poll_next_many to StreamMap #6409

Merged
merged 10 commits into from
Mar 26, 2024
2 changes: 1 addition & 1 deletion tokio-stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ futures-core = { version = "0.3.0" }
pin-project-lite = "0.2.11"
tokio = { version = "1.15.0", path = "../tokio", features = ["sync"] }
tokio-util = { version = "0.7.0", path = "../tokio-util", optional = true }
futures = { version = "0.3", default-features = false }

[dev-dependencies]
tokio = { version = "1.2.0", path = "../tokio", features = ["full", "test-util"] }
async-stream = "0.3"
parking_lot = "0.12.0"
tokio-test = { path = "../tokio-test" }
futures = { version = "0.3", default-features = false }

[package.metadata.docs.rs]
all-features = true
Expand Down
74 changes: 74 additions & 0 deletions tokio-stream/src/stream_map.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::Stream;

use futures::future::poll_fn;
maminrayej marked this conversation as resolved.
Show resolved Hide resolved

use std::borrow::Borrow;
use std::hash::Hash;
use std::pin::Pin;
Expand Down Expand Up @@ -561,6 +563,78 @@ impl<K, V> Default for StreamMap<K, V> {
}
}

impl<K, V> StreamMap<K, V>
where
K: Clone + Unpin,
V: Stream + Unpin,
{
/// Polls to receive multiple items on this `StreamMap`, extending the provided `buffer`.
///
/// This method returns:
/// * `Poll::Pending` if no items are available but the `StreamMap` is not empty.
/// * `Poll::Ready(count)` where `count` is the number of items successfully received and
/// stored in `buffer`. This can be less than, or equal to, `limit`.
/// * `Poll::Ready(0)` if `limit` is set to zero or when the `StreamMap` is empty.
///
/// When the method returns `Poll::Pending`, the `Waker` in the provided
/// `Context` is scheduled to receive a wakeup when an item is sent on any of the
/// underlying stream. Note that on multiple calls to `poll_recv_many`, only
/// the `Waker` from the `Context` passed to the most recent call is scheduled
/// to receive a wakeup.
///
/// Note that this method does not guarantee that exactly `limit` items
/// are received. Rather, if at least one item is available, it returns
/// as many items as it can up to the given limit. This method returns
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
pub fn poll_recv_many(
&mut self,
cx: &mut Context<'_>,
buffer: &mut Vec<(K, V::Item)>,
limit: usize,
) -> Poll<usize> {
if limit == 0 {
return Poll::Ready(0);
}
maminrayej marked this conversation as resolved.
Show resolved Hide resolved

let mut remaining = limit;

while remaining > 0 {
match self.poll_next_entry(cx) {
maminrayej marked this conversation as resolved.
Show resolved Hide resolved
Poll::Ready(Some((idx, val))) => {
remaining -= 1;
let key = self.entries[idx].0.clone();
buffer.push((key, val));
}
Poll::Ready(None) | Poll::Pending => break,
}
}

let added = limit - remaining;
maminrayej marked this conversation as resolved.
Show resolved Hide resolved

if added > 0 {
Poll::Ready(added)
} else if self.entries.is_empty() {
Poll::Ready(0)
} else {
Poll::Pending
}
}

/// Receives multiple items on this [`StreamMap`], extending the provided `buffer`.
///
/// This method returns the number of items that is appended to the `buffer`.
///
/// # Cancel safety
///
/// This method is cancel safe. If `recv_many` is used as the event in a
/// [`tokio::select!`](tokio::select) statement and some other branch
/// completes first, it is guaranteed that no items were received on any of
/// the underlying streams.
pub async fn recv_many(&mut self, buffer: &mut Vec<(K, V::Item)>, limit: usize) -> usize {
maminrayej marked this conversation as resolved.
Show resolved Hide resolved
maminrayej marked this conversation as resolved.
Show resolved Hide resolved
poll_fn(|cx| self.poll_recv_many(cx, buffer, limit)).await
}
maminrayej marked this conversation as resolved.
Show resolved Hide resolved
}

impl<K, V> Stream for StreamMap<K, V>
where
K: Clone + Unpin,
Expand Down
153 changes: 151 additions & 2 deletions tokio-stream/tests/stream_stream_map.rs
maminrayej marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use futures::stream::iter;
use tokio_stream::{self as stream, pending, Stream, StreamExt, StreamMap};
use tokio_test::{assert_ok, assert_pending, assert_ready, task};

use std::{
future::{poll_fn, Future},
pin::{pin, Pin},
task::Poll,
};
maminrayej marked this conversation as resolved.
Show resolved Hide resolved

mod support {
pub(crate) mod mpsc;
}

use support::mpsc;

use std::pin::Pin;

macro_rules! assert_ready_some {
($($t:tt)*) => {
match assert_ready!($($t)*) {
Expand Down Expand Up @@ -328,3 +333,147 @@ fn one_ready_many_none() {
fn pin_box<T: Stream<Item = U> + 'static, U>(s: T) -> Pin<Box<dyn Stream<Item = U>>> {
Box::pin(s)
}

type UsizeStream = Pin<Box<dyn Stream<Item = usize> + Send>>;

#[tokio::test]
async fn poll_recv_many_zero() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(pending()) as UsizeStream);

let n = poll_fn(|cx| stream_map.poll_recv_many(cx, &mut vec![], 0)).await;

assert_eq!(n, 0);
}

#[tokio::test]
async fn poll_recv_many_empty() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

let n = poll_fn(|cx| stream_map.poll_recv_many(cx, &mut vec![], 1)).await;

assert_eq!(n, 0);
}

#[tokio::test]
async fn poll_recv_many_pending() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(pending()) as UsizeStream);

let mut is_pending = false;
poll_fn(|cx| {
let poll = stream_map.poll_recv_many(cx, &mut vec![], 1);

is_pending = poll.is_pending();

Poll::Ready(())
})
.await;

assert!(is_pending);
}

#[tokio::test]
async fn poll_recv_many_not_enough() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(iter([0usize].into_iter())) as UsizeStream);
stream_map.insert(1, Box::pin(iter([1usize].into_iter())) as UsizeStream);

let mut buffer = vec![];
let n = poll_fn(|cx| stream_map.poll_recv_many(cx, &mut buffer, 3)).await;

assert_eq!(n, 2);
assert_eq!(buffer.len(), 2);
assert!(buffer.contains(&(0, 0)));
assert!(buffer.contains(&(1, 1)));
}

#[tokio::test]
async fn poll_recv_many_enough() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(iter([0usize].into_iter())) as UsizeStream);
stream_map.insert(1, Box::pin(iter([1usize].into_iter())) as UsizeStream);

let mut buffer = vec![];
let n = poll_fn(|cx| stream_map.poll_recv_many(cx, &mut buffer, 2)).await;

assert_eq!(n, 2);
assert_eq!(buffer.len(), 2);
assert!(buffer.contains(&(0, 0)));
assert!(buffer.contains(&(1, 1)));
}

#[tokio::test]
async fn recv_many_zero() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(pending()) as UsizeStream);

let n = poll_fn(|cx| pin!(stream_map.recv_many(&mut vec![], 0)).poll(cx)).await;

assert_eq!(n, 0);
}

#[tokio::test]
async fn recv_many_empty() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

let n = stream_map.recv_many(&mut vec![], 1).await;

assert_eq!(n, 0);
}

#[tokio::test]
async fn recv_many_pending() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(pending()) as UsizeStream);

let mut is_pending = false;
poll_fn(|cx| {
let poll = pin!(stream_map.recv_many(&mut vec![], 1)).poll(cx);

is_pending = poll.is_pending();

Poll::Ready(())
})
.await;

assert!(is_pending);
}

#[tokio::test]
async fn recv_many_not_enough() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(iter([0usize].into_iter())) as UsizeStream);
stream_map.insert(1, Box::pin(iter([1usize].into_iter())) as UsizeStream);

let mut buffer = vec![];
let n = poll_fn(|cx| pin!(stream_map.recv_many(&mut buffer, 3)).poll(cx)).await;

assert_eq!(n, 2);
assert_eq!(buffer.len(), 2);
assert!(buffer.contains(&(0, 0)));
assert!(buffer.contains(&(1, 1)));
}

#[tokio::test]
async fn recv_many_enough() {
let mut stream_map: StreamMap<usize, UsizeStream> = StreamMap::new();

stream_map.insert(0, Box::pin(iter([0usize].into_iter())) as UsizeStream);
stream_map.insert(1, Box::pin(iter([1usize].into_iter())) as UsizeStream);

let mut buffer = vec![];
let n = poll_fn(|cx| pin!(stream_map.recv_many(&mut buffer, 2)).poll(cx)).await;

assert_eq!(n, 2);
assert_eq!(buffer.len(), 2);
assert!(buffer.contains(&(0, 0)));
assert!(buffer.contains(&(1, 1)));
}