Skip to content

Commit

Permalink
Factor out UnraisableCapture helper type and use it to check that dro…
Browse files Browse the repository at this point in the history
…pping unsendable elsewhere calls into sys.unraisablehook
  • Loading branch information
adamreichold committed May 24, 2023
1 parent 6261f79 commit 240e117
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 75 deletions.
45 changes: 45 additions & 0 deletions tests/common.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Some common macros for tests

use pyo3::prelude::*;

#[macro_export]
macro_rules! py_assert {
($py:expr, $($val:ident)+, $assertion:literal) => {
Expand Down Expand Up @@ -41,3 +43,46 @@ macro_rules! py_expect_exception {
err
}};
}

#[pyclass]
pub struct UnraisableCapture {
pub capture: Option<(PyErr, PyObject)>,
old_hook: Option<PyObject>,
}

impl UnraisableCapture {
pub fn install(py: Python<'_>) -> Py<Self> {
let sys = py.import("sys").unwrap();
let old_hook = sys.getattr("unraisablehook").unwrap().into();

let capture = Py::new(
py,
UnraisableCapture {
capture: None,
old_hook: Some(old_hook),
},
)
.unwrap();

sys.setattr("unraisablehook", capture.getattr(py, "hook").unwrap())
.unwrap();

capture
}

pub fn uninstall(&mut self, py: Python<'_>) {
let old_hook = self.old_hook.take().unwrap();

let sys = py.import("sys").unwrap();
sys.setattr("unraisablehook", old_hook).unwrap();
}
}

#[pymethods]
impl UnraisableCapture {
pub fn hook(&mut self, unraisable: &PyAny) {
let err = PyErr::from_value(unraisable.getattr("exc_value").unwrap());
let instance = unraisable.getattr("object").unwrap();
self.capture = Some((err, instance.into()));
}
}
24 changes: 3 additions & 21 deletions tests/test_buffer_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ fn test_buffer_referenced() {
#[test]
#[cfg(Py_3_8)] // sys.unraisablehook not available until Python 3.8
fn test_releasebuffer_unraisable_error() {
use common::UnraisableCapture;
use pyo3::exceptions::PyValueError;

#[pyclass]
Expand All @@ -117,27 +118,8 @@ fn test_releasebuffer_unraisable_error() {
}
}

#[pyclass]
struct UnraisableCapture {
capture: Option<(PyErr, PyObject)>,
}

#[pymethods]
impl UnraisableCapture {
fn hook(&mut self, unraisable: &PyAny) {
let err = PyErr::from_value(unraisable.getattr("exc_value").unwrap());
let instance = unraisable.getattr("object").unwrap();
self.capture = Some((err, instance.into()));
}
}

Python::with_gil(|py| {
let sys = py.import("sys").unwrap();
let old_hook = sys.getattr("unraisablehook").unwrap();
let capture = Py::new(py, UnraisableCapture { capture: None }).unwrap();

sys.setattr("unraisablehook", capture.getattr(py, "hook").unwrap())
.unwrap();
let capture = UnraisableCapture::install(py);

let instance = Py::new(py, ReleaseBufferError {}).unwrap();
let env = [("ob", instance.clone())].into_py_dict(py);
Expand All @@ -150,7 +132,7 @@ fn test_releasebuffer_unraisable_error() {
assert_eq!(err.to_string(), "ValueError: oh dear");
assert!(object.is(&instance));

sys.setattr("unraisablehook", old_hook).unwrap();
capture.borrow_mut(py).uninstall(py);
});
}

Expand Down
90 changes: 57 additions & 33 deletions tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,31 +228,40 @@ impl UnsendableChild {
}

fn test_unsendable<T: PyClass + 'static>() -> PyResult<()> {
let obj = std::thread::spawn(|| -> PyResult<_> {
let obj = Python::with_gil(|py| -> PyResult<_> {
let obj: Py<T> = PyType::new::<T>(py).call1((5,))?.extract()?;

// Accessing the value inside this thread should not panic
let caught_panic =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| -> PyResult<_> {
assert_eq!(obj.as_ref(py).getattr("value")?.extract::<usize>()?, 5);
Ok(())
}))
.is_err();

assert!(!caught_panic);
Ok(obj)
})?;

let keep_obj_here = obj.clone();

let caught_panic = std::thread::spawn(move || {
// This access must panic
Python::with_gil(|py| {
let obj: Py<T> = PyType::new::<T>(py).call1((5,))?.extract()?;

// Accessing the value inside this thread should not panic
let caught_panic =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| -> PyResult<_> {
assert_eq!(obj.as_ref(py).getattr("value")?.extract::<usize>()?, 5);
Ok(())
}))
.is_err();

assert!(!caught_panic);
Ok(obj)
})
obj.borrow(py);
});
})
.join()
.unwrap()?;
.join();

// This access must panic
Python::with_gil(|py| {
obj.borrow(py);
});
Python::with_gil(|_py| drop(keep_obj_here));

if let Err(err) = caught_panic {
if let Some(msg) = err.downcast_ref::<String>() {
panic!("{}", msg);
}
}

panic!("Borrowing unsendable from receiving thread did not panic.");
Ok(())
}

/// If a class is marked as `unsendable`, it panics when accessed by another thread.
Expand Down Expand Up @@ -529,6 +538,7 @@ fn access_frozen_class_without_gil() {

#[test]
fn drop_unsendable_elsewhere() {
use common::UnraisableCapture;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand All @@ -546,21 +556,35 @@ fn drop_unsendable_elsewhere() {
}
}

let dropped = Arc::new(AtomicBool::new(false));
Python::with_gil(|py| {
let capture = UnraisableCapture::install(py);

let unsendable = Python::with_gil(|py| {
let dropped = dropped.clone();
let dropped = Arc::new(AtomicBool::new(false));

Py::new(py, Unsendable { dropped }).unwrap()
});
let unsendable = Py::new(
py,
Unsendable {
dropped: dropped.clone(),
},
)
.unwrap();

spawn(move || {
Python::with_gil(move |_py| {
drop(unsendable);
py.allow_threads(|| {
spawn(move || {
Python::with_gil(move |_py| {
drop(unsendable);
});
})
.join()
.unwrap();
});
})
.join()
.unwrap();

assert!(!dropped.load(Ordering::SeqCst));
assert!(!dropped.load(Ordering::SeqCst));

let (err, object) = capture.borrow_mut(py).capture.take().unwrap();
assert_eq!(err.to_string(), "RuntimeError: test_class_basics::drop_unsendable_elsewhere::Unsendable is unsendbale, but is dropped on another thread!");
assert!(object.is_none(py));

capture.borrow_mut(py).uninstall(py);
});
}
24 changes: 3 additions & 21 deletions tests/test_exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,11 @@ fn test_exception_nosegfault() {
#[test]
#[cfg(Py_3_8)]
fn test_write_unraisable() {
use common::UnraisableCapture;
use pyo3::{exceptions::PyRuntimeError, ffi, AsPyPointer};

#[pyclass]
struct UnraisableCapture {
capture: Option<(PyErr, PyObject)>,
}

#[pymethods]
impl UnraisableCapture {
fn hook(&mut self, unraisable: &PyAny) {
let err = PyErr::from_value(unraisable.getattr("exc_value").unwrap());
let instance = unraisable.getattr("object").unwrap();
self.capture = Some((err, instance.into()));
}
}

Python::with_gil(|py| {
let sys = py.import("sys").unwrap();
let old_hook = sys.getattr("unraisablehook").unwrap();
let capture = Py::new(py, UnraisableCapture { capture: None }).unwrap();

sys.setattr("unraisablehook", capture.getattr(py, "hook").unwrap())
.unwrap();
let capture = UnraisableCapture::install(py);

assert!(capture.borrow(py).capture.is_none());

Expand All @@ -140,6 +122,6 @@ fn test_write_unraisable() {
assert_eq!(err.to_string(), "RuntimeError: bar");
assert!(object.as_ptr() == unsafe { ffi::Py_NotImplemented() });

sys.setattr("unraisablehook", old_hook).unwrap();
capture.borrow_mut(py).uninstall(py);
});
}

0 comments on commit 240e117

Please sign in to comment.