Skip to content

Commit

Permalink
Merge pull request #2505 from Baptistemontan/rework_adjacently_tagged…
Browse files Browse the repository at this point in the history
…_enum

Revisit of the representation of adjacently tagged enums tag
  • Loading branch information
dtolnay committed Aug 2, 2023
2 parents 83b1a3d + 957ef20 commit 43035f6
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 34 deletions.
64 changes: 62 additions & 2 deletions serde/src/private/de.rs
@@ -1,10 +1,13 @@
use crate::lib::*;

use crate::de::value::{BorrowedBytesDeserializer, BytesDeserializer};
use crate::de::{Deserialize, Deserializer, Error, IntoDeserializer, Visitor};
use crate::de::{
Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, IntoDeserializer, VariantAccess,
Visitor,
};

#[cfg(any(feature = "std", feature = "alloc"))]
use crate::de::{DeserializeSeed, MapAccess, Unexpected};
use crate::de::{MapAccess, Unexpected};

#[cfg(any(feature = "std", feature = "alloc"))]
pub use self::content::{
Expand Down Expand Up @@ -2836,3 +2839,60 @@ fn flat_map_take_entry<'de>(
None
}
}

pub struct AdjacentlyTaggedEnumVariantSeed<F> {
pub tag: &'static str,
pub variants: &'static [&'static str],
pub fields_enum: PhantomData<F>,
}

pub struct AdjacentlyTaggedEnumVariantVisitor<F> {
tag: &'static str,
fields_enum: PhantomData<F>,
}

impl<'de, F> Visitor<'de> for AdjacentlyTaggedEnumVariantVisitor<F>
where
F: Deserialize<'de>,
{
type Value = F;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "enum {}", self.tag)
}

fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
where
A: EnumAccess<'de>,
{
let (variant, variant_access) = match data.variant() {
Ok(values) => values,
Err(err) => return Err(err),
};
if let Err(err) = variant_access.unit_variant() {
return Err(err);
}
Ok(variant)
}
}

impl<'de, F> DeserializeSeed<'de> for AdjacentlyTaggedEnumVariantSeed<F>
where
F: Deserialize<'de>,
{
type Value = F;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_enum(
self.tag,
self.variants,
AdjacentlyTaggedEnumVariantVisitor {
tag: self.tag,
fields_enum: PhantomData,
},
)
}
}
25 changes: 25 additions & 0 deletions serde/src/private/ser.rs
Expand Up @@ -1355,3 +1355,28 @@ where
Ok(())
}
}

pub struct AdjacentlyTaggedEnumVariantSerializer {
tag: &'static str,
variant_index: u32,
variant_name: &'static str,
}

impl AdjacentlyTaggedEnumVariantSerializer {
pub fn new(tag: &'static str, variant_index: u32, variant_name: &'static str) -> Self {
AdjacentlyTaggedEnumVariantSerializer {
tag,
variant_index,
variant_name,
}
}
}

impl Serialize for AdjacentlyTaggedEnumVariantSerializer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_unit_variant(self.tag, self.variant_index, self.variant_name)
}
}
18 changes: 15 additions & 3 deletions serde_derive/src/de.rs
Expand Up @@ -1480,6 +1480,14 @@ fn deserialize_adjacently_tagged_enum(
}
};

let variant_seed = quote! {
_serde::__private::de::AdjacentlyTaggedEnumVariantSeed::<__Field> {
tag: #tag,
variants: &VARIANTS,
fields_enum: _serde::__private::PhantomData
}
};

let mut missing_content = quote! {
_serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#content))
};
Expand Down Expand Up @@ -1527,6 +1535,10 @@ fn deserialize_adjacently_tagged_enum(
_serde::de::MapAccess::next_key_seed(&mut __map, #tag_or_content)?
};

let variant_from_map = quote! {
_serde::de::MapAccess::next_value_seed(&mut __map, #variant_seed)?
};

// When allowing unknown fields, we want to transparently step through keys
// we don't care about until we find `tag`, `content`, or run out of keys.
let next_relevant_key = if deny_unknown_fields {
Expand Down Expand Up @@ -1572,11 +1584,11 @@ fn deserialize_adjacently_tagged_enum(

let finish_content_then_tag = if variant_arms.is_empty() {
quote! {
match _serde::de::MapAccess::next_value::<__Field>(&mut __map)? {}
match #variant_from_map {}
}
} else {
quote! {
let __ret = match _serde::de::MapAccess::next_value(&mut __map)? {
let __ret = match #variant_from_map {
// Deserialize the buffered content now that we know the variant.
#(#variant_arms)*
}?;
Expand Down Expand Up @@ -1632,7 +1644,7 @@ fn deserialize_adjacently_tagged_enum(
// First key is the tag.
_serde::__private::Some(_serde::__private::de::TagOrContentField::Tag) => {
// Parse the tag.
let __field = _serde::de::MapAccess::next_value(&mut __map)?;
let __field = #variant_from_map;
// Visit the second key.
match #next_relevant_key {
// Second key is a duplicate of the tag.
Expand Down
17 changes: 13 additions & 4 deletions serde_derive/src/ser.rs
Expand Up @@ -478,7 +478,14 @@ fn serialize_variant(
serialize_internally_tagged_variant(params, variant, cattrs, tag)
}
(attr::TagType::Adjacent { tag, content }, false) => {
serialize_adjacently_tagged_variant(params, variant, cattrs, tag, content)
serialize_adjacently_tagged_variant(
params,
variant,
cattrs,
variant_index,
tag,
content,
)
}
(attr::TagType::None, _) | (_, true) => {
serialize_untagged_variant(params, variant, cattrs)
Expand Down Expand Up @@ -634,12 +641,14 @@ fn serialize_adjacently_tagged_variant(
params: &Parameters,
variant: &Variant,
cattrs: &attr::Container,
variant_index: u32,
tag: &str,
content: &str,
) -> Fragment {
let this_type = &params.this_type;
let type_name = cattrs.name().serialize_name();
let variant_name = variant.attrs.name().serialize_name();
let variant_serializer = quote!(&_serde::__private::ser::AdjacentlyTaggedEnumVariantSerializer::new(#tag, #variant_index, #variant_name));

let inner = Stmts(if let Some(path) = variant.attrs.serialize_with() {
let ser = wrap_serialize_variant_with(params, path, variant);
Expand All @@ -653,7 +662,7 @@ fn serialize_adjacently_tagged_variant(
let mut __struct = _serde::Serializer::serialize_struct(
__serializer, #type_name, 1)?;
_serde::ser::SerializeStruct::serialize_field(
&mut __struct, #tag, #variant_name)?;
&mut __struct, #tag, #variant_serializer)?;
_serde::ser::SerializeStruct::end(__struct)
};
}
Expand All @@ -670,7 +679,7 @@ fn serialize_adjacently_tagged_variant(
let mut __struct = _serde::Serializer::serialize_struct(
__serializer, #type_name, 2)?;
_serde::ser::SerializeStruct::serialize_field(
&mut __struct, #tag, #variant_name)?;
&mut __struct, #tag, #variant_serializer)?;
#func(
&mut __struct, #content, #field_expr)?;
_serde::ser::SerializeStruct::end(__struct)
Expand Down Expand Up @@ -735,7 +744,7 @@ fn serialize_adjacently_tagged_variant(
let mut __struct = _serde::Serializer::serialize_struct(
__serializer, #type_name, 2)?;
_serde::ser::SerializeStruct::serialize_field(
&mut __struct, #tag, #variant_name)?;
&mut __struct, #tag, #variant_serializer)?;
_serde::ser::SerializeStruct::serialize_field(
&mut __struct, #content, &__AdjacentlyTagged {
data: (#(#fields_ident,)*),
Expand Down
27 changes: 21 additions & 6 deletions test_suite/tests/test_annotations.rs
Expand Up @@ -2108,7 +2108,10 @@ fn test_adjacently_tagged_enum_bytes() {
len: 2,
},
Token::Str("t"),
Token::Str("A"),
Token::UnitVariant {
name: "t",
variant: "A",
},
Token::Str("c"),
Token::Struct { name: "A", len: 1 },
Token::Str("a"),
Expand All @@ -2126,7 +2129,10 @@ fn test_adjacently_tagged_enum_bytes() {
len: 2,
},
Token::Bytes(b"t"),
Token::Str("A"),
Token::UnitVariant {
name: "t",
variant: "A",
},
Token::Bytes(b"c"),
Token::Struct { name: "A", len: 1 },
Token::Str("a"),
Expand Down Expand Up @@ -2167,7 +2173,10 @@ fn test_adjacently_tagged_enum_containing_flatten() {
len: 2,
},
Token::Str("t"),
Token::Str("A"),
Token::UnitVariant {
name: "t",
variant: "A",
},
Token::Str("c"),
Token::Map { len: None },
Token::Str("a"),
Expand Down Expand Up @@ -2757,7 +2766,7 @@ fn test_expecting_message_adjacently_tagged_enum() {
// Check that #[serde(expecting = "...")] doesn't affect variant identifier error message
assert_de_tokens_error::<Enum>(
&[Token::Map { len: None }, Token::Str("tag"), Token::Unit],
r#"invalid type: unit value, expected variant identifier"#,
r#"invalid type: unit value, expected enum tag"#,
);
}

Expand Down Expand Up @@ -2992,7 +3001,10 @@ mod flatten {
Token::Str("outer"),
Token::U32(42),
Token::Str("tag"),
Token::Str("Struct"),
Token::UnitVariant {
name: "tag",
variant: "Struct",
},
Token::Str("content"),
Token::Struct {
len: 2,
Expand Down Expand Up @@ -3020,7 +3032,10 @@ mod flatten {
Token::Str("outer"),
Token::U32(42),
Token::Str("tag"),
Token::Str("Newtype"),
Token::UnitVariant {
name: "tag",
variant: "Newtype",
},
Token::Str("content"),
Token::Struct {
len: 1,
Expand Down

0 comments on commit 43035f6

Please sign in to comment.