Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dataclass_transform] support default parameters #14580

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 62 additions & 9 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,13 +480,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_import_all(self)


FUNCBASE_FLAGS: Final = [
"is_property",
"is_class",
"is_static",
"is_final",
"is_dataclass_transform",
]
FUNCBASE_FLAGS: Final = ["is_property", "is_class", "is_static", "is_final"]


class FuncBase(Node):
Expand All @@ -512,7 +506,9 @@ class FuncBase(Node):
"is_static", # Uses "@staticmethod"
"is_final", # Uses "@final"
"_fullname",
"is_dataclass_transform", # Is decorated with "@typing.dataclass_transform" or similar
# Present when a function is decorated with "@typing.dataclass_transform" or similar, and
# records the parameters passed to typing.dataclass_transform for later use
"dataclass_transform_spec",
wesleywright marked this conversation as resolved.
Show resolved Hide resolved
)

def __init__(self) -> None:
Expand All @@ -531,7 +527,7 @@ def __init__(self) -> None:
self.is_final = False
# Name with module prefix
self._fullname = ""
self.is_dataclass_transform = False
self.dataclass_transform_spec: DataclassTransformSpec | None = None

@property
@abstractmethod
Expand Down Expand Up @@ -592,6 +588,11 @@ def serialize(self) -> JsonDict:
"fullname": self._fullname,
"impl": None if self.impl is None else self.impl.serialize(),
"flags": get_flags(self, FUNCBASE_FLAGS),
"dataclass_transform_spec": (
None
if self.dataclass_transform_spec is None
else self.dataclass_transform_spec.serialize()
),
}

@classmethod
Expand All @@ -610,6 +611,11 @@ def deserialize(cls, data: JsonDict) -> OverloadedFuncDef:
assert isinstance(typ, mypy.types.ProperType)
res.type = typ
res._fullname = data["fullname"]
res.dataclass_transform_spec = (
DataclassTransformSpec.deserialize(data["dataclass_transform_spec"])
if data["dataclass_transform_spec"] is not None
else None
)
set_flags(res, data["flags"])
# NOTE: res.info will be set in the fixup phase.
return res
Expand Down Expand Up @@ -810,6 +816,11 @@ def serialize(self) -> JsonDict:
"flags": get_flags(self, FUNCDEF_FLAGS),
"abstract_status": self.abstract_status,
# TODO: Do we need expanded, original_def?
"dataclass_transform_spec": (
None
if self.dataclass_transform_spec is None
else self.dataclass_transform_spec.serialize()
),
}

@classmethod
Expand All @@ -832,6 +843,11 @@ def deserialize(cls, data: JsonDict) -> FuncDef:
ret.arg_names = data["arg_names"]
ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]]
ret.abstract_status = data["abstract_status"]
ret.dataclass_transform_spec = (
DataclassTransformSpec.deserialize(data["dataclass_transform_spec"])
if data["dataclass_transform_spec"] is not None
else None
)
# Leave these uninitialized so that future uses will trigger an error
del ret.arguments
del ret.max_pos
Expand Down Expand Up @@ -3851,6 +3867,43 @@ def deserialize(cls, data: JsonDict) -> SymbolTable:
return st


class DataclassTransformSpec:
wesleywright marked this conversation as resolved.
Show resolved Hide resolved
"""Specifies how a dataclass-like transform should be applied. The fields here are based on the
parameters accepted by `typing.dataclass_transform`."""

__slots__ = ("eq_default", "order_default", "kw_only_default", "field_specifiers")

def __init__(
self,
*,
eq_default: bool | None = None,
order_default: bool | None = None,
kw_only_default: bool | None = None,
field_specifiers: tuple[str, ...] | None = None,
):
self.eq_default = eq_default if eq_default is not None else True
self.order_default = order_default if order_default is not None else False
self.kw_only_default = kw_only_default if kw_only_default is not None else False
self.field_specifiers = field_specifiers if field_specifiers is not None else ()

def serialize(self) -> JsonDict:
return {
"eq_default": self.eq_default,
"order_default": self.order_default,
"kw_only_default": self.kw_only_default,
"field_specifiers": self.field_specifiers,
}

@classmethod
def deserialize(cls, data: JsonDict) -> DataclassTransformSpec:
return DataclassTransformSpec(
eq_default=data.get("eq_default"),
order_default=data.get("order_default"),
kw_only_default=data.get("kw_only_default"),
field_specifiers=data.get("field_specifiers"),
)


def get_flags(node: Node, names: list[str]) -> list[str]:
return [name for name in names if getattr(node, name)]

Expand Down
103 changes: 69 additions & 34 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
AssignmentStmt,
CallExpr,
Context,
DataclassTransformSpec,
Expression,
JsonDict,
NameExpr,
Node,
PlaceholderNode,
RefExpr,
SymbolTableNode,
Expand All @@ -37,6 +39,7 @@
add_method,
deserialize_and_fixup_type,
)
from mypy.semanal_shared import find_dataclass_transform_spec
from mypy.server.trigger import make_wildcard_trigger
from mypy.state import state
from mypy.typeops import map_type_from_supertype
Expand All @@ -56,11 +59,15 @@

# The set of decorators that generate dataclasses.
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}
# The set of functions that generate dataclass fields.
field_makers: Final = {"dataclasses.field"}


SELF_TVAR_NAME: Final = "_DT"
_TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec(
eq_default=True,
order_default=False,
kw_only_default=False,
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)


class DataclassAttribute:
Expand Down Expand Up @@ -155,6 +162,7 @@ class DataclassTransformer:

def __init__(self, ctx: ClassDefContext) -> None:
self._ctx = ctx
self._spec = _get_transform_spec(ctx.reason)

def transform(self) -> bool:
"""Apply all the necessary transformations to the underlying
Expand All @@ -172,8 +180,8 @@ def transform(self) -> bool:
return False
decorator_arguments = {
"init": _get_decorator_bool_argument(self._ctx, "init", True),
"eq": _get_decorator_bool_argument(self._ctx, "eq", True),
"order": _get_decorator_bool_argument(self._ctx, "order", False),
"eq": _get_decorator_bool_argument(self._ctx, "eq", self._spec.eq_default),
"order": _get_decorator_bool_argument(self._ctx, "order", self._spec.order_default),
"frozen": _get_decorator_bool_argument(self._ctx, "frozen", False),
"slots": _get_decorator_bool_argument(self._ctx, "slots", False),
"match_args": _get_decorator_bool_argument(self._ctx, "match_args", True),
Expand Down Expand Up @@ -411,7 +419,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:

# Second, collect attributes belonging to the current class.
current_attr_names: set[str] = set()
kw_only = _get_decorator_bool_argument(ctx, "kw_only", False)
kw_only = _get_decorator_bool_argument(ctx, "kw_only", self._spec.kw_only_default)
for stmt in cls.defs.body:
# Any assignment that doesn't use the new type declaration
# syntax can be ignored out of hand.
Expand Down Expand Up @@ -461,7 +469,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
if self._is_kw_only_type(node_type):
kw_only = True

has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)
has_field_call, field_args = self._collect_field_args(stmt.rvalue, ctx)

is_in_init_param = field_args.get("init")
if is_in_init_param is None:
Expand Down Expand Up @@ -614,6 +622,36 @@ def _add_dataclass_fields_magic_attribute(self) -> None:
kind=MDEF, node=var, plugin_generated=True
)

def _collect_field_args(
self, expr: Expression, ctx: ClassDefContext
) -> tuple[bool, dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
"""
if (
isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)
and expr.callee.fullname in self._spec.field_specifiers
):
# field() only takes keyword arguments.
args = {}
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
if not kind.is_named():
if kind.is_named(star=True):
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
else:
message = '"field()" does not accept positional arguments'
ctx.api.fail(message, expr)
return True, {}
assert name is not None
args[name] = arg
return True, args
return False, {}


def dataclass_tag_callback(ctx: ClassDefContext) -> None:
"""Record that we have a dataclass in the main semantic analysis pass.
Expand All @@ -631,32 +669,29 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
return transformer.transform()


def _collect_field_args(
expr: Expression, ctx: ClassDefContext
) -> tuple[bool, dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
def _get_transform_spec(reason: Expression) -> DataclassTransformSpec:
"""Find the relevant transform parameters from the decorator/parent class/metaclass that
triggered the dataclasses plugin.

Although the resulting DataclassTransformSpec is based on the typing.dataclass_transform
function, we also use it for traditional dataclasses.dataclass classes as well for simplicity.
In those cases, we return a default spec rather than one based on a call to
`typing.dataclass_transform`.
"""
if (
isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)
and expr.callee.fullname in field_makers
):
# field() only takes keyword arguments.
args = {}
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
if not kind.is_named():
if kind.is_named(star=True):
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
else:
message = '"field()" does not accept positional arguments'
ctx.api.fail(message, expr)
return True, {}
assert name is not None
args[name] = arg
return True, args
return False, {}
if _is_dataclasses_decorator(reason):
return _TRANSFORM_SPEC_FOR_DATACLASSES

spec = find_dataclass_transform_spec(reason)
assert spec is not None, (
"trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor "
"decorated with typing.dataclass_transform"
)
return spec


def _is_dataclasses_decorator(node: Node) -> bool:
if isinstance(node, CallExpr):
node = node.callee
if isinstance(node, RefExpr):
return node.fullname in dataclass_makers
return False