Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dataclass_transform] support default parameters #14580

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 64 additions & 9 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,13 +480,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_import_all(self)


FUNCBASE_FLAGS: Final = [
"is_property",
"is_class",
"is_static",
"is_final",
"is_dataclass_transform",
]
FUNCBASE_FLAGS: Final = ["is_property", "is_class", "is_static", "is_final"]


class FuncBase(Node):
Expand All @@ -512,7 +506,6 @@ class FuncBase(Node):
"is_static", # Uses "@staticmethod"
"is_final", # Uses "@final"
"_fullname",
"is_dataclass_transform", # Is decorated with "@typing.dataclass_transform" or similar
)

def __init__(self) -> None:
Expand All @@ -531,7 +524,6 @@ def __init__(self) -> None:
self.is_final = False
# Name with module prefix
self._fullname = ""
self.is_dataclass_transform = False

@property
@abstractmethod
Expand Down Expand Up @@ -758,6 +750,8 @@ class FuncDef(FuncItem, SymbolNode, Statement):
"deco_line",
"is_trivial_body",
"is_mypy_only",
# Present only when a function is decorated with @typing.datasclass_transform or similar
"dataclass_transform_spec",
)

__match_args__ = ("name", "arguments", "type", "body")
Expand Down Expand Up @@ -785,6 +779,7 @@ def __init__(
self.deco_line: int | None = None
# Definitions that appear in if TYPE_CHECKING are marked with this flag.
self.is_mypy_only = False
self.dataclass_transform_spec: DataclassTransformSpec | None = None

@property
def name(self) -> str:
Expand All @@ -810,6 +805,11 @@ def serialize(self) -> JsonDict:
"flags": get_flags(self, FUNCDEF_FLAGS),
"abstract_status": self.abstract_status,
# TODO: Do we need expanded, original_def?
"dataclass_transform_spec": (
None
if self.dataclass_transform_spec is None
else self.dataclass_transform_spec.serialize()
),
}

@classmethod
Expand All @@ -832,6 +832,11 @@ def deserialize(cls, data: JsonDict) -> FuncDef:
ret.arg_names = data["arg_names"]
ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]]
ret.abstract_status = data["abstract_status"]
ret.dataclass_transform_spec = (
DataclassTransformSpec.deserialize(data["dataclass_transform_spec"])
if data["dataclass_transform_spec"] is not None
else None
)
# Leave these uninitialized so that future uses will trigger an error
del ret.arguments
del ret.max_pos
Expand Down Expand Up @@ -3857,6 +3862,56 @@ def deserialize(cls, data: JsonDict) -> SymbolTable:
return st


class DataclassTransformSpec:
wesleywright marked this conversation as resolved.
Show resolved Hide resolved
"""Specifies how a dataclass-like transform should be applied. The fields here are based on the
parameters accepted by `typing.dataclass_transform`."""

__slots__ = (
"eq_default",
"order_default",
"kw_only_default",
"frozen_default",
"field_specifiers",
)

def __init__(
self,
*,
eq_default: bool | None = None,
order_default: bool | None = None,
kw_only_default: bool | None = None,
field_specifiers: tuple[str, ...] | None = None,
# Specified outside of PEP 681:
# frozen_default was added to CPythonin https://github.com/python/cpython/pull/99958 citing
# positive discussion in typing-sig
frozen_default: bool | None = None,
):
self.eq_default = eq_default if eq_default is not None else True
self.order_default = order_default if order_default is not None else False
self.kw_only_default = kw_only_default if kw_only_default is not None else False
self.frozen_default = frozen_default if frozen_default is not None else False
self.field_specifiers = field_specifiers if field_specifiers is not None else ()

def serialize(self) -> JsonDict:
return {
"eq_default": self.eq_default,
"order_default": self.order_default,
"kw_only_default": self.kw_only_default,
"frozen_only_default": self.frozen_default,
"field_specifiers": self.field_specifiers,
}

@classmethod
def deserialize(cls, data: JsonDict) -> DataclassTransformSpec:
return DataclassTransformSpec(
eq_default=data.get("eq_default"),
order_default=data.get("order_default"),
kw_only_default=data.get("kw_only_default"),
frozen_default=data.get("frozen_default"),
field_specifiers=data.get("field_specifiers"),
)


def get_flags(node: Node, names: list[str]) -> list[str]:
return [name for name in names if getattr(node, name)]

Expand Down
106 changes: 71 additions & 35 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
AssignmentStmt,
CallExpr,
Context,
DataclassTransformSpec,
Expression,
JsonDict,
NameExpr,
Node,
PlaceholderNode,
RefExpr,
SymbolTableNode,
Expand All @@ -37,6 +39,7 @@
add_method,
deserialize_and_fixup_type,
)
from mypy.semanal_shared import find_dataclass_transform_spec
from mypy.server.trigger import make_wildcard_trigger
from mypy.state import state
from mypy.typeops import map_type_from_supertype
Expand All @@ -56,11 +59,16 @@

# The set of decorators that generate dataclasses.
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}
# The set of functions that generate dataclass fields.
field_makers: Final = {"dataclasses.field"}


SELF_TVAR_NAME: Final = "_DT"
_TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec(
eq_default=True,
order_default=False,
kw_only_default=False,
frozen_default=False,
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)


class DataclassAttribute:
Expand Down Expand Up @@ -155,6 +163,7 @@ class DataclassTransformer:

def __init__(self, ctx: ClassDefContext) -> None:
self._ctx = ctx
self._spec = _get_transform_spec(ctx.reason)

def transform(self) -> bool:
"""Apply all the necessary transformations to the underlying
Expand All @@ -172,9 +181,9 @@ def transform(self) -> bool:
return False
decorator_arguments = {
"init": _get_decorator_bool_argument(self._ctx, "init", True),
"eq": _get_decorator_bool_argument(self._ctx, "eq", True),
"order": _get_decorator_bool_argument(self._ctx, "order", False),
"frozen": _get_decorator_bool_argument(self._ctx, "frozen", False),
"eq": _get_decorator_bool_argument(self._ctx, "eq", self._spec.eq_default),
"order": _get_decorator_bool_argument(self._ctx, "order", self._spec.order_default),
"frozen": _get_decorator_bool_argument(self._ctx, "frozen", self._spec.frozen_default),
"slots": _get_decorator_bool_argument(self._ctx, "slots", False),
"match_args": _get_decorator_bool_argument(self._ctx, "match_args", True),
}
Expand Down Expand Up @@ -411,7 +420,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:

# Second, collect attributes belonging to the current class.
current_attr_names: set[str] = set()
kw_only = _get_decorator_bool_argument(ctx, "kw_only", False)
kw_only = _get_decorator_bool_argument(ctx, "kw_only", self._spec.kw_only_default)
for stmt in cls.defs.body:
# Any assignment that doesn't use the new type declaration
# syntax can be ignored out of hand.
Expand Down Expand Up @@ -461,7 +470,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
if self._is_kw_only_type(node_type):
kw_only = True

has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)
has_field_call, field_args = self._collect_field_args(stmt.rvalue, ctx)

is_in_init_param = field_args.get("init")
if is_in_init_param is None:
Expand Down Expand Up @@ -614,6 +623,36 @@ def _add_dataclass_fields_magic_attribute(self) -> None:
kind=MDEF, node=var, plugin_generated=True
)

def _collect_field_args(
self, expr: Expression, ctx: ClassDefContext
) -> tuple[bool, dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
"""
if (
isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)
and expr.callee.fullname in self._spec.field_specifiers
):
# field() only takes keyword arguments.
args = {}
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
if not kind.is_named():
if kind.is_named(star=True):
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
else:
message = '"field()" does not accept positional arguments'
ctx.api.fail(message, expr)
return True, {}
assert name is not None
args[name] = arg
return True, args
return False, {}


def dataclass_tag_callback(ctx: ClassDefContext) -> None:
"""Record that we have a dataclass in the main semantic analysis pass.
Expand All @@ -631,32 +670,29 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
return transformer.transform()


def _collect_field_args(
expr: Expression, ctx: ClassDefContext
) -> tuple[bool, dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
def _get_transform_spec(reason: Expression) -> DataclassTransformSpec:
"""Find the relevant transform parameters from the decorator/parent class/metaclass that
triggered the dataclasses plugin.

Although the resulting DataclassTransformSpec is based on the typing.dataclass_transform
function, we also use it for traditional dataclasses.dataclass classes as well for simplicity.
In those cases, we return a default spec rather than one based on a call to
`typing.dataclass_transform`.
"""
if (
isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)
and expr.callee.fullname in field_makers
):
# field() only takes keyword arguments.
args = {}
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
if not kind.is_named():
if kind.is_named(star=True):
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
else:
message = '"field()" does not accept positional arguments'
ctx.api.fail(message, expr)
return True, {}
assert name is not None
args[name] = arg
return True, args
return False, {}
if _is_dataclasses_decorator(reason):
return _TRANSFORM_SPEC_FOR_DATACLASSES

spec = find_dataclass_transform_spec(reason)
assert spec is not None, (
"trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor "
"decorated with typing.dataclass_transform"
)
return spec


def _is_dataclasses_decorator(node: Node) -> bool:
if isinstance(node, CallExpr):
node = node.callee
if isinstance(node, RefExpr):
return node.fullname in dataclass_makers
return False
53 changes: 33 additions & 20 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
ConditionalExpr,
Context,
ContinueStmt,
DataclassTransformSpec,
Decorator,
DelStmt,
DictExpr,
Expand Down Expand Up @@ -213,6 +214,7 @@
PRIORITY_FALLBACKS,
SemanticAnalyzerInterface,
calculate_tuple_fallback,
find_dataclass_transform_spec,
has_placeholder,
set_callable_name as set_callable_name,
)
Expand Down Expand Up @@ -1524,7 +1526,7 @@ def visit_decorator(self, dec: Decorator) -> None:
elif isinstance(d, CallExpr) and refers_to_fullname(
d.callee, DATACLASS_TRANSFORM_NAMES
):
dec.func.is_dataclass_transform = True
dec.func.dataclass_transform_spec = self.parse_dataclass_transform_spec(d)
elif not dec.var.is_property:
# We have seen a "non-trivial" decorator before seeing @property, if
# we will see a @property later, give an error, as we don't support this.
Expand Down Expand Up @@ -1729,7 +1731,7 @@ def apply_class_plugin_hooks(self, defn: ClassDef) -> None:
# Special case: if the decorator is itself decorated with
# typing.dataclass_transform, apply the hook for the dataclasses plugin
# TODO: remove special casing here
if hook is None and is_dataclass_transform_decorator(decorator):
if hook is None and find_dataclass_transform_spec(decorator):
hook = dataclasses_plugin.dataclass_tag_callback
if hook:
hook(ClassDefContext(defn, decorator, self))
Expand Down Expand Up @@ -6462,6 +6464,35 @@ def set_future_import_flags(self, module_name: str) -> None:
def is_future_flag_set(self, flag: str) -> bool:
return self.modules[self.cur_mod_id].is_future_flag_set(flag)

def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSpec:
"""Build a DataclassTransformSpec from the arguments passed to the given call to
typing.dataclass_transform."""
parameters = DataclassTransformSpec()
for name, value in zip(call.arg_names, call.args):
# field_specifiers is currently the only non-boolean argument; check for it first so
# so the rest of the block can fail through to handling booleans
if name == "field_specifiers":
self.fail('"field_specifiers" support is currently unimplemented', call)
continue

boolean = self.parse_bool(value)
if boolean is None:
self.fail(f'"{name}" argument must be a True or False literal', call)
continue

if name == "eq_default":
parameters.eq_default = boolean
elif name == "order_default":
parameters.order_default = boolean
elif name == "kw_only_default":
parameters.kw_only_default = boolean
wesleywright marked this conversation as resolved.
Show resolved Hide resolved
elif name == "frozen_default":
parameters.frozen_default = boolean
else:
self.fail(f'Unrecognized dataclass_transform parameter "{name}"', call)

return parameters


def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
if isinstance(sig, CallableType):
Expand Down Expand Up @@ -6651,21 +6682,3 @@ def halt(self, reason: str = ...) -> NoReturn:
return isinstance(stmt, PassStmt) or (
isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr)
)


def is_dataclass_transform_decorator(node: Node | None) -> bool:
if isinstance(node, RefExpr):
return is_dataclass_transform_decorator(node.node)
if isinstance(node, CallExpr):
# Like dataclasses.dataclass, transform-based decorators can be applied either with or
# without parameters; ie, both of these forms are accepted:
#
# @typing.dataclass_transform
# class Foo: ...
# @typing.dataclass_transform(eq=True, order=True, ...)
# class Bar: ...
#
# We need to unwrap the call for the second variant.
return is_dataclass_transform_decorator(node.callee)

return isinstance(node, Decorator) and node.func.is_dataclass_transform