Skip to content

Commit

Permalink
Add support for combining the #[new] and #[classmethod] method ty…
Browse files Browse the repository at this point in the history
…pes.
  • Loading branch information
stuhood committed May 16, 2023
1 parent edb9522 commit 329ef94
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 10 deletions.
20 changes: 20 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,26 @@ created from Rust, but not from Python.

For arguments, see the [`Method arguments`](#method-arguments) section below.

### Constructors which accept a class argument

To create a constructor which takes a positional class argument, you can additionally mark your constructor with `#[classmethod]`:
```rust
# use pyo3::prelude::*;
# use pyo3::types::PyType;
# #[pyclass]
# struct BaseClass(PyObject);
#
#[pymethods]
impl BaseClass {
#[new]
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
// Get an abstract attribute (presumably) declared on a subclass of this class.
let subclass_attr = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.to_object(py)))
}
}
```

## Adding the class to a module

The next step is to create the module initializer and add our class to it:
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3157.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow combining `#[new]` and `#[classmethod]` to create a constructor which receives a (subtype's) class/`PyType` as its first argument.
38 changes: 30 additions & 8 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
pub enum MethodTypeAttribute {
/// `#[new]`
New,
/// `#[new]` && `#[classmethod]`
NewClassMethod,
/// `#[classmethod]`
ClassMethod,
/// `#[classattr]`
Expand All @@ -102,6 +104,7 @@ pub enum FnType {
Setter(SelfType),
Fn(SelfType),
FnNew,
FnNewClass,
FnClass,
FnStatic,
FnModule,
Expand All @@ -122,7 +125,7 @@ impl FnType {
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => {
quote!()
}
FnType::FnClass => {
FnType::FnClass | FnType::FnNewClass => {
quote! {
let _slf = _pyo3::types::PyType::from_type_ptr(_py, _slf as *mut _pyo3::ffi::PyTypeObject);
}
Expand Down Expand Up @@ -368,12 +371,16 @@ impl<'a> FnSpec<'a> {
let (fn_type, skip_first_arg, fixed_convention) = match fn_type_attr {
Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false, None),
Some(MethodTypeAttribute::ClassAttribute) => (FnType::ClassAttribute, false, None),
Some(MethodTypeAttribute::New) => {
Some(MethodTypeAttribute::New | MethodTypeAttribute::NewClassMethod) => {
if let Some(name) = &python_name {
bail_spanned!(name.span() => "`name` not allowed with `#[new]`");
}
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
(FnType::FnNew, false, Some(CallingConvention::TpNew))
if matches!(fn_type_attr, Some(MethodTypeAttribute::New)) {
(FnType::FnNew, false, Some(CallingConvention::TpNew))
} else {
(FnType::FnNewClass, true, Some(CallingConvention::TpNew))
}
}
Some(MethodTypeAttribute::ClassMethod) => (FnType::FnClass, true, None),
Some(MethodTypeAttribute::Getter) => {
Expand Down Expand Up @@ -496,7 +503,11 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::TpNew => {
let (arg_convert, args) = impl_arg_params(self, cls, &py, false)?;
let call = quote! { #rust_name(#(#args),*) };
let call = match &self.tp {
FnType::FnNew => quote! { #rust_name(#(#args),*) },
FnType::FnNewClass => quote! { #rust_name(PyType::from_type_ptr(#py, subtype), #(#args),*) },
x => panic!("Only `FnNew` or `FnNewClass` may use the `TpNew` calling convention. Got: {:?}", x),
};
quote! {
unsafe fn #ident(
#py: _pyo3::Python<'_>,
Expand Down Expand Up @@ -609,7 +620,7 @@ impl<'a> FnSpec<'a> {
FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => return None,
FnType::Fn(_) => Some("self"),
FnType::FnModule => Some("module"),
FnType::FnClass => Some("cls"),
FnType::FnClass | FnType::FnNewClass => Some("cls"),
FnType::FnStatic | FnType::FnNew => None,
};

Expand Down Expand Up @@ -637,11 +648,22 @@ fn parse_method_attributes(
let mut deprecated_args = None;
let mut ty: Option<MethodTypeAttribute> = None;

macro_rules! set_compound_ty {
($new_ty:expr, $ident:expr) => {
ty = match (ty, $new_ty) {
(None, new_ty) => Some(new_ty),
(Some(MethodTypeAttribute::ClassMethod), MethodTypeAttribute::New) => Some(MethodTypeAttribute::NewClassMethod),
(Some(MethodTypeAttribute::New), MethodTypeAttribute::ClassMethod) => Some(MethodTypeAttribute::NewClassMethod),
(Some(_), _) => bail_spanned!($ident.span() => "can only combine `new` and `classmethod`"),
};
};
}

macro_rules! set_ty {
($new_ty:expr, $ident:expr) => {
ensure_spanned!(
ty.replace($new_ty).is_none(),
$ident.span() => "cannot specify a second method type"
$ident.span() => "cannot combine these method types"
);
};
}
Expand All @@ -650,13 +672,13 @@ fn parse_method_attributes(
match attr.parse_meta() {
Ok(syn::Meta::Path(name)) => {
if name.is_ident("new") || name.is_ident("__new__") {
set_ty!(MethodTypeAttribute::New, name);
set_compound_ty!(MethodTypeAttribute::New, name);
} else if name.is_ident("init") || name.is_ident("__init__") {
bail_spanned!(name.span() => "#[init] is disabled since PyO3 0.9.0");
} else if name.is_ident("call") || name.is_ident("__call__") {
bail_spanned!(name.span() => "use `fn __call__` instead of `#[call]` attribute since PyO3 0.15.0");
} else if name.is_ident("classmethod") {
set_ty!(MethodTypeAttribute::ClassMethod, name);
set_compound_ty!(MethodTypeAttribute::ClassMethod, name);
} else if name.is_ident("staticmethod") {
set_ty!(MethodTypeAttribute::StaticMethod, name);
} else if name.is_ident("classattr") {
Expand Down
4 changes: 3 additions & 1 deletion pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ pub fn gen_py_method(
Some(quote!(_pyo3::ffi::METH_STATIC)),
)?),
// special prototypes
(_, FnType::FnNew) => GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?),
(_, FnType::FnNew | FnType::FnNewClass) => {
GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?)
}

(_, FnType::Getter(self_type)) => GeneratedPyMethod::Method(impl_py_getter_def(
cls,
Expand Down
28 changes: 27 additions & 1 deletion tests/test_class_new.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use pyo3::types::{IntoPyDict, PyType};

#[pyclass]
struct EmptyClassWithNew {}
Expand Down Expand Up @@ -204,3 +204,29 @@ fn new_with_custom_error() {
assert_eq!(err.to_string(), "ValueError: custom error");
});
}

#[pyclass]
#[derive(Clone, Debug)]
struct NewWithClassMethod;

#[pymethods]
impl NewWithClassMethod {
#[new]
#[classmethod]
fn new(cls: &PyType) -> PyResult<Self> {
assert!(cls.is_subclass_of::<NewWithClassMethod>()?);
Ok(Self)
}
}

#[test]
fn new_with_class_method() {
Python::with_gil(|py| {
let typeobj = py.get_type::<NewWithClassMethod>();
typeobj
.call0()
.unwrap()
.extract::<NewWithClassMethod>()
.unwrap();
});
}

0 comments on commit 329ef94

Please sign in to comment.