diff --git a/tokio/src/io/read_buf.rs b/tokio/src/io/read_buf.rs index 0dc595a87dd..283d96e3095 100644 --- a/tokio/src/io/read_buf.rs +++ b/tokio/src/io/read_buf.rs @@ -270,6 +270,33 @@ impl<'a> ReadBuf<'a> { } } +#[cfg(feature = "io-util")] +#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] +unsafe impl<'a> bytes::BufMut for ReadBuf<'a> { + fn remaining_mut(&self) -> usize { + self.remaining() + } + + // SAFETY: The caller guarantees that at least `cnt` unfilled bytes have been initialized. + unsafe fn advance_mut(&mut self, cnt: usize) { + self.assume_init(cnt); + self.advance(cnt); + } + + fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { + // SAFETY: No region of `unfilled` will be deinitialized because it is + // exposed as an `UninitSlice`, whose API guarantees that the memory is + // never deinitialized. + let unfilled = unsafe { self.unfilled_mut() }; + let len = unfilled.len(); + let ptr = unfilled.as_mut_ptr() as *mut u8; + + // SAFETY: The pointer is valid for `len` bytes because it comes from a + // slice of that length. + unsafe { bytes::buf::UninitSlice::from_raw_parts_mut(ptr, len) } + } +} + impl fmt::Debug for ReadBuf<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ReadBuf") diff --git a/tokio/tests/io_read_buf.rs b/tokio/tests/io_read_buf.rs index 0328168d7ab..49a4f86f8ad 100644 --- a/tokio/tests/io_read_buf.rs +++ b/tokio/tests/io_read_buf.rs @@ -34,3 +34,61 @@ async fn read_buf() { assert_eq!(n, 11); assert_eq!(buf[..], b"hello world"[..]); } + +#[tokio::test] +#[cfg(feature = "io-util")] +async fn issue_5588() { + use bytes::BufMut; + + // steps to zero + let mut buf = [0; 8]; + let mut read_buf = ReadBuf::new(&mut buf); + assert_eq!(read_buf.remaining_mut(), 8); + assert_eq!(read_buf.chunk_mut().len(), 8); + unsafe { + read_buf.advance_mut(1); + } + assert_eq!(read_buf.remaining_mut(), 7); + assert_eq!(read_buf.chunk_mut().len(), 7); + unsafe { + read_buf.advance_mut(5); + } + assert_eq!(read_buf.remaining_mut(), 2); + assert_eq!(read_buf.chunk_mut().len(), 2); + unsafe { + read_buf.advance_mut(2); + } + assert_eq!(read_buf.remaining_mut(), 0); + assert_eq!(read_buf.chunk_mut().len(), 0); + + // directly to zero + let mut buf = [0; 8]; + let mut read_buf = ReadBuf::new(&mut buf); + assert_eq!(read_buf.remaining_mut(), 8); + assert_eq!(read_buf.chunk_mut().len(), 8); + unsafe { + read_buf.advance_mut(8); + } + assert_eq!(read_buf.remaining_mut(), 0); + assert_eq!(read_buf.chunk_mut().len(), 0); + + // uninit + let mut buf = [std::mem::MaybeUninit::new(1); 8]; + let mut uninit = ReadBuf::uninit(&mut buf); + assert_eq!(uninit.remaining_mut(), 8); + assert_eq!(uninit.chunk_mut().len(), 8); + + let mut buf = [std::mem::MaybeUninit::uninit(); 8]; + let mut uninit = ReadBuf::uninit(&mut buf); + unsafe { + uninit.advance_mut(4); + } + assert_eq!(uninit.remaining_mut(), 4); + assert_eq!(uninit.chunk_mut().len(), 4); + uninit.put_u8(1); + assert_eq!(uninit.remaining_mut(), 3); + assert_eq!(uninit.chunk_mut().len(), 3); + uninit.put_slice(&[1, 2, 3]); + assert_eq!(uninit.remaining_mut(), 0); + assert_eq!(uninit.chunk_mut().len(), 0); +}