Skip to content

Commit

Permalink
Fixes discriminated unions not working on aliased literal fields. Issue
Browse files Browse the repository at this point in the history
  • Loading branch information
benwah committed May 10, 2023
1 parent 8131196 commit 4d2a64c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
1 change: 1 addition & 0 deletions changes/5736-benwah.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This solves the (closed) issue #3849 where aliased fields that use discriminated union fail to validate when the data contains the non-aliased field name.
5 changes: 4 additions & 1 deletion pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,10 @@ def _validate_discriminated_union(
assert self.discriminator_alias is not None

try:
discriminator_value = v[self.discriminator_alias]
try:
discriminator_value = v[self.discriminator_alias]
except KeyError:
discriminator_value = v[self.discriminator_key]
except KeyError:
return v, ErrorWrapper(MissingDiscriminator(discriminator_key=self.discriminator_key), loc)
except TypeError:
Expand Down
58 changes: 58 additions & 0 deletions tests/test_discrimated_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,64 @@ class Top(BaseModel):
assert Top(sub=B(literal='b')).sub.literal == 'b'


def test_discriminated_union_model_with_alias():
class A(BaseModel):
literal: Literal['a'] = Field(alias='lit')

class B(BaseModel):
literal: Literal['b'] = Field(alias='lit')

class Config:
allow_population_by_field_name = True

class TopDisallow(BaseModel):
sub: Union[A, B] = Field(..., discriminator='literal', alias='s')

class TopAllow(BaseModel):
sub: Union[A, B] = Field(..., discriminator='literal', alias='s')

class Config:
allow_population_by_field_name = True

assert TopDisallow.parse_obj({'s': {'lit': 'a'}}).sub.literal == 'a'
assert TopDisallow.parse_obj({'s': {'literal': 'b'}}).sub.literal == 'b'

with pytest.raises(ValidationError) as exc_info:
TopDisallow.parse_obj({'s': {'literal': 'a'}})

assert exc_info.value.errors() == [
{'loc': ('s', 'A', 'lit'), 'msg': 'field required', 'type': 'value_error.missing'},
]

with pytest.raises(ValidationError) as exc_info:
TopDisallow.parse_obj({'sub': {'lit': 'a'}})

assert exc_info.value.errors() == [
{'loc': ('s',), 'msg': 'field required', 'type': 'value_error.missing'},
]

assert TopAllow.parse_obj({'s': {'lit': 'a'}}).sub.literal == 'a'
assert TopAllow.parse_obj({'s': {'lit': 'b'}}).sub.literal == 'b'
assert TopAllow.parse_obj({'s': {'literal': 'b'}}).sub.literal == 'b'
assert TopAllow.parse_obj({'sub': {'lit': 'a'}}).sub.literal == 'a'
assert TopAllow.parse_obj({'sub': {'lit': 'b'}}).sub.literal == 'b'
assert TopAllow.parse_obj({'sub': {'literal': 'b'}}).sub.literal == 'b'

with pytest.raises(ValidationError) as exc_info:
TopAllow.parse_obj({'s': {'literal': 'a'}})

assert exc_info.value.errors() == [
{'loc': ('s', 'A', 'lit'), 'msg': 'field required', 'type': 'value_error.missing'},
]

with pytest.raises(ValidationError) as exc_info:
TopAllow.parse_obj({'sub': {'literal': 'a'}})

assert exc_info.value.errors() == [
{'loc': ('s', 'A', 'lit'), 'msg': 'field required', 'type': 'value_error.missing'},
]


def test_discriminated_union_int():
class A(BaseModel):
l: Literal[1]
Expand Down

0 comments on commit 4d2a64c

Please sign in to comment.