Skip to content

Commit

Permalink
[3.9] bpo-42345: Fix three issues with typing.Literal parameters (GH-…
Browse files Browse the repository at this point in the history
…23294) (GH-23335)

Literal equality no longer depends on the order of arguments.

Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function.

Add deduplication of `typing.Literal` arguments.

(cherry picked from commit f03d318)

Co-authored-by: Yurii Karabas <1998uriyyo@gmail.com>
  • Loading branch information
uriyyo committed Nov 17, 2020
1 parent 656d50f commit ac472b3
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 23 deletions.
25 changes: 25 additions & 0 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def test_repr(self):
self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
self.assertEqual(repr(Literal), "typing.Literal")
self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")

def test_cannot_init(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -563,6 +564,30 @@ def test_no_multiple_subscripts(self):
with self.assertRaises(TypeError):
Literal[1][1]

def test_equal(self):
self.assertNotEqual(Literal[0], Literal[False])
self.assertNotEqual(Literal[True], Literal[1])
self.assertNotEqual(Literal[1], Literal[2])
self.assertNotEqual(Literal[1, True], Literal[1])
self.assertEqual(Literal[1], Literal[1])
self.assertEqual(Literal[1, 2], Literal[2, 1])
self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])

def test_args(self):
self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3))
self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3))
self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4))
# Mutable arguments will not be deduplicated
self.assertEqual(Literal[[], []].__args__, ([], []))

def test_flatten(self):
l1 = Literal[Literal[1], Literal[2], Literal[3]]
l2 = Literal[Literal[1, 2], 3]
l3 = Literal[Literal[1, 2, 3]]
for l in l1, l2, l3:
self.assertEqual(l, Literal[1, 2, 3])
self.assertEqual(l.__args__, (1, 2, 3))


XK = TypeVar('XK', str, bytes)
XV = TypeVar('XV')
Expand Down
100 changes: 77 additions & 23 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,20 @@ def _check_generic(cls, parameters, elen):
f" actual {alen}, expected {elen}")


def _deduplicate(params):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params


def _remove_dups_flatten(parameters):
"""An internal helper for Union creation and substitution: flatten Unions
among parameters, then remove duplicates.
Expand All @@ -213,38 +227,45 @@ def _remove_dups_flatten(parameters):
params.extend(p[1:])
else:
params.append(p)
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params

return tuple(_deduplicate(params))


def _flatten_literal_params(parameters):
"""An internal helper for Literal creation: flatten Literals among parameters"""
params = []
for p in parameters:
if isinstance(p, _LiteralGenericAlias):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)


_cleanups = []


def _tp_cache(func):
def _tp_cache(func=None, /, *, typed=False):
"""Internal wrapper caching __getitem__ of generic types with a fallback to
original function for non-hashable arguments.
"""
cached = functools.lru_cache()(func)
_cleanups.append(cached.cache_clear)
def decorator(func):
cached = functools.lru_cache(typed=typed)(func)
_cleanups.append(cached.cache_clear)

@functools.wraps(func)
def inner(*args, **kwds):
try:
return cached(*args, **kwds)
except TypeError:
pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
return inner
@functools.wraps(func)
def inner(*args, **kwds):
try:
return cached(*args, **kwds)
except TypeError:
pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
return inner

if func is not None:
return decorator(func)

return decorator

def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
"""Evaluate all forward references in the given type t.
Expand Down Expand Up @@ -317,6 +338,13 @@ def __subclasscheck__(self, cls):
def __getitem__(self, parameters):
return self._getitem(self, parameters)


class _LiteralSpecialForm(_SpecialForm, _root=True):
@_tp_cache(typed=True)
def __getitem__(self, parameters):
return self._getitem(self, parameters)


@_SpecialForm
def Any(self, parameters):
"""Special type indicating an unconstrained type.
Expand Down Expand Up @@ -434,7 +462,7 @@ def Optional(self, parameters):
arg = _type_check(parameters, f"{self} requires a single type.")
return Union[arg, type(None)]

@_SpecialForm
@_LiteralSpecialForm
def Literal(self, parameters):
"""Special typing form to define literal types (a.k.a. value types).
Expand All @@ -458,7 +486,17 @@ def open_helper(file: str, mode: MODE) -> str:
"""
# There is no '_type_check' call because arguments to Literal[...] are
# values, not types.
return _GenericAlias(self, parameters)
if not isinstance(parameters, tuple):
parameters = (parameters,)

parameters = _flatten_literal_params(parameters)

try:
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
except TypeError: # unhashable parameters
pass

return _LiteralGenericAlias(self, parameters)


class ForwardRef(_Final, _root=True):
Expand Down Expand Up @@ -881,6 +919,22 @@ def __repr__(self):
return super().__repr__()


def _value_and_type_iter(parameters):
return ((p, type(p)) for p in parameters)


class _LiteralGenericAlias(_GenericAlias, _root=True):

def __eq__(self, other):
if not isinstance(other, _LiteralGenericAlias):
return NotImplemented

return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))

def __hash__(self):
return hash(tuple(_value_and_type_iter(self.__args__)))


class Generic:
"""Abstract base class for generic types.
Expand Down
1 change: 1 addition & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ Jan Kanis
Rafe Kaplan
Jacob Kaplan-Moss
Allison Kaptur
Yurii Karabas
Janne Karila
Per Øyvind Karlsen
Anton Kasyanov
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix various issues with ``typing.Literal`` parameter handling (flatten,
deduplicate, use type to cache key). Patch provided by Yurii Karabas.

0 comments on commit ac472b3

Please sign in to comment.