From b7eea73f6747610d49f4a19547c225655aa873da Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 27 Feb 2023 15:50:09 -0800 Subject: [PATCH] net: add `UdpSocket::peek_sender()` closes #5491 --- tokio/src/net/udp.rs | 61 ++++++++++++++++++++++++++++++++++++++++++++ tokio/tests/udp.rs | 51 ++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/tokio/src/net/udp.rs b/tokio/src/net/udp.rs index 213d9149dad..c8fed793e58 100644 --- a/tokio/src/net/udp.rs +++ b/tokio/src/net/udp.rs @@ -1331,6 +1331,11 @@ impl UdpSocket { /// Make sure to always use a sufficiently large buffer to hold the /// maximum UDP packet size, which can be up to 65536 bytes in size. /// + /// MacOS will return an error if you pass a zero-sized buffer. + /// + /// If you're merely interested in learning the sender of the data at the head of the queue, + /// try [`peek_sender`]. + /// /// # Examples /// /// ```no_run @@ -1349,6 +1354,8 @@ impl UdpSocket { /// Ok(()) /// } /// ``` + /// + /// [`peek_sender`]: method@Self::peek_sender pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { self.io .registration() @@ -1371,6 +1378,11 @@ impl UdpSocket { /// Make sure to always use a sufficiently large buffer to hold the /// maximum UDP packet size, which can be up to 65536 bytes in size. /// + /// MacOS will return an error if you pass a zero-sized buffer. + /// + /// If you're merely interested in learning the sender of the data at the head of the queue, + /// try [`poll_peek_sender`]. + /// /// # Return value /// /// The function returns: @@ -1382,6 +1394,8 @@ impl UdpSocket { /// # Errors /// /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`poll_peek_sender`]: method@Self::poll_peek_sender pub fn poll_peek_from( &self, cx: &mut Context<'_>, @@ -1404,6 +1418,53 @@ impl UdpSocket { Poll::Ready(Ok(addr)) } + /// Retrieve the sender of the data at the head of the input queue, waiting if empty. + /// + /// This is equivalent to calling [`peek_from`] with a zero-sized buffer, + /// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS. + /// + /// [`peek_from`]: method@Self::peek_from + pub async fn peek_sender(&self) -> io::Result { + self.io + .registration() + .async_io(Interest::READABLE, || { + self + .as_socket() + .peek_sender()? + .as_socket() + // Not clear what conditions could cause this, + // but we probably ought not to panic. + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "sender not available")) + }) + .await + } + + /// Retrieve the sender of the data at the head of the input queue. + /// + /// This is equivalent to calling [`poll_peek_from`] with a zero-sized buffer, + /// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS. + /// + /// # Notes + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// [`poll_peek_from`]: method@Self::poll_peek_from + pub fn poll_peek_sender(&self, cx: &mut Context<'_>) -> Poll> { + self.io + .registration() + .poll_read_io(cx, || { + self + .as_socket() + .peek_sender()? + .as_socket() + // Not clear what conditions could cause this, + // but we probably ought not to panic. + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "sender not available")) + }) + } + /// Gets the value of the `SO_BROADCAST` option for this socket. /// /// For more information about this option, see [`set_broadcast`]. diff --git a/tokio/tests/udp.rs b/tokio/tests/udp.rs index 2b6ab4d2ad2..b8a5074e969 100644 --- a/tokio/tests/udp.rs +++ b/tokio/tests/udp.rs @@ -134,6 +134,57 @@ async fn send_to_peek_from_poll() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn peek_sender() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let sender_addr = sender.local_addr()?; + let receiver_addr = receiver.local_addr()?; + + let msg = b"Hello, world!"; + sender.send_to(msg, receiver_addr).await?; + + let peeked_sender = receiver.peek_sender().await?; + assert_eq!(peeked_sender, sender_addr); + + // Assert that `peek_sender()` returns the right sender but + // doesn't remove from the receive queue. + let mut recv_buf = [0u8; 32]; + let (read, received_sender) = receiver.recv_from(&mut recv_buf).await?; + + assert_eq!(&recv_buf[..read], msg); + assert_eq!(received_sender, peeked_sender); + + Ok(()) +} + +#[tokio::test] +async fn poll_peek_sender() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let sender_addr = sender.local_addr()?; + let receiver_addr = receiver.local_addr()?; + + let msg = b"Hello, world!"; + poll_fn(|cx| sender.poll_send_to(cx, msg, receiver_addr)).await?; + + let peeked_sender = poll_fn(|cx| receiver.poll_peek_sender(cx)).await?; + assert_eq!(peeked_sender, sender_addr); + + // Assert that `poll_peek_sender()` returns the right sender but + // doesn't remove from the receive queue. + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let received_sender = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; + + assert_eq!(read.filled(), msg); + assert_eq!(received_sender, peeked_sender); + + Ok(()) +} + #[tokio::test] async fn split() -> std::io::Result<()> { let socket = UdpSocket::bind("127.0.0.1:0").await?;