Skip to content

Commit

Permalink
Fix Literal bug with typing-extension==4.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
hramezani committed May 23, 2023
1 parent 12ba3f4 commit 7f0a4d5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
10 changes: 9 additions & 1 deletion pydantic/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import typing
from collections.abc import Callable
from os import PathLike
from typing import ( # type: ignore
Expand All @@ -22,6 +23,7 @@
TypeVar,
Union,
_eval_type,
_SpecialForm as TypingSpecialForm,
cast,
get_type_hints,
)
Expand All @@ -32,6 +34,7 @@
Literal,
NotRequired as TypedDictNotRequired,
Required as TypedDictRequired,
_SpecialForm as TypingExtensionSpecialForm,
)

try:
Expand Down Expand Up @@ -91,6 +94,11 @@ def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> A
AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'}


LITERAL_TYPES: Set[Union[TypingSpecialForm, TypingExtensionSpecialForm]] = {Literal}
if hasattr(typing, 'Literal'):
LITERAL_TYPES.add(typing.Literal)


if sys.version_info < (3, 8):

def get_origin(t: Type[Any]) -> Optional[Type[Any]]:
Expand Down Expand Up @@ -415,7 +423,7 @@ def is_callable_type(type_: Type[Any]) -> bool:


def is_literal_type(type_: Type[Any]) -> bool:
return Literal is not None and get_origin(type_) is Literal
return Literal is not None and get_origin(type_) in LITERAL_TYPES


def literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3440,3 +3440,22 @@ class Model(BaseModel):
'type': 'value_error.date.not_in_the_future',
}
]


def test_typing_extension_literal_field():
from typing_extensions import Literal

class Model(BaseModel):
foo: Literal['foo']

assert Model(foo='foo').foo == 'foo'


@pytest.mark.skipif(sys.version_info < (3, 8), reason='`typing.Literal` is available for python 3.8 and above.')
def test_typing_literal_field():
from typing import Literal

class Model(BaseModel):
foo: Literal['foo']

assert Model(foo='foo').foo == 'foo'
17 changes: 16 additions & 1 deletion tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Annotated # noqa: F401

from pydantic import Field # noqa: F401
from pydantic.typing import Literal, convert_generics, is_namedtuple, is_none_type, is_typeddict
from pydantic.typing import Literal, convert_generics, is_literal_type, is_namedtuple, is_none_type, is_typeddict

try:
from typing import TypedDict as typing_TypedDict
Expand Down Expand Up @@ -124,3 +124,18 @@ def test_convert_generics_pep604():
assert (
convert_generics(dict['Hero', list['Team']] | int) == dict[ForwardRef('Hero'), list[ForwardRef('Team')]] | int
)


def test_is_literal__with_typing_extension_literal():
from typing_extensions import Literal

assert is_literal_type(Literal) is False
assert is_literal_type(Literal['foo']) is True


@pytest.mark.skipif(sys.version_info < (3, 8), reason='`typing.Literal` is available for python 3.8 and above.')
def test_is_literal_with_typing_literal():
from typing import Literal

assert is_literal_type(Literal) is False
assert is_literal_type(Literal['foo']) is True

0 comments on commit 7f0a4d5

Please sign in to comment.