Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unsafe impl<'a> BufMut for ReadBuf<'a> #5590

Merged
merged 19 commits into from Apr 4, 2023
Merged
27 changes: 27 additions & 0 deletions tokio/src/io/read_buf.rs
Expand Up @@ -270,6 +270,33 @@ impl<'a> ReadBuf<'a> {
}
}

#[cfg(feature = "io-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
unsafe impl<'a> bytes::BufMut for ReadBuf<'a> {
fn remaining_mut(&self) -> usize {
self.remaining()
}

// SAFETY: The caller guarantees that at least `cnt` unfilled bytes have been initialized.
unsafe fn advance_mut(&mut self, cnt: usize) {
self.assume_init(cnt);
self.advance(cnt);
}

fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
// SAFETY: No region of `unfilled` will be deinitialized because it is
// exposed as an `UninitSlice`, whose API guarantees that the memory is
// never deinitialized.
let unfilled = unsafe { self.unfilled_mut() };
let len = unfilled.len();
let ptr = unfilled.as_mut_ptr() as *mut u8;

// SAFETY: The pointer is valid for `len` bytes because it comes from a
// slice of that length.
unsafe { bytes::buf::UninitSlice::from_raw_parts_mut(ptr, len) }
}
}

impl fmt::Debug for ReadBuf<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadBuf")
Expand Down
58 changes: 58 additions & 0 deletions tokio/tests/io_read_buf.rs
Expand Up @@ -34,3 +34,61 @@ async fn read_buf() {
assert_eq!(n, 11);
assert_eq!(buf[..], b"hello world"[..]);
}

#[tokio::test]
#[cfg(feature = "io-util")]
async fn issue_5588() {
use bytes::BufMut;

// steps to zero
let mut buf = [0; 8];
let mut read_buf = ReadBuf::new(&mut buf);
assert_eq!(read_buf.remaining_mut(), 8);
assert_eq!(read_buf.chunk_mut().len(), 8);
unsafe {
read_buf.advance_mut(1);
}
assert_eq!(read_buf.remaining_mut(), 7);
assert_eq!(read_buf.chunk_mut().len(), 7);
unsafe {
read_buf.advance_mut(5);
}
assert_eq!(read_buf.remaining_mut(), 2);
assert_eq!(read_buf.chunk_mut().len(), 2);
unsafe {
read_buf.advance_mut(2);
}
assert_eq!(read_buf.remaining_mut(), 0);
assert_eq!(read_buf.chunk_mut().len(), 0);

// directly to zero
let mut buf = [0; 8];
let mut read_buf = ReadBuf::new(&mut buf);
assert_eq!(read_buf.remaining_mut(), 8);
assert_eq!(read_buf.chunk_mut().len(), 8);
unsafe {
read_buf.advance_mut(8);
}
assert_eq!(read_buf.remaining_mut(), 0);
assert_eq!(read_buf.chunk_mut().len(), 0);

// uninit
let mut buf = [std::mem::MaybeUninit::new(1); 8];
let mut uninit = ReadBuf::uninit(&mut buf);
assert_eq!(uninit.remaining_mut(), 8);
assert_eq!(uninit.chunk_mut().len(), 8);

let mut buf = [std::mem::MaybeUninit::uninit(); 8];
let mut uninit = ReadBuf::uninit(&mut buf);
unsafe {
uninit.advance_mut(4);
}
assert_eq!(uninit.remaining_mut(), 4);
assert_eq!(uninit.chunk_mut().len(), 4);
uninit.put_u8(1);
assert_eq!(uninit.remaining_mut(), 3);
assert_eq!(uninit.chunk_mut().len(), 3);
uninit.put_slice(&[1, 2, 3]);
assert_eq!(uninit.remaining_mut(), 0);
assert_eq!(uninit.chunk_mut().len(), 0);
}