Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement write_vectored for DuplexStream #5985

Merged
merged 1 commit into from Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
75 changes: 75 additions & 0 deletions tokio/src/io/util/mem.rs
Expand Up @@ -124,6 +124,18 @@ impl AsyncWrite for DuplexStream {
Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
true
}

#[allow(unused_mut)]
fn poll_flush(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -224,6 +236,37 @@ impl Pipe {
}
Poll::Ready(Ok(len))
}

fn poll_write_vectored_internal(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
let avail = self.max_buf_size - self.buffer.len();
if avail == 0 {
self.write_waker = Some(cx.waker().clone());
return Poll::Pending;
}

let mut rem = avail;
for buf in bufs {
if rem == 0 {
break;
}

let len = buf.len().min(rem);
self.buffer.extend_from_slice(&buf[..len]);
rem -= len;
}

if let Some(waker) = self.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(avail - rem))
}
}

impl AsyncRead for Pipe {
Expand Down Expand Up @@ -285,6 +328,38 @@ impl AsyncWrite for Pipe {
}
}

cfg_coop! {
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
ready!(crate::trace::trace_leaf(cx));
let coop = ready!(crate::runtime::coop::poll_proceed(cx));

let ret = self.poll_write_vectored_internal(cx, bufs);
if ret.is_ready() {
coop.made_progress();
}
ret
}
}

cfg_not_coop! {
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
ready!(crate::trace::trace_leaf(cx));
self.poll_write_vectored_internal(cx, bufs)
}
}

fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
Expand Down
47 changes: 47 additions & 0 deletions tokio/tests/duplex_stream.rs
@@ -0,0 +1,47 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use std::io::IoSlice;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

const HELLO: &[u8] = b"hello world...";

#[tokio::test]
async fn write_vectored() {
let (mut client, mut server) = tokio::io::duplex(64);

let ret = client
.write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)])
.await
.unwrap();
assert_eq!(ret, HELLO.len() * 2);

client.flush().await.unwrap();
drop(client);

let mut buf = Vec::with_capacity(HELLO.len() * 2);
let bytes_read = server.read_to_end(&mut buf).await.unwrap();

assert_eq!(bytes_read, HELLO.len() * 2);
assert_eq!(buf, [HELLO, HELLO].concat());
}

#[tokio::test]
async fn write_vectored_and_shutdown() {
let (mut client, mut server) = tokio::io::duplex(64);

let ret = client
.write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)])
.await
.unwrap();
assert_eq!(ret, HELLO.len() * 2);

client.shutdown().await.unwrap();
drop(client);

let mut buf = Vec::with_capacity(HELLO.len() * 2);
let bytes_read = server.read_to_end(&mut buf).await.unwrap();

assert_eq!(bytes_read, HELLO.len() * 2);
assert_eq!(buf, [HELLO, HELLO].concat());
}