Skip to content

Commit

Permalink
Merge #3184
Browse files Browse the repository at this point in the history
3184: Additional tests for #3168 r=adamreichold a=adamreichold

These were a part of tests `@lifthrasiir` was preparing for #3165, and I believe it's worthy to add them (any single of them fails in the current main branch).

Co-authored-by: Kang Seonghoon <public+git@mearie.org>
  • Loading branch information
bors[bot] and lifthrasiir committed May 25, 2023
2 parents 32c335e + e884327 commit 2ed1d70
Showing 1 changed file with 199 additions and 34 deletions.
233 changes: 199 additions & 34 deletions tests/test_gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::class::PyTraverseError;
use pyo3::class::PyVisit;
use pyo3::prelude::*;
use pyo3::{py_run, AsPyPointer, PyCell, PyTryInto};
use std::cell::Cell;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

Expand Down Expand Up @@ -248,22 +249,10 @@ impl TraversableClass {
}
}

unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse))
}

#[test]
fn gc_during_borrow() {
Python::with_gil(|py| {
unsafe {
// declare a dummy visitor function
extern "C" fn novisit(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
0
}

// get the traverse function
let ty = py.get_type::<TraversableClass>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();
Expand All @@ -290,18 +279,18 @@ fn gc_during_borrow() {
}

#[pyclass]
struct PanickyTraverse {
struct PartialTraverse {
member: PyObject,
}

impl PanickyTraverse {
impl PartialTraverse {
fn new(py: Python<'_>) -> Self {
Self { member: py.None() }
}
}

#[pymethods]
impl PanickyTraverse {
impl PartialTraverse {
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(&self.member)?;
// In the test, we expect this to never be hit
Expand All @@ -310,29 +299,53 @@ impl PanickyTraverse {
}

#[test]
fn traverse_error() {
fn traverse_partial() {
Python::with_gil(|py| unsafe {
// declare a visitor function which errors (returns nonzero code)
extern "C" fn visit_error(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
-1
}

// get the traverse function
let ty = py.get_type::<PanickyTraverse>().as_type_ptr();
let ty = py.get_type::<PartialTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

// confirm that traversing errors
let obj = Py::new(py, PanickyTraverse::new(py)).unwrap();
let obj = Py::new(py, PartialTraverse::new(py)).unwrap();
assert_eq!(
traverse(obj.as_ptr(), visit_error, std::ptr::null_mut()),
-1
);
})
}

#[pyclass]
struct PanickyTraverse {
member: PyObject,
}

impl PanickyTraverse {
fn new(py: Python<'_>) -> Self {
Self { member: py.None() }
}
}

#[pymethods]
impl PanickyTraverse {
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(&self.member)?;
panic!("at the disco");
}
}

#[test]
fn traverse_panic() {
Python::with_gil(|py| unsafe {
// get the traverse function
let ty = py.get_type::<PanickyTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

// confirm that traversing errors
let obj = Py::new(py, PanickyTraverse::new(py)).unwrap();
assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1);
})
}

#[pyclass]
struct TriesGILInTraverse {}

Expand All @@ -346,14 +359,6 @@ impl TriesGILInTraverse {
#[test]
fn tries_gil_in_traverse() {
Python::with_gil(|py| unsafe {
// declare a visitor function which errors (returns nonzero code)
extern "C" fn novisit(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
0
}

// get the traverse function
let ty = py.get_type::<TriesGILInTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();
Expand All @@ -363,3 +368,163 @@ fn tries_gil_in_traverse() {
assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1);
})
}

#[pyclass]
struct HijackedTraverse {
traversed: Cell<bool>,
hijacked: Cell<bool>,
}

impl HijackedTraverse {
fn new() -> Self {
Self {
traversed: Cell::new(false),
hijacked: Cell::new(false),
}
}

fn traversed_and_hijacked(&self) -> (bool, bool) {
(self.traversed.get(), self.hijacked.get())
}
}

#[pymethods]
impl HijackedTraverse {
#[allow(clippy::unnecessary_wraps)]
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.traversed.set(true);
Ok(())
}
}

trait Traversable {
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError>;
}

impl<'a> Traversable for PyRef<'a, HijackedTraverse> {
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.hijacked.set(true);
Ok(())
}
}

#[test]
fn traverse_cannot_be_hijacked() {
Python::with_gil(|py| unsafe {
// get the traverse function
let ty = py.get_type::<HijackedTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

let cell = PyCell::new(py, HijackedTraverse::new()).unwrap();
let obj = cell.to_object(py);
assert_eq!(cell.borrow().traversed_and_hijacked(), (false, false));
traverse(obj.as_ptr(), novisit, std::ptr::null_mut());
assert_eq!(cell.borrow().traversed_and_hijacked(), (true, false));
})
}

#[allow(dead_code)]
#[pyclass]
struct DropDuringTraversal {
cycle: Cell<Option<Py<Self>>>,
dropped: TestDropCall,
}

#[pymethods]
impl DropDuringTraversal {
#[allow(clippy::unnecessary_wraps)]
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.cycle.take();
Ok(())
}

fn __clear__(&mut self) {
self.cycle.take();
}
}

#[test]
fn drop_during_traversal_with_gil() {
let drop_called = Arc::new(AtomicBool::new(false));

Python::with_gil(|py| {
let inst = Py::new(
py,
DropDuringTraversal {
cycle: Cell::new(None),
dropped: TestDropCall {
drop_called: Arc::clone(&drop_called),
},
},
)
.unwrap();

inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py)));

drop(inst);
});

// due to the internal GC mechanism, we may need multiple
// (but not too many) collections to get `inst` actually dropped.
for _ in 0..10 {
Python::with_gil(|py| {
py.run("import gc; gc.collect()", None, None).unwrap();
});
}
assert!(drop_called.load(Ordering::Relaxed));
}

#[test]
fn drop_during_traversal_without_gil() {
let drop_called = Arc::new(AtomicBool::new(false));

let inst = Python::with_gil(|py| {
let inst = Py::new(
py,
DropDuringTraversal {
cycle: Cell::new(None),
dropped: TestDropCall {
drop_called: Arc::clone(&drop_called),
},
},
)
.unwrap();

inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py)));

inst
});

drop(inst);

// due to the internal GC mechanism, we may need multiple
// (but not too many) collections to get `inst` actually dropped.
for _ in 0..10 {
Python::with_gil(|py| {
py.run("import gc; gc.collect()", None, None).unwrap();
});
}
assert!(drop_called.load(Ordering::Relaxed));
}

// Manual traversal utilities

unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse))
}

// a dummy visitor function
extern "C" fn novisit(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
0
}

// a visitor function which errors (returns nonzero code)
extern "C" fn visit_error(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
-1
}

0 comments on commit 2ed1d70

Please sign in to comment.