Skip to content

Commit

Permalink
Factor out UnraisableCapture helper type and use it to check that dro…
Browse files Browse the repository at this point in the history
…pping unsendable elsewhere calls into sys.unraisablehook
  • Loading branch information
adamreichold committed May 24, 2023
1 parent 6261f79 commit 074f9e1
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 75 deletions.
49 changes: 49 additions & 0 deletions tests/common.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! Some common macros for tests

#[cfg(feature = "macros")]
use pyo3::prelude::*;

#[macro_export]
macro_rules! py_assert {
($py:expr, $($val:ident)+, $assertion:literal) => {
Expand Down Expand Up @@ -41,3 +44,49 @@ macro_rules! py_expect_exception {
err
}};
}

#[cfg(feature = "macros")]
#[pyclass]
pub struct UnraisableCapture {
pub capture: Option<(PyErr, PyObject)>,
old_hook: Option<PyObject>,
}

#[cfg(feature = "macros")]
#[pymethods]
impl UnraisableCapture {
pub fn hook(&mut self, unraisable: &PyAny) {
let err = PyErr::from_value(unraisable.getattr("exc_value").unwrap());
let instance = unraisable.getattr("object").unwrap();
self.capture = Some((err, instance.into()));
}
}

#[cfg(feature = "macros")]
impl UnraisableCapture {
pub fn install(py: Python<'_>) -> Py<Self> {
let sys = py.import("sys").unwrap();
let old_hook = sys.getattr("unraisablehook").unwrap().into();

let capture = Py::new(
py,
UnraisableCapture {
capture: None,
old_hook: Some(old_hook),
},
)
.unwrap();

sys.setattr("unraisablehook", capture.getattr(py, "hook").unwrap())
.unwrap();

capture
}

pub fn uninstall(&mut self, py: Python<'_>) {
let old_hook = self.old_hook.take().unwrap();

let sys = py.import("sys").unwrap();
sys.setattr("unraisablehook", old_hook).unwrap();
}
}
24 changes: 3 additions & 21 deletions tests/test_buffer_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ fn test_buffer_referenced() {
#[test]
#[cfg(Py_3_8)] // sys.unraisablehook not available until Python 3.8
fn test_releasebuffer_unraisable_error() {
use common::UnraisableCapture;
use pyo3::exceptions::PyValueError;

#[pyclass]
Expand All @@ -117,27 +118,8 @@ fn test_releasebuffer_unraisable_error() {
}
}

#[pyclass]
struct UnraisableCapture {
capture: Option<(PyErr, PyObject)>,
}

#[pymethods]
impl UnraisableCapture {
fn hook(&mut self, unraisable: &PyAny) {
let err = PyErr::from_value(unraisable.getattr("exc_value").unwrap());
let instance = unraisable.getattr("object").unwrap();
self.capture = Some((err, instance.into()));
}
}

Python::with_gil(|py| {
let sys = py.import("sys").unwrap();
let old_hook = sys.getattr("unraisablehook").unwrap();
let capture = Py::new(py, UnraisableCapture { capture: None }).unwrap();

sys.setattr("unraisablehook", capture.getattr(py, "hook").unwrap())
.unwrap();
let capture = UnraisableCapture::install(py);

let instance = Py::new(py, ReleaseBufferError {}).unwrap();
let env = [("ob", instance.clone())].into_py_dict(py);
Expand All @@ -150,7 +132,7 @@ fn test_releasebuffer_unraisable_error() {
assert_eq!(err.to_string(), "ValueError: oh dear");
assert!(object.is(&instance));

sys.setattr("unraisablehook", old_hook).unwrap();
capture.borrow_mut(py).uninstall(py);
});
}

Expand Down
90 changes: 57 additions & 33 deletions tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,31 +228,40 @@ impl UnsendableChild {
}

fn test_unsendable<T: PyClass + 'static>() -> PyResult<()> {
let obj = std::thread::spawn(|| -> PyResult<_> {
let obj = Python::with_gil(|py| -> PyResult<_> {
let obj: Py<T> = PyType::new::<T>(py).call1((5,))?.extract()?;

// Accessing the value inside this thread should not panic
let caught_panic =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| -> PyResult<_> {
assert_eq!(obj.as_ref(py).getattr("value")?.extract::<usize>()?, 5);
Ok(())
}))
.is_err();

assert!(!caught_panic);
Ok(obj)
})?;

let keep_obj_here = obj.clone();

let caught_panic = std::thread::spawn(move || {
// This access must panic
Python::with_gil(|py| {
let obj: Py<T> = PyType::new::<T>(py).call1((5,))?.extract()?;

// Accessing the value inside this thread should not panic
let caught_panic =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| -> PyResult<_> {
assert_eq!(obj.as_ref(py).getattr("value")?.extract::<usize>()?, 5);
Ok(())
}))
.is_err();

assert!(!caught_panic);
Ok(obj)
})
obj.borrow(py);
});
})
.join()
.unwrap()?;
.join();

// This access must panic
Python::with_gil(|py| {
obj.borrow(py);
});
Python::with_gil(|_py| drop(keep_obj_here));

if let Err(err) = caught_panic {
if let Some(msg) = err.downcast_ref::<String>() {
panic!("{}", msg);
}
}

panic!("Borrowing unsendable from receiving thread did not panic.");
Ok(())
}

/// If a class is marked as `unsendable`, it panics when accessed by another thread.
Expand Down Expand Up @@ -529,6 +538,7 @@ fn access_frozen_class_without_gil() {

#[test]
fn drop_unsendable_elsewhere() {
use common::UnraisableCapture;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
Expand All @@ -546,21 +556,35 @@ fn drop_unsendable_elsewhere() {
}
}

let dropped = Arc::new(AtomicBool::new(false));
Python::with_gil(|py| {
let capture = UnraisableCapture::install(py);

let unsendable = Python::with_gil(|py| {
let dropped = dropped.clone();
let dropped = Arc::new(AtomicBool::new(false));

Py::new(py, Unsendable { dropped }).unwrap()
});
let unsendable = Py::new(
py,
Unsendable {
dropped: dropped.clone(),
},
)
.unwrap();

spawn(move || {
Python::with_gil(move |_py| {
drop(unsendable);
py.allow_threads(|| {
spawn(move || {
Python::with_gil(move |_py| {
drop(unsendable);
});
})
.join()
.unwrap();
});
})
.join()
.unwrap();

assert!(!dropped.load(Ordering::SeqCst));
assert!(!dropped.load(Ordering::SeqCst));

let (err, object) = capture.borrow_mut(py).capture.take().unwrap();
assert_eq!(err.to_string(), "RuntimeError: test_class_basics::drop_unsendable_elsewhere::Unsendable is unsendbale, but is dropped on another thread!");
assert!(object.is_none(py));

capture.borrow_mut(py).uninstall(py);
});
}
24 changes: 3 additions & 21 deletions tests/test_exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,11 @@ fn test_exception_nosegfault() {
#[test]
#[cfg(Py_3_8)]
fn test_write_unraisable() {
use common::UnraisableCapture;
use pyo3::{exceptions::PyRuntimeError, ffi, AsPyPointer};

#[pyclass]
struct UnraisableCapture {
capture: Option<(PyErr, PyObject)>,
}

#[pymethods]
impl UnraisableCapture {
fn hook(&mut self, unraisable: &PyAny) {
let err = PyErr::from_value(unraisable.getattr("exc_value").unwrap());
let instance = unraisable.getattr("object").unwrap();
self.capture = Some((err, instance.into()));
}
}

Python::with_gil(|py| {
let sys = py.import("sys").unwrap();
let old_hook = sys.getattr("unraisablehook").unwrap();
let capture = Py::new(py, UnraisableCapture { capture: None }).unwrap();

sys.setattr("unraisablehook", capture.getattr(py, "hook").unwrap())
.unwrap();
let capture = UnraisableCapture::install(py);

assert!(capture.borrow(py).capture.is_none());

Expand All @@ -140,6 +122,6 @@ fn test_write_unraisable() {
assert_eq!(err.to_string(), "RuntimeError: bar");
assert!(object.as_ptr() == unsafe { ffi::Py_NotImplemented() });

sys.setattr("unraisablehook", old_hook).unwrap();
capture.borrow_mut(py).uninstall(py);
});
}

0 comments on commit 074f9e1

Please sign in to comment.