Skip to content

Commit

Permalink
First pass over generics (#1079)
Browse files Browse the repository at this point in the history
* First pass over generics

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reformat comment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* More work on generics

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add test case

* Tweak condition

* Remove redundant code

* Add test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hynek Schlawack <hs@ox.cx>
  • Loading branch information
3 people committed Feb 5, 2023
1 parent 9cf2ed5 commit 4fcd15b
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 3 deletions.
8 changes: 8 additions & 0 deletions src/attr/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings

from collections.abc import Mapping, Sequence # noqa
from typing import _GenericAlias


PYPY = platform.python_implementation() == "PyPy"
Expand Down Expand Up @@ -174,3 +175,10 @@ def func():
# don't have a direct reference to the thread-local in their globals dict.
# If they have such a reference, it breaks cloudpickle.
repr_context = threading.local()


def get_generic_base(cl):
"""If this is a generic class (A[str]), return the generic base for it."""
if cl.__class__ is _GenericAlias:
return cl.__origin__
return None
15 changes: 14 additions & 1 deletion src/attr/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import copy

from ._compat import get_generic_base
from ._make import NOTHING, _obj_setattr, fields
from .exceptions import AttrsAttributeNotFoundError

Expand Down Expand Up @@ -296,7 +297,19 @@ def has(cls):
:rtype: bool
"""
return getattr(cls, "__attrs_attrs__", None) is not None
attrs = getattr(cls, "__attrs_attrs__", None)
if attrs is not None:
return True

# No attrs, maybe it's a specialized generic (A[str])?
generic_base = get_generic_base(cls)
if generic_base is not None:
generic_attrs = getattr(generic_base, "__attrs_attrs__", None)
if generic_attrs is not None:
# Stick it on here for speed next time.
cls.__attrs_attrs__ = generic_attrs
return generic_attrs is not None
return False


def assoc(inst, **changes):
Expand Down
23 changes: 21 additions & 2 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# We need to import _compat itself in addition to the _compat members to avoid
# having the thread-local in the globals here.
from . import _compat, _config, setters
from ._compat import PY310, _AnnotationExtractor, set_closure_cell
from ._compat import (
PY310,
_AnnotationExtractor,
get_generic_base,
set_closure_cell,
)
from .exceptions import (
DefaultAlreadySetError,
FrozenInstanceError,
Expand Down Expand Up @@ -1918,12 +1923,26 @@ def fields(cls):
.. versionchanged:: 16.2.0 Returned tuple allows accessing the fields
by name.
.. versionchanged:: 23.1.0 Add support for generic classes.
"""
if not isinstance(cls, type):
generic_base = get_generic_base(cls)

if generic_base is None and not isinstance(cls, type):
raise TypeError("Passed object must be a class.")

attrs = getattr(cls, "__attrs_attrs__", None)

if attrs is None:
if generic_base is not None:
attrs = getattr(generic_base, "__attrs_attrs__", None)
if attrs is not None:
# Even though this is global state, stick it on here to speed
# it up. We rely on `cls` being cached for this to be
# efficient.
cls.__attrs_attrs__ = attrs
return attrs
raise NotAnAttrsClassError(f"{cls!r} is not an attrs-decorated class.")

return attrs


Expand Down
32 changes: 32 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


from collections import OrderedDict
from typing import Generic, TypeVar

import pytest

Expand Down Expand Up @@ -418,6 +419,37 @@ def test_negative(self):
"""
assert not has(object)

def test_generics(self):
"""
Works with generic classes.
"""
T = TypeVar("T")

@attr.define
class A(Generic[T]):
a: T

assert has(A)

assert has(A[str])
# Verify twice, since there's caching going on.
assert has(A[str])

def test_generics_negative(self):
"""
Returns `False` on non-decorated generic classes.
"""
T = TypeVar("T")

class A(Generic[T]):
a: T

assert not has(A)

assert not has(A[str])
# Verify twice, since there's caching going on.
assert not has(A[str])


class TestAssoc:
"""
Expand Down
35 changes: 35 additions & 0 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys

from operator import attrgetter
from typing import Generic, TypeVar

import pytest

Expand Down Expand Up @@ -1114,6 +1115,22 @@ def test_handler_non_attrs_class(self):
f"{object!r} is not an attrs-decorated class."
) == e.value.args[0]

def test_handler_non_attrs_generic_class(self):
"""
Raises `ValueError` if passed a non-*attrs* generic class.
"""
T = TypeVar("T")

class B(Generic[T]):
pass

with pytest.raises(NotAnAttrsClassError) as e:
fields(B[str])

assert (
f"{B[str]!r} is not an attrs-decorated class."
) == e.value.args[0]

@given(simple_classes())
def test_fields(self, C):
"""
Expand All @@ -1129,6 +1146,24 @@ def test_fields_properties(self, C):
for attribute in fields(C):
assert getattr(fields(C), attribute.name) is attribute

def test_generics(self):
"""
Fields work with generic classes.
"""
T = TypeVar("T")

@attr.define
class A(Generic[T]):
a: T

assert len(fields(A)) == 1
assert fields(A).a.name == "a"
assert fields(A).a.default is attr.NOTHING

assert len(fields(A[str])) == 1
assert fields(A[str]).a.name == "a"
assert fields(A[str]).a.default is attr.NOTHING


class TestFieldsDict:
"""
Expand Down

0 comments on commit 4fcd15b

Please sign in to comment.