Skip to content

Commit

Permalink
Have Protocol inherit from typing.Generic on 3.8+ (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed May 24, 2023
1 parent b306e56 commit 88a7f68
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 129 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

- Change deprecated `@runtime` to formal API `@runtime_checkable` in the error
message. Patch by Xuehai Pan.
- Fix regression in 4.6.0 where attempting to define a `Protocol` that was
generic over a `ParamSpec` or a `TypeVarTuple` would cause `TypeError` to be
raised. Patch by Alex Waygood.

# Release 4.6.0 (May 22, 2023)

Expand Down
120 changes: 101 additions & 19 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2613,6 +2613,62 @@ class CustomProtocolWithoutInitB(Protocol):

self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__)

def test_protocol_generic_over_paramspec(self):
P = ParamSpec("P")
T = TypeVar("T")
T2 = TypeVar("T2")

class MemoizedFunc(Protocol[P, T, T2]):
cache: typing.Dict[T2, T]
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

self.assertEqual(MemoizedFunc.__parameters__, (P, T, T2))
self.assertTrue(MemoizedFunc._is_protocol)

with self.assertRaises(TypeError):
MemoizedFunc[[int, str, str]]

if sys.version_info >= (3, 10):
# These unfortunately don't pass on <=3.9,
# due to typing._type_check on older Python versions
X = MemoizedFunc[[int, str, str], T, T2]
self.assertEqual(X.__parameters__, (T, T2))
self.assertEqual(X.__args__, ((int, str, str), T, T2))

Y = X[bytes, memoryview]
self.assertEqual(Y.__parameters__, ())
self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview))

def test_protocol_generic_over_typevartuple(self):
Ts = TypeVarTuple("Ts")
T = TypeVar("T")
T2 = TypeVar("T2")

class MemoizedFunc(Protocol[Unpack[Ts], T, T2]):
cache: typing.Dict[T2, T]
def __call__(self, *args: Unpack[Ts]) -> T: ...

self.assertEqual(MemoizedFunc.__parameters__, (Ts, T, T2))
self.assertTrue(MemoizedFunc._is_protocol)

things = "arguments" if sys.version_info >= (3, 11) else "parameters"

# A bug was fixed in 3.11.1
# (https://github.com/python/cpython/commit/74920aa27d0c57443dd7f704d6272cca9c507ab3)
# That means this assertion doesn't pass on 3.11.0,
# but it passes on all other Python versions
if sys.version_info[:3] != (3, 11, 0):
with self.assertRaisesRegex(TypeError, f"Too few {things}"):
MemoizedFunc[int]

X = MemoizedFunc[int, T, T2]
self.assertEqual(X.__parameters__, (T, T2))
self.assertEqual(X.__args__, (int, T, T2))

Y = X[bytes, memoryview]
self.assertEqual(Y.__parameters__, ())
self.assertEqual(Y.__args__, (int, bytes, memoryview))


class Point2DGeneric(Generic[T], TypedDict):
a: T
Expand Down Expand Up @@ -3402,13 +3458,18 @@ def test_user_generics(self):
class X(Generic[T, P]):
pass

G1 = X[int, P_2]
self.assertEqual(G1.__args__, (int, P_2))
self.assertEqual(G1.__parameters__, (P_2,))
class Y(Protocol[T, P]):
pass

for klass in X, Y:
with self.subTest(klass=klass.__name__):
G1 = klass[int, P_2]
self.assertEqual(G1.__args__, (int, P_2))
self.assertEqual(G1.__parameters__, (P_2,))

G2 = X[int, Concatenate[int, P_2]]
self.assertEqual(G2.__args__, (int, Concatenate[int, P_2]))
self.assertEqual(G2.__parameters__, (P_2,))
G2 = klass[int, Concatenate[int, P_2]]
self.assertEqual(G2.__args__, (int, Concatenate[int, P_2]))
self.assertEqual(G2.__parameters__, (P_2,))

# The following are some valid uses cases in PEP 612 that don't work:
# These do not work in 3.9, _type_check blocks the list and ellipsis.
Expand All @@ -3421,6 +3482,9 @@ class X(Generic[T, P]):
class Z(Generic[P]):
pass

class ProtoZ(Protocol[P]):
pass

def test_pickle(self):
global P, P_co, P_contra, P_default
P = ParamSpec('P')
Expand Down Expand Up @@ -3727,31 +3791,49 @@ def test_concatenation(self):
self.assertEqual(Tuple[int, Unpack[Xs], str].__args__,
(int, Unpack[Xs], str))
class C(Generic[Unpack[Xs]]): pass
self.assertEqual(C[int, Unpack[Xs]].__args__, (int, Unpack[Xs]))
self.assertEqual(C[Unpack[Xs], int].__args__, (Unpack[Xs], int))
self.assertEqual(C[int, Unpack[Xs], str].__args__,
(int, Unpack[Xs], str))
class D(Protocol[Unpack[Xs]]): pass
for klass in C, D:
with self.subTest(klass=klass.__name__):
self.assertEqual(klass[int, Unpack[Xs]].__args__, (int, Unpack[Xs]))
self.assertEqual(klass[Unpack[Xs], int].__args__, (Unpack[Xs], int))
self.assertEqual(klass[int, Unpack[Xs], str].__args__,
(int, Unpack[Xs], str))

def test_class(self):
Ts = TypeVarTuple('Ts')

class C(Generic[Unpack[Ts]]): pass
self.assertEqual(C[int].__args__, (int,))
self.assertEqual(C[int, str].__args__, (int, str))
class D(Protocol[Unpack[Ts]]): pass

for klass in C, D:
with self.subTest(klass=klass.__name__):
self.assertEqual(klass[int].__args__, (int,))
self.assertEqual(klass[int, str].__args__, (int, str))

with self.assertRaises(TypeError):
class C(Generic[Unpack[Ts], int]): pass

with self.assertRaises(TypeError):
class D(Protocol[Unpack[Ts], int]): pass

T1 = TypeVar('T')
T2 = TypeVar('T')
class C(Generic[T1, T2, Unpack[Ts]]): pass
self.assertEqual(C[int, str].__args__, (int, str))
self.assertEqual(C[int, str, float].__args__, (int, str, float))
self.assertEqual(C[int, str, float, bool].__args__, (int, str, float, bool))
# TODO This should probably also fail on 3.11, pending changes to CPython.
if not TYPING_3_11_0:
with self.assertRaises(TypeError):
C[int]
class D(Protocol[T1, T2, Unpack[Ts]]): pass
for klass in C, D:
with self.subTest(klass=klass.__name__):
self.assertEqual(klass[int, str].__args__, (int, str))
self.assertEqual(klass[int, str, float].__args__, (int, str, float))
self.assertEqual(
klass[int, str, float, bool].__args__, (int, str, float, bool)
)
# A bug was fixed in 3.11.1
# (https://github.com/python/cpython/commit/74920aa27d0c57443dd7f704d6272cca9c507ab3)
# That means this assertion doesn't pass on 3.11.0,
# but it passes on all other Python versions
if sys.version_info[:3] != (3, 11, 0):
with self.assertRaises(TypeError):
klass[int]


class TypeVarTupleTests(BaseTestCase):
Expand Down

0 comments on commit 88a7f68

Please sign in to comment.