diff --git a/serde/src/private/de.rs b/serde/src/private/de.rs index 48ccb5bce..bbe04599b 100644 --- a/serde/src/private/de.rs +++ b/serde/src/private/de.rs @@ -1,10 +1,13 @@ use crate::lib::*; use crate::de::value::{BorrowedBytesDeserializer, BytesDeserializer}; -use crate::de::{Deserialize, Deserializer, Error, IntoDeserializer, Visitor}; +use crate::de::{ + Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, IntoDeserializer, VariantAccess, + Visitor, +}; #[cfg(any(feature = "std", feature = "alloc"))] -use crate::de::{DeserializeSeed, MapAccess, Unexpected}; +use crate::de::{MapAccess, Unexpected}; #[cfg(any(feature = "std", feature = "alloc"))] pub use self::content::{ @@ -2836,3 +2839,60 @@ fn flat_map_take_entry<'de>( None } } + +pub struct AdjacentlyTaggedEnumVariantSeed { + pub tag: &'static str, + pub variants: &'static [&'static str], + pub fields_enum: PhantomData, +} + +pub struct AdjacentlyTaggedEnumVariantVisitor { + tag: &'static str, + fields_enum: PhantomData, +} + +impl<'de, F> Visitor<'de> for AdjacentlyTaggedEnumVariantVisitor +where + F: Deserialize<'de>, +{ + type Value = F; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "enum {}", self.tag) + } + + fn visit_enum(self, data: A) -> Result + where + A: EnumAccess<'de>, + { + let (variant, variant_access) = match data.variant() { + Ok(values) => values, + Err(err) => return Err(err), + }; + if let Err(err) = variant_access.unit_variant() { + return Err(err); + } + Ok(variant) + } +} + +impl<'de, F> DeserializeSeed<'de> for AdjacentlyTaggedEnumVariantSeed +where + F: Deserialize<'de>, +{ + type Value = F; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_enum( + self.tag, + self.variants, + AdjacentlyTaggedEnumVariantVisitor { + tag: self.tag, + fields_enum: PhantomData, + }, + ) + } +} diff --git a/serde/src/private/ser.rs b/serde/src/private/ser.rs index 016f647ee..4845fef02 100644 --- a/serde/src/private/ser.rs +++ b/serde/src/private/ser.rs @@ -1355,3 +1355,28 @@ where Ok(()) } } + +pub struct AdjacentlyTaggedEnumVariantSerializer { + tag: &'static str, + variant_index: u32, + variant_name: &'static str, +} + +impl AdjacentlyTaggedEnumVariantSerializer { + pub fn new(tag: &'static str, variant_index: u32, variant_name: &'static str) -> Self { + AdjacentlyTaggedEnumVariantSerializer { + tag, + variant_index, + variant_name, + } + } +} + +impl Serialize for AdjacentlyTaggedEnumVariantSerializer { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_unit_variant(self.tag, self.variant_index, self.variant_name) + } +} diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 1434680b8..cd9be5ea7 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1510,6 +1510,14 @@ fn deserialize_adjacently_tagged_enum( } }; + let variant_seed = quote! { + _serde::__private::de::AdjacentlyTaggedEnumVariantSeed::<__Field> { + tag: #tag, + variants: &VARIANTS, + fields_enum: _serde::__private::PhantomData + } + }; + let mut missing_content = quote! { _serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#content)) }; @@ -1557,6 +1565,10 @@ fn deserialize_adjacently_tagged_enum( _serde::de::MapAccess::next_key_seed(&mut __map, #tag_or_content)? }; + let variant_from_map = quote! { + _serde::de::MapAccess::next_value_seed(&mut __map, #variant_seed)? + }; + // When allowing unknown fields, we want to transparently step through keys // we don't care about until we find `tag`, `content`, or run out of keys. let next_relevant_key = if deny_unknown_fields { @@ -1602,11 +1614,11 @@ fn deserialize_adjacently_tagged_enum( let finish_content_then_tag = if variant_arms.is_empty() { quote! { - match _serde::de::MapAccess::next_value::<__Field>(&mut __map)? {} + match #variant_from_map {} } } else { quote! { - let __ret = match _serde::de::MapAccess::next_value(&mut __map)? { + let __ret = match #variant_from_map { // Deserialize the buffered content now that we know the variant. #(#variant_arms)* }?; @@ -1662,7 +1674,7 @@ fn deserialize_adjacently_tagged_enum( // First key is the tag. _serde::__private::Some(_serde::__private::de::TagOrContentField::Tag) => { // Parse the tag. - let __field = _serde::de::MapAccess::next_value(&mut __map)?; + let __field = #variant_from_map; // Visit the second key. match #next_relevant_key { // Second key is a duplicate of the tag. diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index 1792d3fec..32ba4529d 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -478,7 +478,14 @@ fn serialize_variant( serialize_internally_tagged_variant(params, variant, cattrs, tag) } (attr::TagType::Adjacent { tag, content }, false) => { - serialize_adjacently_tagged_variant(params, variant, cattrs, tag, content) + serialize_adjacently_tagged_variant( + params, + variant, + cattrs, + variant_index, + tag, + content, + ) } (attr::TagType::None, _) | (_, true) => { serialize_untagged_variant(params, variant, cattrs) @@ -634,12 +641,14 @@ fn serialize_adjacently_tagged_variant( params: &Parameters, variant: &Variant, cattrs: &attr::Container, + variant_index: u32, tag: &str, content: &str, ) -> Fragment { let this_type = ¶ms.this_type; let type_name = cattrs.name().serialize_name(); let variant_name = variant.attrs.name().serialize_name(); + let variant_serializer = quote!(&_serde::__private::ser::AdjacentlyTaggedEnumVariantSerializer::new(#tag, #variant_index, #variant_name)); let inner = Stmts(if let Some(path) = variant.attrs.serialize_with() { let ser = wrap_serialize_variant_with(params, path, variant); @@ -653,7 +662,7 @@ fn serialize_adjacently_tagged_variant( let mut __struct = _serde::Serializer::serialize_struct( __serializer, #type_name, 1)?; _serde::ser::SerializeStruct::serialize_field( - &mut __struct, #tag, #variant_name)?; + &mut __struct, #tag, #variant_serializer)?; _serde::ser::SerializeStruct::end(__struct) }; } @@ -670,7 +679,7 @@ fn serialize_adjacently_tagged_variant( let mut __struct = _serde::Serializer::serialize_struct( __serializer, #type_name, 2)?; _serde::ser::SerializeStruct::serialize_field( - &mut __struct, #tag, #variant_name)?; + &mut __struct, #tag, #variant_serializer)?; #func( &mut __struct, #content, #field_expr)?; _serde::ser::SerializeStruct::end(__struct) @@ -735,7 +744,7 @@ fn serialize_adjacently_tagged_variant( let mut __struct = _serde::Serializer::serialize_struct( __serializer, #type_name, 2)?; _serde::ser::SerializeStruct::serialize_field( - &mut __struct, #tag, #variant_name)?; + &mut __struct, #tag, #variant_serializer)?; _serde::ser::SerializeStruct::serialize_field( &mut __struct, #content, &__AdjacentlyTagged { data: (#(#fields_ident,)*), diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 9d301c4cf..10f1d05e9 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -2108,7 +2108,10 @@ fn test_adjacently_tagged_enum_bytes() { len: 2, }, Token::Str("t"), - Token::Str("A"), + Token::UnitVariant { + name: "t", + variant: "A", + }, Token::Str("c"), Token::Struct { name: "A", len: 1 }, Token::Str("a"), @@ -2126,7 +2129,10 @@ fn test_adjacently_tagged_enum_bytes() { len: 2, }, Token::Bytes(b"t"), - Token::Str("A"), + Token::UnitVariant { + name: "t", + variant: "A", + }, Token::Bytes(b"c"), Token::Struct { name: "A", len: 1 }, Token::Str("a"), @@ -2167,7 +2173,10 @@ fn test_adjacently_tagged_enum_containing_flatten() { len: 2, }, Token::Str("t"), - Token::Str("A"), + Token::UnitVariant { + name: "t", + variant: "A", + }, Token::Str("c"), Token::Map { len: None }, Token::Str("a"), @@ -2757,7 +2766,7 @@ fn test_expecting_message_adjacently_tagged_enum() { // Check that #[serde(expecting = "...")] doesn't affect variant identifier error message assert_de_tokens_error::( &[Token::Map { len: None }, Token::Str("tag"), Token::Unit], - r#"invalid type: unit value, expected variant identifier"#, + r#"invalid type: unit value, expected enum tag"#, ); } @@ -2992,7 +3001,10 @@ mod flatten { Token::Str("outer"), Token::U32(42), Token::Str("tag"), - Token::Str("Struct"), + Token::UnitVariant { + name: "tag", + variant: "Struct", + }, Token::Str("content"), Token::Struct { len: 2, @@ -3020,7 +3032,10 @@ mod flatten { Token::Str("outer"), Token::U32(42), Token::Str("tag"), - Token::Str("Newtype"), + Token::UnitVariant { + name: "tag", + variant: "Newtype", + }, Token::Str("content"), Token::Struct { len: 1, diff --git a/test_suite/tests/test_macros.rs b/test_suite/tests/test_macros.rs index 99a78ab44..cbd0d9b75 100644 --- a/test_suite/tests/test_macros.rs +++ b/test_suite/tests/test_macros.rs @@ -472,7 +472,10 @@ fn test_adjacently_tagged_newtype_struct() { }, Token::U32(5), Token::Str("t"), - Token::Str("Newtype"), + Token::UnitVariant { + name: "t", + variant: "Newtype", + }, Token::StructEnd, ], ); @@ -1066,7 +1069,10 @@ fn test_adjacently_tagged_enum() { len: 1, }, Token::Str("t"), - Token::Str("Unit"), + Token::UnitVariant { + name: "t", + variant: "Unit", + }, Token::StructEnd, ], ); @@ -1080,7 +1086,10 @@ fn test_adjacently_tagged_enum() { len: 2, }, Token::Str("t"), - Token::Str("Unit"), + Token::UnitVariant { + name: "t", + variant: "Unit", + }, Token::StructEnd, ], ); @@ -1094,7 +1103,10 @@ fn test_adjacently_tagged_enum() { len: 2, }, Token::Str("t"), - Token::Str("Unit"), + Token::UnitVariant { + name: "t", + variant: "Unit", + }, Token::Str("c"), Token::Unit, Token::StructEnd, @@ -1112,7 +1124,10 @@ fn test_adjacently_tagged_enum() { Token::Str("c"), Token::Unit, Token::Str("t"), - Token::Str("Unit"), + Token::UnitVariant { + name: "t", + variant: "Unit", + }, Token::StructEnd, ], ); @@ -1128,7 +1143,10 @@ fn test_adjacently_tagged_enum() { Token::Str("f"), Token::Unit, Token::Str("t"), - Token::Str("Unit"), + Token::UnitVariant { + name: "t", + variant: "Unit", + }, Token::Str("g"), Token::Unit, Token::Str("c"), @@ -1148,7 +1166,10 @@ fn test_adjacently_tagged_enum() { len: 2, }, Token::Str("t"), - Token::Str("Newtype"), + Token::UnitVariant { + name: "t", + variant: "Newtype", + }, Token::Str("c"), Token::U8(1), Token::StructEnd, @@ -1166,7 +1187,10 @@ fn test_adjacently_tagged_enum() { Token::Str("c"), Token::U8(1), Token::Str("t"), - Token::Str("Newtype"), + Token::UnitVariant { + name: "t", + variant: "Newtype", + }, Token::StructEnd, ], ); @@ -1180,7 +1204,10 @@ fn test_adjacently_tagged_enum() { len: 1, }, Token::Str("t"), - Token::Str("Newtype"), + Token::UnitVariant { + name: "t", + variant: "Newtype", + }, Token::StructEnd, ], ); @@ -1194,7 +1221,10 @@ fn test_adjacently_tagged_enum() { len: 2, }, Token::Str("t"), - Token::Str("Tuple"), + Token::UnitVariant { + name: "t", + variant: "Tuple", + }, Token::Str("c"), Token::Tuple { len: 2 }, Token::U8(1), @@ -1218,7 +1248,10 @@ fn test_adjacently_tagged_enum() { Token::U8(1), Token::TupleEnd, Token::Str("t"), - Token::Str("Tuple"), + Token::UnitVariant { + name: "t", + variant: "Tuple", + }, Token::StructEnd, ], ); @@ -1232,7 +1265,10 @@ fn test_adjacently_tagged_enum() { len: 2, }, Token::Str("t"), - Token::Str("Struct"), + Token::UnitVariant { + name: "t", + variant: "Struct", + }, Token::Str("c"), Token::Struct { name: "Struct", @@ -1262,7 +1298,10 @@ fn test_adjacently_tagged_enum() { Token::U8(1), Token::StructEnd, Token::Str("t"), - Token::Str("Struct"), + Token::UnitVariant { + name: "t", + variant: "Struct", + }, Token::StructEnd, ], ); @@ -1278,7 +1317,10 @@ fn test_adjacently_tagged_enum() { Token::U64(1), // content field Token::U8(1), Token::U64(0), // tag field - Token::Str("Newtype"), + Token::UnitVariant { + name: "t", + variant: "Newtype", + }, Token::StructEnd, ], ); @@ -1294,7 +1336,10 @@ fn test_adjacently_tagged_enum() { Token::Bytes(b"c"), Token::U8(1), Token::Bytes(b"t"), - Token::Str("Newtype"), + Token::UnitVariant { + name: "t", + variant: "Newtype", + }, Token::StructEnd, ], ); @@ -1316,7 +1361,10 @@ fn test_adjacently_tagged_enum_deny_unknown_fields() { len: 2, }, Token::Str("t"), - Token::Str("Unit"), + Token::UnitVariant { + name: "t", + variant: "Unit", + }, Token::Str("c"), Token::Unit, Token::StructEnd, @@ -1330,7 +1378,10 @@ fn test_adjacently_tagged_enum_deny_unknown_fields() { len: 2, }, Token::Str("t"), - Token::Str("Unit"), + Token::UnitVariant { + name: "t", + variant: "Unit", + }, Token::Str("c"), Token::Unit, Token::Str("h"), @@ -1369,7 +1420,10 @@ fn test_adjacently_tagged_enum_deny_unknown_fields() { len: 2, }, Token::U64(0), // tag field - Token::Str("Unit"), + Token::UnitVariant { + name: "t", + variant: "Unit", + }, Token::U64(3), ], r#"invalid value: integer `3`, expected "t" or "c""#, @@ -1565,7 +1619,10 @@ fn test_internally_tagged_struct_with_flattened_field() { Token::Str("tag_struct"), Token::Str("Struct"), Token::Str("tag_enum"), - Token::Str("A"), + Token::UnitVariant { + name: "tag_enum", + variant: "A", + }, Token::Str("content"), Token::U64(0), Token::MapEnd,