Skip to content

Commit

Permalink
Refactor to support typevars, and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinche committed Apr 9, 2023
1 parent bd33c7a commit f13a463
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 28 deletions.
53 changes: 34 additions & 19 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Iterable, List, Optional, cast
from typing import Iterable, List, cast
from typing_extensions import Final, Literal

import mypy.plugin # To avoid circular imports.
Expand Down Expand Up @@ -43,7 +43,7 @@
Var,
is_class_var,
)
from mypy.plugin import FunctionContext, SemanticAnalyzerPluginInterface
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
_get_argument,
_get_bool_argument,
Expand Down Expand Up @@ -990,27 +990,42 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
)


def _get_cls_from_init(t: Type) -> Optional[TypeInfo]:
proper_type = get_proper_type(t)
if isinstance(proper_type, CallableType):
return proper_type.type_object()
return None
def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
"""Provide the proper signature for `attrs.fields`."""
if ctx.args and len(ctx.args) == 1 and ctx.args[0] and ctx.args[0][0]:

# <hack>
assert isinstance(ctx.api, TypeChecker)
inst_type = ctx.api.expr_checker.accept(ctx.args[0][0])
# </hack>
proper_type = get_proper_type(inst_type)

if isinstance(proper_type, AnyType): # fields(Any) -> Any
return ctx.default_signature

cls = None
arg_types = ctx.default_signature.arg_types

if isinstance(proper_type, TypeVarType):
inner = get_proper_type(proper_type.upper_bound)
if isinstance(inner, Instance):
# We need to work arg_types to compensate for the attrs stubs.
arg_types = [inst_type]
cls = inner.type
elif isinstance(proper_type, CallableType):
cls = proper_type.type_object()

def fields_function_callback(ctx: FunctionContext) -> Type:
"""Provide the proper return value for `attrs.fields`."""
if ctx.arg_types and ctx.arg_types[0] and ctx.arg_types[0][0]:
first_arg_type = ctx.arg_types[0][0]
cls = _get_cls_from_init(first_arg_type)
if cls is not None:
if MAGIC_ATTR_NAME in cls.names:
# This is a proper attrs class.
ret_type = cls.names[MAGIC_ATTR_NAME].type
if ret_type is not None:
return ret_type
else:
ctx.api.fail(
f'Argument 1 to "fields" has incompatible type "{format_type_bare(first_arg_type)}"; expected an attrs class',
ctx.context,
)
return ctx.default_return_type
return ctx.default_signature.copy_modified(
arg_types=arg_types, ret_type=ret_type
)

ctx.api.fail(
f'Argument 1 to "fields" has incompatible type "{format_type_bare(proper_type)}"; expected an attrs class',
ctx.context,
)
return ctx.default_signature
7 changes: 4 additions & 3 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@ class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""

def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
from mypy.plugins import attrs, ctypes, singledispatch
from mypy.plugins import ctypes, singledispatch

if fullname == "ctypes.Array":
return ctypes.array_constructor_callback
elif fullname == "functools.singledispatch":
return singledispatch.create_singledispatch_function_callback
elif fullname in ("attr.fields", "attrs.fields"):
return attrs.fields_function_callback

return None

def get_function_signature_hook(
Expand All @@ -56,6 +55,8 @@ def get_function_signature_hook(

if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
return attrs.evolve_function_sig_callback
elif fullname in ("attr.fields", "attrs.fields"):
return attrs.fields_function_sig_callback
return None

def get_method_signature_hook(
Expand Down
33 changes: 29 additions & 4 deletions test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,24 @@ takes_attrs_instance(A) # E: Argument 1 to "takes_attrs_instance" has incompati
[builtins fixtures/attr.pyi]

[case testAttrsFields]
import attr
from attrs import fields as f # Common usage.

@attr.define
class A:
b: int
c: str

reveal_type(f(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
reveal_type(f(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
reveal_type(f(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
f(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"

[builtins fixtures/attr.pyi]

[case testAttrsGenericFields]
from typing import TypeVar

import attr
from attrs import fields

Expand All @@ -1557,21 +1575,28 @@ class A:
b: int
c: str

reveal_type(fields(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
reveal_type(fields(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
reveal_type(fields(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
fields(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
TA = TypeVar('TA', bound=A)

def f(t: TA) -> None:
reveal_type(fields(t)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
reveal_type(fields(t)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
reveal_type(fields(t).b) # N: Revealed type is "attr.Attribute[builtins.int]"
fields(t).x # E: "____main___A_AttrsAttributes__" has no attribute "x"


[builtins fixtures/attr.pyi]

[case testNonattrsFields]
from typing import Any, cast
from attrs import fields

class A:
b: int
c: str

fields(A) # E: Argument 1 to "fields" has incompatible type "Type[A]"; expected an attrs class
fields(None) # E: Argument 1 to "fields" has incompatible type "None"; expected an attrs class
fields(cast(Any, 42))

[builtins fixtures/attr.pyi]

Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/lib-stub/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,4 @@ def field(
def evolve(inst: _T, **changes: Any) -> _T: ...
def assoc(inst: _T, **changes: Any) -> _T: ...

def fields(cls: _C) -> Any: ...
def fields(cls: type) -> Any: ...
2 changes: 1 addition & 1 deletion test-data/unit/lib-stub/attrs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ def field(
def evolve(inst: _T, **changes: Any) -> _T: ...
def assoc(inst: _T, **changes: Any) -> _T: ...

def fields(cls: _C) -> Any: ...
def fields(cls: type) -> Any: ...

0 comments on commit f13a463

Please sign in to comment.