Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass over generics #1079

Merged
merged 14 commits into from Feb 5, 2023
8 changes: 8 additions & 0 deletions src/attr/_compat.py
Expand Up @@ -9,6 +9,7 @@
import warnings

from collections.abc import Mapping, Sequence # noqa
from typing import _GenericAlias


PYPY = platform.python_implementation() == "PyPy"
Expand Down Expand Up @@ -174,3 +175,10 @@ def func():
# don't have a direct reference to the thread-local in their globals dict.
# If they have such a reference, it breaks cloudpickle.
repr_context = threading.local()


def get_generic_base(cl):
"""If this is a generic class (A[str]), return the generic base for it."""
if cl.__class__ is _GenericAlias:
return cl.__origin__
return None
15 changes: 14 additions & 1 deletion src/attr/_funcs.py
Expand Up @@ -3,6 +3,7 @@

import copy

from ._compat import get_generic_base
from ._make import NOTHING, _obj_setattr, fields
from .exceptions import AttrsAttributeNotFoundError

Expand Down Expand Up @@ -296,7 +297,19 @@ def has(cls):

:rtype: bool
"""
return getattr(cls, "__attrs_attrs__", None) is not None
attrs = getattr(cls, "__attrs_attrs__", None)
if attrs is not None:
return True

# No attrs, maybe it's a specialized generic (A[str])?
generic_base = get_generic_base(cls)
if generic_base is not None:
generic_attrs = getattr(generic_base, "__attrs_attrs__", None)
if generic_attrs is not None:
# Stick it on here for speed next time.
cls.__attrs_attrs__ = generic_attrs
return generic_attrs is not None
return False


def assoc(inst, **changes):
Expand Down
23 changes: 21 additions & 2 deletions src/attr/_make.py
Expand Up @@ -12,7 +12,12 @@
# We need to import _compat itself in addition to the _compat members to avoid
# having the thread-local in the globals here.
from . import _compat, _config, setters
from ._compat import PY310, _AnnotationExtractor, set_closure_cell
from ._compat import (
PY310,
_AnnotationExtractor,
get_generic_base,
set_closure_cell,
)
from .exceptions import (
DefaultAlreadySetError,
FrozenInstanceError,
Expand Down Expand Up @@ -1918,12 +1923,26 @@ def fields(cls):

.. versionchanged:: 16.2.0 Returned tuple allows accessing the fields
by name.
.. versionchanged:: 23.1.0 Add support for generic classes.
"""
if not isinstance(cls, type):
generic_base = get_generic_base(cls)

if generic_base is None and not isinstance(cls, type):
raise TypeError("Passed object must be a class.")

attrs = getattr(cls, "__attrs_attrs__", None)

if attrs is None:
if generic_base is not None:
attrs = getattr(generic_base, "__attrs_attrs__", None)
if attrs is not None:
# Even though this is global state, stick it on here to speed
# it up. We rely on `cls` being cached for this to be
# efficient.
cls.__attrs_attrs__ = attrs
return attrs
raise NotAnAttrsClassError(f"{cls!r} is not an attrs-decorated class.")

return attrs


Expand Down
32 changes: 32 additions & 0 deletions tests/test_funcs.py
Expand Up @@ -6,6 +6,7 @@


from collections import OrderedDict
from typing import Generic, TypeVar

import pytest

Expand Down Expand Up @@ -418,6 +419,37 @@ def test_negative(self):
"""
assert not has(object)

def test_generics(self):
"""
Works with generic classes.
"""
T = TypeVar("T")

@attr.define
class A(Generic[T]):
a: T

assert has(A)

assert has(A[str])
# Verify twice, since there's caching going on.
assert has(A[str])

def test_generics_negative(self):
"""
Returns `False` on non-decorated generic classes.
"""
T = TypeVar("T")

class A(Generic[T]):
a: T

assert not has(A)

assert not has(A[str])
# Verify twice, since there's caching going on.
assert not has(A[str])


class TestAssoc:
"""
Expand Down
35 changes: 35 additions & 0 deletions tests/test_make.py
Expand Up @@ -13,6 +13,7 @@
import sys

from operator import attrgetter
from typing import Generic, TypeVar

import pytest

Expand Down Expand Up @@ -1114,6 +1115,22 @@ def test_handler_non_attrs_class(self):
f"{object!r} is not an attrs-decorated class."
) == e.value.args[0]

def test_handler_non_attrs_generic_class(self):
"""
Raises `ValueError` if passed a non-*attrs* generic class.
"""
T = TypeVar("T")

class B(Generic[T]):
pass

with pytest.raises(NotAnAttrsClassError) as e:
fields(B[str])

assert (
f"{B[str]!r} is not an attrs-decorated class."
) == e.value.args[0]

@given(simple_classes())
def test_fields(self, C):
"""
Expand All @@ -1129,6 +1146,24 @@ def test_fields_properties(self, C):
for attribute in fields(C):
assert getattr(fields(C), attribute.name) is attribute

def test_generics(self):
"""
Fields work with generic classes.
"""
T = TypeVar("T")

@attr.define
class A(Generic[T]):
a: T

assert len(fields(A)) == 1
assert fields(A).a.name == "a"
assert fields(A).a.default is attr.NOTHING

assert len(fields(A[str])) == 1
assert fields(A[str]).a.name == "a"
assert fields(A[str]).a.default is attr.NOTHING


class TestFieldsDict:
"""
Expand Down