Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix broken parametrized bases with GenericModels #5052

Merged
merged 18 commits into from
Feb 15, 2023
Merged
1 change: 1 addition & 0 deletions changes/5052-MarkusSintonen.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix broken parametrized bases handling with `GenericModel`s with complex sets of models.
14 changes: 11 additions & 3 deletions pydantic/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Union,
cast,
)
from weakref import WeakKeyDictionary, WeakValueDictionary

from typing_extensions import Annotated

Expand All @@ -25,7 +26,7 @@
from .main import BaseModel, create_model
from .types import JsonWrapper
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
from .utils import LimitedDict, all_identical, lenient_issubclass
from .utils import all_identical, lenient_issubclass

if sys.version_info >= (3, 10):
from typing import _UnionGenericAlias
Expand All @@ -35,13 +36,20 @@

Parametrization = Mapping[TypeVarType, Type[Any]]

_generic_types_cache: LimitedDict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = LimitedDict()
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
GenericTypesCache = WeakValueDictionary[Tuple[Type[Any], Any, Tuple[Any, ...]], Type[BaseModel]]
AssignedParameters = WeakKeyDictionary[Type[BaseModel], Parametrization]
else:
GenericTypesCache = WeakValueDictionary
AssignedParameters = WeakKeyDictionary

_generic_types_cache = GenericTypesCache()
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
# as captured during construction of the class (not instances).
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
# (This information is only otherwise available after creation from the class name string).
_assigned_parameters: LimitedDict[Type[Any], Parametrization] = LimitedDict()
_assigned_parameters = AssignedParameters()


class GenericModel(BaseModel):
Expand Down
73 changes: 66 additions & 7 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import json
import sys
from enum import Enum
Expand All @@ -21,8 +22,19 @@
import pytest
from typing_extensions import Annotated, Literal

from pydantic import BaseModel, Field, Json, ValidationError, root_validator, validator
from pydantic.generics import GenericModel, _generic_types_cache, iter_contained_typevars, replace_types
from pydantic import BaseModel, Field, Json, ValidationError, create_model, root_validator, validator
from pydantic.generics import (
GenericModel,
_assigned_parameters,
_generic_types_cache,
iter_contained_typevars,
replace_types,
)


@pytest.fixture(autouse=True)
def clean_cache():
gc.collect() # cleans up _generic_types_cache for checking item counts in the cache


def test_generic_name():
Expand Down Expand Up @@ -229,10 +241,13 @@ def test_cover_cache():
class Model(GenericModel, Generic[T]):
x: T

Model[int] # adds both with-tuple and without-tuple version to cache
models = [] # keep references to models to get cache size

models.append(Model[int]) # adds both with-tuple and without-tuple version to cache
assert len(_generic_types_cache) == cache_size + 2
Model[int] # uses the cache
models.append(Model[int]) # uses the cache
assert len(_generic_types_cache) == cache_size + 2
del models


def test_cache_keys_are_hashable():
Expand All @@ -246,19 +261,63 @@ class MyGenericModel(GenericModel, Generic[T]):
# Callable's first params get converted to a list, which is not hashable.
# Make sure we can handle that special case
Simple = MyGenericModel[Callable[[int], str]]
models = [] # keep references to models to get cache size
models.append(Simple)
assert len(_generic_types_cache) == cache_size + 2
# Nested Callables
MyGenericModel[Callable[[C], Iterable[str]]]
models.append(MyGenericModel[Callable[[C], Iterable[str]]])
assert len(_generic_types_cache) == cache_size + 4
MyGenericModel[Callable[[Simple], Iterable[int]]]
models.append(MyGenericModel[Callable[[Simple], Iterable[int]]])
assert len(_generic_types_cache) == cache_size + 6
MyGenericModel[Callable[[MyGenericModel[C]], Iterable[int]]]
models.append(MyGenericModel[Callable[[MyGenericModel[C]], Iterable[int]]])
assert len(_generic_types_cache) == cache_size + 10

class Model(BaseModel):
x: MyGenericModel[Callable[[C], Iterable[str]]] = Field(...)

models.append(Model)
assert len(_generic_types_cache) == cache_size + 10
del models


def test_cache_gets_cleaned_up():
cache_size = len(_generic_types_cache)
T = TypeVar('T')

class Model(GenericModel, Generic[T]):
x: T

model = Model[int]
assert len(_generic_types_cache) == cache_size + 2
del model
gc.collect()
assert len(_generic_types_cache) == cache_size


def test_generics_work_with_many_parametrized_base_models():
cache_size = len(_generic_types_cache)
params_size = len(_assigned_parameters)
count_create_models = 1000
T = TypeVar('T')
C = TypeVar('C')

class A(GenericModel, Generic[T, C]):
x: T
y: C

class B(A[int, C], GenericModel, Generic[C]):
pass

models = [create_model(f'M{i}') for i in range(count_create_models)]
generics = []
for m in models:
working = B[m]
generics.append(working)

assert len(_generic_types_cache) == cache_size + count_create_models * 5 + 1
assert len(_assigned_parameters) == params_size + count_create_models * 3 + 1
del models
del generics


def test_generic_config():
Expand Down