Skip to content

Commit

Permalink
Merge #3168
Browse files Browse the repository at this point in the history
3168: Do not apply deferred ref count updates and prevent the GIL from being acquired inside of __traverse__ implementations. r=davidhewitt a=adamreichold

Closes #2301
Closes #3165


Co-authored-by: Adam Reichold <adam.reichold@t-online.de>
  • Loading branch information
bors[bot] and adamreichold committed May 25, 2023
2 parents d71af73 + 89a6c43 commit efaf768
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 28 deletions.
4 changes: 4 additions & 0 deletions guide/src/class/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ impl ClassWithGCSupport {
}
```

Usually, an implementation of `__traverse__` should do nothing but calls to `visit.call`.
Most importantly, safe access to the GIL is prohibited inside implementations of `__traverse__`,
i.e. `Python::with_gil` will panic.

> Note: these methods are part of the C API, PyPy does not necessarily honor them. If you are building for PyPy you should measure memory consumption to make sure you do not have runaway memory growth. See [this issue on the PyPy bug tracker](https://foss.heptapod.net/pypy/pypy/-/issues/3899).
[`IterNextOutput`]: {{#PYO3_DOCS_URL}}/pyo3/pyclass/enum.IterNextOutput.html
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3168.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Safe access to the GIL, for example via `Python::with_gil`, is now locked inside of implementations of the `__traverse__` slot.
1 change: 1 addition & 0 deletions newsfragments/3168.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Do not apply deferred reference count updates when entering a `__traverse__` implementation is it cannot alter any reference counts while the garbage collector is running.
18 changes: 2 additions & 16 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,22 +406,8 @@ fn impl_traverse_slot(cls: &syn::Type, rust_fn_ident: &syn::Ident) -> MethodAndS
slf: *mut _pyo3::ffi::PyObject,
visit: _pyo3::ffi::visitproc,
arg: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int
{
let trap = _pyo3::impl_::panic::PanicTrap::new("uncaught panic inside __traverse__ handler");
let pool = _pyo3::GILPool::new();
let py = pool.python();
let slf = py.from_borrowed_ptr::<_pyo3::PyCell<#cls>>(slf);

let visit = _pyo3::class::gc::PyVisit::from_raw(visit, arg, py);
let borrow = slf.try_borrow();
let retval = if let ::std::result::Result::Ok(borrow) = borrow {
_pyo3::impl_::pymethods::unwrap_traverse_result(borrow.#rust_fn_ident(visit))
} else {
0
};
trap.disarm();
retval
) -> ::std::os::raw::c_int {
_pyo3::impl_::pymethods::call_traverse_impl::<#cls>(slf, #cls::#rust_fn_ident, visit, arg)
}
};
let slot_def = quote! {
Expand Down
51 changes: 48 additions & 3 deletions src/gil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@ thread_local_const_init! {
/// they are dropped.
///
/// As a result, if this thread has the GIL, GIL_COUNT is greater than zero.
static GIL_COUNT: Cell<usize> = const { Cell::new(0) };
///
/// Additionally, we sometimes need to prevent safe access to the GIL,
/// e.g. when implementing `__traverse__`, which is represented by a negative value.
static GIL_COUNT: Cell<isize> = const { Cell::new(0) };

/// Temporarily hold objects that will be released when the GILPool drops.
static OWNED_OBJECTS: RefCell<Vec<NonNull<ffi::PyObject>>> = const { RefCell::new(Vec::new()) };
}

const GIL_LOCKED_DURING_TRAVERSE: isize = -1;

/// Checks whether the GIL is acquired.
///
/// Note: This uses pyo3's internal count rather than PyGILState_Check for two reasons:
Expand Down Expand Up @@ -286,7 +291,7 @@ static POOL: ReferencePool = ReferencePool::new();

/// A guard which can be used to temporarily release the GIL and restore on `Drop`.
pub(crate) struct SuspendGIL {
count: usize,
count: isize,
tstate: *mut ffi::PyThreadState,
}

Expand All @@ -311,6 +316,40 @@ impl Drop for SuspendGIL {
}
}

/// Used to lock safe access to the GIL
pub(crate) struct LockGIL {
count: isize,
}

impl LockGIL {
/// Lock access to the GIL while an implementation of `__traverse__` is running
pub fn during_traverse() -> Self {
Self::new(GIL_LOCKED_DURING_TRAVERSE)
}

fn new(reason: isize) -> Self {
let count = GIL_COUNT.with(|c| c.replace(reason));

Self { count }
}

#[cold]
fn bail(current: isize) {
match current {
GIL_LOCKED_DURING_TRAVERSE => panic!(
"Access to the GIL is prohibited while a __traverse__ implmentation is running."
),
_ => panic!("Access to the GIL is currently prohibited."),
}
}
}

impl Drop for LockGIL {
fn drop(&mut self) {
GIL_COUNT.with(|c| c.set(self.count));
}
}

/// A RAII pool which PyO3 uses to store owned Python references.
///
/// See the [Memory Management] chapter of the guide for more information about how PyO3 uses
Expand Down Expand Up @@ -421,7 +460,13 @@ pub unsafe fn register_owned(_py: Python<'_>, obj: NonNull<ffi::PyObject>) {
#[inline(always)]
fn increment_gil_count() {
// Ignores the error in case this function called from `atexit`.
let _ = GIL_COUNT.try_with(|c| c.set(c.get() + 1));
let _ = GIL_COUNT.try_with(|c| {
let current = c.get();
if current < 0 {
LockGIL::bail(current);
}
c.set(current + 1);
});
}

/// Decrements pyo3's internal GIL count - to be called whenever GILPool or GILGuard is dropped.
Expand Down
56 changes: 47 additions & 9 deletions src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::internal_tricks::extract_c_string;
use crate::{ffi, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, PyTraverseError, Python};
use crate::{
ffi, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit,
Python,
};
use std::borrow::Cow;
use std::ffi::CStr;
use std::fmt;
use std::os::raw::c_int;
use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe};

/// Python 3.8 and up - __ipow__ has modulo argument correctly populated.
#[cfg(Py_3_8)]
Expand Down Expand Up @@ -239,14 +245,46 @@ impl PySetterDef {
}
}

/// Unwraps the result of __traverse__ for tp_traverse
/// Calls an implementation of __traverse__ for tp_traverse
#[doc(hidden)]
#[inline]
pub fn unwrap_traverse_result(result: Result<(), PyTraverseError>) -> c_int {
match result {
Ok(()) => 0,
Err(PyTraverseError(value)) => value,
}
pub unsafe fn call_traverse_impl<T>(
slf: *mut ffi::PyObject,
impl_: fn(&T, PyVisit<'_>) -> Result<(), PyTraverseError>,
visit: ffi::visitproc,
arg: *mut c_void,
) -> c_int
where
T: PyClass,
{
// It is important the implementation of `__traverse__` cannot safely access the GIL,
// c.f. https://github.com/PyO3/pyo3/issues/3165, and hence we do not expose our GIL
// token to the user code and lock safe methods for acquiring the GIL.
// (This includes enforcing the `&self` method receiver as e.g. `PyRef<Self>` could
// reconstruct a GIL token via `PyRef::py`.)
// Since we do not create a `GILPool` at all, it is important that our usage of the GIL
// token does not produce any owned objects thereby calling into `register_owned`.
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");

let py = Python::assume_gil_acquired();
let slf = py.from_borrowed_ptr::<PyCell<T>>(slf);
let borrow = slf.try_borrow();
let visit = PyVisit::from_raw(visit, arg, py);

let retval = if let Ok(borrow) = borrow {
let _lock = LockGIL::during_traverse();

match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) {
Ok(res) => match res {
Ok(()) => 0,
Err(PyTraverseError(value)) => value,
},
Err(_err) => -1,
}
} else {
0
};
trap.disarm();
retval
}

pub(crate) struct PyMethodDefDestructor {
Expand Down
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ fn _test_compile_errors() {
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
t.compile_fail("tests/ui/invalid_pymodule_args.rs");
t.compile_fail("tests/ui/reject_generics.rs");
t.compile_fail("tests/ui/traverse_bare_self.rs");

tests_rust_1_49(&t);
tests_rust_1_56(&t);
Expand Down
31 changes: 31 additions & 0 deletions tests/test_gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,34 @@ fn traverse_error() {
);
})
}

#[pyclass]
struct TriesGILInTraverse {}

#[pymethods]
impl TriesGILInTraverse {
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Python::with_gil(|_py| Ok(()))
}
}

#[test]
fn tries_gil_in_traverse() {
Python::with_gil(|py| unsafe {
// declare a visitor function which errors (returns nonzero code)
extern "C" fn novisit(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
0
}

// get the traverse function
let ty = py.get_type::<TriesGILInTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

// confirm that traversing panicks
let obj = Py::new(py, TriesGILInTraverse {}).unwrap();
assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1);
})
}
12 changes: 12 additions & 0 deletions tests/ui/traverse_bare_self.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use pyo3::prelude::*;
use pyo3::PyVisit;

#[pyclass]
struct TraverseTriesToTakePyRef {}

#[pymethods]
impl TraverseTriesToTakePyRef {
fn __traverse__(slf: PyRef<Self>, visit: PyVisit) {}
}

fn main() {}
17 changes: 17 additions & 0 deletions tests/ui/traverse_bare_self.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
error[E0308]: mismatched types
--> tests/ui/traverse_bare_self.rs:8:6
|
7 | #[pymethods]
| ------------ arguments to this function are incorrect
8 | impl TraverseTriesToTakePyRef {
| ______^
9 | | fn __traverse__(slf: PyRef<Self>, visit: PyVisit) {}
| |___________________^ expected fn pointer, found fn item
|
= note: expected fn pointer `for<'a, 'b> fn(&'a TraverseTriesToTakePyRef, PyVisit<'b>) -> Result<(), PyTraverseError>`
found fn item `for<'a, 'b> fn(pyo3::PyRef<'a, TraverseTriesToTakePyRef>, PyVisit<'b>) {TraverseTriesToTakePyRef::__traverse__}`
note: function defined here
--> src/impl_/pymethods.rs
|
| pub unsafe fn call_traverse_impl<T>(
| ^^^^^^^^^^^^^^^^^^

0 comments on commit efaf768

Please sign in to comment.