Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit of the representation of adjacently tagged enums tag #2505

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 62 additions & 2 deletions serde/src/private/de.rs
@@ -1,10 +1,13 @@
use crate::lib::*;

use crate::de::value::{BorrowedBytesDeserializer, BytesDeserializer};
use crate::de::{Deserialize, Deserializer, Error, IntoDeserializer, Visitor};
use crate::de::{
Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, IntoDeserializer, VariantAccess,
Visitor,
};

#[cfg(any(feature = "std", feature = "alloc"))]
use crate::de::{DeserializeSeed, MapAccess, Unexpected};
use crate::de::{MapAccess, Unexpected};

#[cfg(any(feature = "std", feature = "alloc"))]
pub use self::content::{
Expand Down Expand Up @@ -2836,3 +2839,60 @@ fn flat_map_take_entry<'de>(
None
}
}

pub struct AdjacentlyTaggedEnumVariantSeed<F> {
pub tag: &'static str,
pub variants: &'static [&'static str],
pub fields_enum: PhantomData<F>,
}

pub struct AdjacentlyTaggedEnumVariantVisitor<F> {
tag: &'static str,
fields_enum: PhantomData<F>,
}

impl<'de, F> Visitor<'de> for AdjacentlyTaggedEnumVariantVisitor<F>
where
F: Deserialize<'de>,
{
type Value = F;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "enum {}", self.tag)
}

fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
where
A: EnumAccess<'de>,
{
let (variant, variant_access) = match data.variant() {
Ok(values) => values,
Err(err) => return Err(err),
};
if let Err(err) = variant_access.unit_variant() {
return Err(err);
}
Ok(variant)
}
}

impl<'de, F> DeserializeSeed<'de> for AdjacentlyTaggedEnumVariantSeed<F>
where
F: Deserialize<'de>,
{
type Value = F;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_enum(
self.tag,
self.variants,
AdjacentlyTaggedEnumVariantVisitor {
tag: self.tag,
fields_enum: PhantomData,
},
)
}
}
25 changes: 25 additions & 0 deletions serde/src/private/ser.rs
Expand Up @@ -1355,3 +1355,28 @@ where
Ok(())
}
}

pub struct AdjacentlyTaggedEnumVariantSerializer {
tag: &'static str,
variant_index: u32,
variant_name: &'static str,
}

impl AdjacentlyTaggedEnumVariantSerializer {
pub fn new(tag: &'static str, variant_index: u32, variant_name: &'static str) -> Self {
AdjacentlyTaggedEnumVariantSerializer {
tag,
variant_index,
variant_name,
}
}
}

impl Serialize for AdjacentlyTaggedEnumVariantSerializer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_unit_variant(self.tag, self.variant_index, self.variant_name)
}
}
18 changes: 15 additions & 3 deletions serde_derive/src/de.rs
Expand Up @@ -1510,6 +1510,14 @@ fn deserialize_adjacently_tagged_enum(
}
};

let variant_seed = quote! {
_serde::__private::de::AdjacentlyTaggedEnumVariantSeed::<__Field> {
tag: #tag,
variants: &VARIANTS,
fields_enum: _serde::__private::PhantomData
}
};

let mut missing_content = quote! {
_serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#content))
};
Expand Down Expand Up @@ -1557,6 +1565,10 @@ fn deserialize_adjacently_tagged_enum(
_serde::de::MapAccess::next_key_seed(&mut __map, #tag_or_content)?
};

let variant_from_map = quote! {
_serde::de::MapAccess::next_value_seed(&mut __map, #variant_seed)?
};

// When allowing unknown fields, we want to transparently step through keys
// we don't care about until we find `tag`, `content`, or run out of keys.
let next_relevant_key = if deny_unknown_fields {
Expand Down Expand Up @@ -1602,11 +1614,11 @@ fn deserialize_adjacently_tagged_enum(

let finish_content_then_tag = if variant_arms.is_empty() {
quote! {
match _serde::de::MapAccess::next_value::<__Field>(&mut __map)? {}
match #variant_from_map {}
}
} else {
quote! {
let __ret = match _serde::de::MapAccess::next_value(&mut __map)? {
let __ret = match #variant_from_map {
// Deserialize the buffered content now that we know the variant.
#(#variant_arms)*
}?;
Expand Down Expand Up @@ -1662,7 +1674,7 @@ fn deserialize_adjacently_tagged_enum(
// First key is the tag.
_serde::__private::Some(_serde::__private::de::TagOrContentField::Tag) => {
// Parse the tag.
let __field = _serde::de::MapAccess::next_value(&mut __map)?;
let __field = #variant_from_map;
// Visit the second key.
match #next_relevant_key {
// Second key is a duplicate of the tag.
Expand Down
17 changes: 13 additions & 4 deletions serde_derive/src/ser.rs
Expand Up @@ -478,7 +478,14 @@ fn serialize_variant(
serialize_internally_tagged_variant(params, variant, cattrs, tag)
}
(attr::TagType::Adjacent { tag, content }, false) => {
serialize_adjacently_tagged_variant(params, variant, cattrs, tag, content)
serialize_adjacently_tagged_variant(
params,
variant,
cattrs,
variant_index,
tag,
content,
)
}
(attr::TagType::None, _) | (_, true) => {
serialize_untagged_variant(params, variant, cattrs)
Expand Down Expand Up @@ -634,12 +641,14 @@ fn serialize_adjacently_tagged_variant(
params: &Parameters,
variant: &Variant,
cattrs: &attr::Container,
variant_index: u32,
tag: &str,
content: &str,
) -> Fragment {
let this_type = &params.this_type;
let type_name = cattrs.name().serialize_name();
let variant_name = variant.attrs.name().serialize_name();
let variant_serializer = quote!(&_serde::__private::ser::AdjacentlyTaggedEnumVariantSerializer::new(#tag, #variant_index, #variant_name));

let inner = Stmts(if let Some(path) = variant.attrs.serialize_with() {
let ser = wrap_serialize_variant_with(params, path, variant);
Expand All @@ -653,7 +662,7 @@ fn serialize_adjacently_tagged_variant(
let mut __struct = _serde::Serializer::serialize_struct(
__serializer, #type_name, 1)?;
_serde::ser::SerializeStruct::serialize_field(
&mut __struct, #tag, #variant_name)?;
&mut __struct, #tag, #variant_serializer)?;
_serde::ser::SerializeStruct::end(__struct)
};
}
Expand All @@ -670,7 +679,7 @@ fn serialize_adjacently_tagged_variant(
let mut __struct = _serde::Serializer::serialize_struct(
__serializer, #type_name, 2)?;
_serde::ser::SerializeStruct::serialize_field(
&mut __struct, #tag, #variant_name)?;
&mut __struct, #tag, #variant_serializer)?;
#func(
&mut __struct, #content, #field_expr)?;
_serde::ser::SerializeStruct::end(__struct)
Expand Down Expand Up @@ -735,7 +744,7 @@ fn serialize_adjacently_tagged_variant(
let mut __struct = _serde::Serializer::serialize_struct(
__serializer, #type_name, 2)?;
_serde::ser::SerializeStruct::serialize_field(
&mut __struct, #tag, #variant_name)?;
&mut __struct, #tag, #variant_serializer)?;
_serde::ser::SerializeStruct::serialize_field(
&mut __struct, #content, &__AdjacentlyTagged {
data: (#(#fields_ident,)*),
Expand Down
27 changes: 21 additions & 6 deletions test_suite/tests/test_annotations.rs
Expand Up @@ -2108,7 +2108,10 @@ fn test_adjacently_tagged_enum_bytes() {
len: 2,
},
Token::Str("t"),
Token::Str("A"),
Token::UnitVariant {
name: "t",
variant: "A",
},
Token::Str("c"),
Token::Struct { name: "A", len: 1 },
Token::Str("a"),
Expand All @@ -2126,7 +2129,10 @@ fn test_adjacently_tagged_enum_bytes() {
len: 2,
},
Token::Bytes(b"t"),
Token::Str("A"),
Token::UnitVariant {
name: "t",
variant: "A",
},
Token::Bytes(b"c"),
Token::Struct { name: "A", len: 1 },
Token::Str("a"),
Expand Down Expand Up @@ -2167,7 +2173,10 @@ fn test_adjacently_tagged_enum_containing_flatten() {
len: 2,
},
Token::Str("t"),
Token::Str("A"),
Token::UnitVariant {
name: "t",
variant: "A",
},
Token::Str("c"),
Token::Map { len: None },
Token::Str("a"),
Expand Down Expand Up @@ -2757,7 +2766,7 @@ fn test_expecting_message_adjacently_tagged_enum() {
// Check that #[serde(expecting = "...")] doesn't affect variant identifier error message
assert_de_tokens_error::<Enum>(
&[Token::Map { len: None }, Token::Str("tag"), Token::Unit],
r#"invalid type: unit value, expected variant identifier"#,
r#"invalid type: unit value, expected enum tag"#,
);
}

Expand Down Expand Up @@ -2992,7 +3001,10 @@ mod flatten {
Token::Str("outer"),
Token::U32(42),
Token::Str("tag"),
Token::Str("Struct"),
Token::UnitVariant {
name: "tag",
variant: "Struct",
},
Token::Str("content"),
Token::Struct {
len: 2,
Expand Down Expand Up @@ -3020,7 +3032,10 @@ mod flatten {
Token::Str("outer"),
Token::U32(42),
Token::Str("tag"),
Token::Str("Newtype"),
Token::UnitVariant {
name: "tag",
variant: "Newtype",
},
Token::Str("content"),
Token::Struct {
len: 1,
Expand Down