Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not apply deferred ref count updates and prevent the GIL from being acquired inside of __traverse__ implementations. #3168

Merged
merged 4 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions guide/src/class/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ impl ClassWithGCSupport {
}
```

Usually, an implementation of `__traverse__` should do nothing but calls to `visit.call`.
Most importantly, safe access to the GIL is prohibited inside implementations of `__traverse__`,
i.e. `Python::with_gil` will panic.

> Note: these methods are part of the C API, PyPy does not necessarily honor them. If you are building for PyPy you should measure memory consumption to make sure you do not have runaway memory growth. See [this issue on the PyPy bug tracker](https://foss.heptapod.net/pypy/pypy/-/issues/3899).

[`IterNextOutput`]: {{#PYO3_DOCS_URL}}/pyo3/pyclass/enum.IterNextOutput.html
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3168.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Safe access to the GIL, for example via `Python::with_gil`, is now locked inside of implementations of the `__traverse__` slot.
1 change: 1 addition & 0 deletions newsfragments/3168.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Do not apply deferred reference count updates when entering a `__traverse__` implementation is it cannot alter any reference counts while the garbage collector is running.
18 changes: 2 additions & 16 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,22 +406,8 @@ fn impl_traverse_slot(cls: &syn::Type, rust_fn_ident: &syn::Ident) -> MethodAndS
slf: *mut _pyo3::ffi::PyObject,
visit: _pyo3::ffi::visitproc,
arg: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int
{
let trap = _pyo3::impl_::panic::PanicTrap::new("uncaught panic inside __traverse__ handler");
let pool = _pyo3::GILPool::new();
let py = pool.python();
let slf = py.from_borrowed_ptr::<_pyo3::PyCell<#cls>>(slf);

let visit = _pyo3::class::gc::PyVisit::from_raw(visit, arg, py);
let borrow = slf.try_borrow();
let retval = if let ::std::result::Result::Ok(borrow) = borrow {
_pyo3::impl_::pymethods::unwrap_traverse_result(borrow.#rust_fn_ident(visit))
} else {
0
};
trap.disarm();
retval
) -> ::std::os::raw::c_int {
_pyo3::impl_::pymethods::call_traverse_impl::<#cls>(slf, #cls::#rust_fn_ident, visit, arg)
adamreichold marked this conversation as resolved.
Show resolved Hide resolved
}
};
let slot_def = quote! {
Expand Down
51 changes: 48 additions & 3 deletions src/gil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@ thread_local_const_init! {
/// they are dropped.
///
/// As a result, if this thread has the GIL, GIL_COUNT is greater than zero.
static GIL_COUNT: Cell<usize> = const { Cell::new(0) };
///
/// Additionally, we sometimes need to prevent safe access to the GIL,
/// e.g. when implementing `__traverse__`, which is represented by a negative value.
static GIL_COUNT: Cell<isize> = const { Cell::new(0) };

/// Temporarily hold objects that will be released when the GILPool drops.
static OWNED_OBJECTS: RefCell<Vec<NonNull<ffi::PyObject>>> = const { RefCell::new(Vec::new()) };
}

const GIL_LOCKED_DURING_TRAVERSE: isize = -1;

/// Checks whether the GIL is acquired.
///
/// Note: This uses pyo3's internal count rather than PyGILState_Check for two reasons:
Expand Down Expand Up @@ -286,7 +291,7 @@ static POOL: ReferencePool = ReferencePool::new();

/// A guard which can be used to temporarily release the GIL and restore on `Drop`.
pub(crate) struct SuspendGIL {
count: usize,
count: isize,
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
tstate: *mut ffi::PyThreadState,
}

Expand All @@ -311,6 +316,40 @@ impl Drop for SuspendGIL {
}
}

/// Used to lock safe access to the GIL
pub(crate) struct LockGIL {
count: isize,
}

impl LockGIL {
/// Lock access to the GIL while an implementation of `__traverse__` is running
pub fn during_traverse() -> Self {
Self::new(GIL_LOCKED_DURING_TRAVERSE)
}

fn new(reason: isize) -> Self {
let count = GIL_COUNT.with(|c| c.replace(reason));

Self { count }
}

#[cold]
fn bail(current: isize) {
match current {
GIL_LOCKED_DURING_TRAVERSE => panic!(
"Access to the GIL is prohibited while a __traverse__ implmentation is running."
),
_ => panic!("Access to the GIL is currently prohibited."),
}
}
}

impl Drop for LockGIL {
fn drop(&mut self) {
GIL_COUNT.with(|c| c.set(self.count));
}
}

/// A RAII pool which PyO3 uses to store owned Python references.
///
/// See the [Memory Management] chapter of the guide for more information about how PyO3 uses
Expand Down Expand Up @@ -421,7 +460,13 @@ pub unsafe fn register_owned(_py: Python<'_>, obj: NonNull<ffi::PyObject>) {
#[inline(always)]
fn increment_gil_count() {
// Ignores the error in case this function called from `atexit`.
let _ = GIL_COUNT.try_with(|c| c.set(c.get() + 1));
let _ = GIL_COUNT.try_with(|c| {
let current = c.get();
if current < 0 {
LockGIL::bail(current);
}
c.set(current + 1);
});
}

/// Decrements pyo3's internal GIL count - to be called whenever GILPool or GILGuard is dropped.
Expand Down
56 changes: 47 additions & 9 deletions src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::internal_tricks::extract_c_string;
use crate::{ffi, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, PyTraverseError, Python};
use crate::{
ffi, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit,
Python,
};
use std::borrow::Cow;
use std::ffi::CStr;
use std::fmt;
use std::os::raw::c_int;
use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe};

/// Python 3.8 and up - __ipow__ has modulo argument correctly populated.
#[cfg(Py_3_8)]
Expand Down Expand Up @@ -239,14 +245,46 @@ impl PySetterDef {
}
}

/// Unwraps the result of __traverse__ for tp_traverse
/// Calls an implementation of __traverse__ for tp_traverse
#[doc(hidden)]
#[inline]
pub fn unwrap_traverse_result(result: Result<(), PyTraverseError>) -> c_int {
match result {
Ok(()) => 0,
Err(PyTraverseError(value)) => value,
}
pub unsafe fn call_traverse_impl<T>(
slf: *mut ffi::PyObject,
impl_: fn(&T, PyVisit<'_>) -> Result<(), PyTraverseError>,
visit: ffi::visitproc,
arg: *mut c_void,
) -> c_int
where
T: PyClass,
{
// It is important the implementation of `__traverse__` cannot safely access the GIL,
// c.f. https://github.com/PyO3/pyo3/issues/3165, and hence we do not expose our GIL
// token to the user code and lock safe methods for acquiring the GIL.
// (This includes enforcing the `&self` method receiver as e.g. `PyRef<Self>` could
// reconstruct a GIL token via `PyRef::py`.)
// Since we do not create a `GILPool` at all, it is important that our usage of the GIL
// token does not produce any owned objects thereby calling into `register_owned`.
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");

let py = Python::assume_gil_acquired();
let slf = py.from_borrowed_ptr::<PyCell<T>>(slf);
let borrow = slf.try_borrow();
let visit = PyVisit::from_raw(visit, arg, py);

let retval = if let Ok(borrow) = borrow {
let _lock = LockGIL::during_traverse();

match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) {
Ok(res) => match res {
Ok(()) => 0,
Err(PyTraverseError(value)) => value,
},
Err(_err) => -1,
}
} else {
0
};
trap.disarm();
retval
}

pub(crate) struct PyMethodDefDestructor {
Expand Down
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ fn _test_compile_errors() {
t.compile_fail("tests/ui/not_send2.rs");
t.compile_fail("tests/ui/not_send3.rs");
t.compile_fail("tests/ui/get_set_all.rs");
t.compile_fail("tests/ui/traverse_bare_self.rs");
}

#[rustversion::before(1.63)]
Expand Down
31 changes: 31 additions & 0 deletions tests/test_gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,34 @@ fn traverse_error() {
);
})
}

#[pyclass]
struct TriesGILInTraverse {}

#[pymethods]
impl TriesGILInTraverse {
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Python::with_gil(|_py| Ok(()))
}
}

#[test]
fn tries_gil_in_traverse() {
Python::with_gil(|py| unsafe {
// declare a visitor function which errors (returns nonzero code)
extern "C" fn novisit(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
0
}

// get the traverse function
let ty = py.get_type::<TriesGILInTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

// confirm that traversing panicks
let obj = Py::new(py, TriesGILInTraverse {}).unwrap();
assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1);
})
}
12 changes: 12 additions & 0 deletions tests/ui/traverse_bare_self.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use pyo3::prelude::*;
use pyo3::PyVisit;

#[pyclass]
struct TraverseTriesToTakePyRef {}

#[pymethods]
impl TraverseTriesToTakePyRef {
fn __traverse__(slf: PyRef<Self>, visit: PyVisit) {}
}

fn main() {}
17 changes: 17 additions & 0 deletions tests/ui/traverse_bare_self.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
error[E0308]: mismatched types
--> tests/ui/traverse_bare_self.rs:8:6
|
7 | #[pymethods]
| ------------ arguments to this function are incorrect
8 | impl TraverseTriesToTakePyRef {
| ______^
9 | | fn __traverse__(slf: PyRef<Self>, visit: PyVisit) {}
| |___________________^ expected fn pointer, found fn item
|
= note: expected fn pointer `for<'a, 'b> fn(&'a TraverseTriesToTakePyRef, PyVisit<'b>) -> Result<(), PyTraverseError>`
found fn item `for<'a, 'b> fn(pyo3::PyRef<'a, TraverseTriesToTakePyRef>, PyVisit<'b>) {TraverseTriesToTakePyRef::__traverse__}`
note: function defined here
--> src/impl_/pymethods.rs
|
| pub unsafe fn call_traverse_impl<T>(
| ^^^^^^^^^^^^^^^^^^
adamreichold marked this conversation as resolved.
Show resolved Hide resolved