Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowed Enum variants to be individually marked as untagged #2403

Merged
merged 8 commits into from Jun 8, 2023
45 changes: 41 additions & 4 deletions serde_derive/src/de.rs
Expand Up @@ -1165,6 +1165,32 @@ fn deserialize_enum(
params: &Parameters,
variants: &[Variant],
cattrs: &attr::Container,
) -> Fragment {
// The variants have already been checked (in ast.rs) that all untagged variants appear at the end
match variants
.iter()
.enumerate()
.find(|(_, var)| var.attrs.untagged())
dewert99 marked this conversation as resolved.
Show resolved Hide resolved
{
Some((variant_idx, _)) => {
let (tagged, untagged) = variants.split_at(variant_idx);
let tagged_frag = Expr(deserialize_homogeneous_enum(params, tagged, cattrs));
let tagged_frag = |deserializer| {
Some(Expr(quote_block! {
let __deserializer = #deserializer;
#tagged_frag
}))
};
deserialize_untagged_enum_after(params, untagged, cattrs, tagged_frag)
}
None => deserialize_homogeneous_enum(params, variants, cattrs),
}
}

fn deserialize_homogeneous_enum(
params: &Parameters,
variants: &[Variant],
cattrs: &attr::Container,
) -> Fragment {
match cattrs.tag() {
attr::TagType::External => deserialize_externally_tagged_enum(params, variants, cattrs),
Expand Down Expand Up @@ -1661,6 +1687,17 @@ fn deserialize_untagged_enum(
variants: &[Variant],
cattrs: &attr::Container,
) -> Fragment {
deserialize_untagged_enum_after(params, variants, cattrs, |_| None)
}

fn deserialize_untagged_enum_after(
params: &Parameters,
variants: &[Variant],
cattrs: &attr::Container,
first_attempt: impl FnOnce(TokenStream) -> Option<Expr>,
) -> Fragment {
let deserializer =
quote!(_serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content));
let attempts = variants
.iter()
.filter(|variant| !variant.attrs.skip_deserializing())
Expand All @@ -1669,12 +1706,12 @@ fn deserialize_untagged_enum(
params,
variant,
cattrs,
quote!(
_serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content)
),
deserializer.clone(),
))
});

let attempts = first_attempt(deserializer.clone())
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
.into_iter()
.chain(attempts);
// TODO this message could be better by saving the errors from the failed
// attempts. The heuristic used by TOML was to count the number of fields
// processed before an error, and use the error that happened after the
Expand Down
9 changes: 7 additions & 2 deletions serde_derive/src/internals/ast.rs
Expand Up @@ -140,6 +140,7 @@ fn enum_from_ast<'a>(
variants: &'a Punctuated<syn::Variant, Token![,]>,
container_default: &attr::Default,
) -> Vec<Variant<'a>> {
let mut seen_untagged = false;
variants
.iter()
.map(|variant| {
Expand All @@ -153,8 +154,12 @@ fn enum_from_ast<'a>(
fields,
original: variant,
}
})
.collect()
}).inspect(|variant| {
if !variant.attrs.untagged() && seen_untagged {
cx.error_spanned_by(&variant.ident, "all variants with the #[serde(untagged)] attribute must be placed at the end of the enum")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only allowing #[serde(untagged)] to be at the end?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This keeps the implementation more consistent with the deserialisation strategy of going though each variant in order and choosing the first variant that matches.

}
seen_untagged = variant.attrs.untagged()
}).collect()
}

fn struct_from_ast<'a>(
Expand Down
9 changes: 9 additions & 0 deletions serde_derive/src/internals/attr.rs
Expand Up @@ -734,6 +734,7 @@ pub struct Variant {
serialize_with: Option<syn::ExprPath>,
deserialize_with: Option<syn::ExprPath>,
borrow: Option<BorrowAttribute>,
untagged: bool,
}

struct BorrowAttribute {
Expand All @@ -756,6 +757,7 @@ impl Variant {
let mut serialize_with = Attr::none(cx, SERIALIZE_WITH);
let mut deserialize_with = Attr::none(cx, DESERIALIZE_WITH);
let mut borrow = Attr::none(cx, BORROW);
let mut untagged = BoolAttr::none(cx, UNTAGGED);

for attr in &variant.attrs {
if attr.path() != SERDE {
Expand Down Expand Up @@ -867,6 +869,8 @@ impl Variant {
cx.error_spanned_by(variant, msg);
}
}
} else if meta.path == UNTAGGED {
untagged.set_true(&meta.path);
} else {
let path = meta.path.to_token_stream().to_string().replace(' ', "");
return Err(
Expand All @@ -893,6 +897,7 @@ impl Variant {
serialize_with: serialize_with.get(),
deserialize_with: deserialize_with.get(),
borrow: borrow.get(),
untagged: untagged.get(),
}
}

Expand Down Expand Up @@ -944,6 +949,10 @@ impl Variant {
pub fn deserialize_with(&self) -> Option<&syn::ExprPath> {
self.deserialize_with.as_ref()
}

pub fn untagged(&self) -> bool {
self.untagged
}
}

/// Represents field attribute information
Expand Down
10 changes: 5 additions & 5 deletions serde_derive/src/ser.rs
Expand Up @@ -477,17 +477,17 @@ fn serialize_variant(
}
};

let body = Match(match cattrs.tag() {
attr::TagType::External => {
let body = Match(match (cattrs.tag(), variant.attrs.untagged()) {
(attr::TagType::External, false) => {
serialize_externally_tagged_variant(params, variant, variant_index, cattrs)
}
attr::TagType::Internal { tag } => {
(attr::TagType::Internal { tag }, false) => {
serialize_internally_tagged_variant(params, variant, cattrs, tag)
}
attr::TagType::Adjacent { tag, content } => {
(attr::TagType::Adjacent { tag, content }, false) => {
serialize_adjacently_tagged_variant(params, variant, cattrs, tag, content)
}
attr::TagType::None => serialize_untagged_variant(params, variant, cattrs),
(attr::TagType::None, _) | (_, true) => serialize_untagged_variant(params, variant, cattrs),
});

quote! {
Expand Down
153 changes: 153 additions & 0 deletions test_suite/tests/test_annotations.rs
Expand Up @@ -2395,6 +2395,159 @@ fn test_untagged_enum_containing_flatten() {
);
}

#[test]
fn test_partially_untagged_enum() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
enum Exp {
Lambda(u32, Box<Exp>),
#[serde(untagged)]
App(Box<Exp>, Box<Exp>),
#[serde(untagged)]
Var(u32),
}
use Exp::*;
oli-obk marked this conversation as resolved.
Show resolved Hide resolved

let data = Lambda(0, Box::new(App(Box::new(Var(0)), Box::new(Var(0)))));
assert_tokens(
&data,
&[
Token::TupleVariant {
name: "Exp",
variant: "Lambda",
len: 2,
},
Token::U32(0),
Token::Tuple { len: 2 },
Token::U32(0),
Token::U32(0),
Token::TupleEnd,
Token::TupleVariantEnd,
],
);
}

#[test]
fn test_partially_untagged_enum_generic() {
trait Trait<T> {
type Assoc;
type Assoc2;
}

#[derive(Serialize, Deserialize, PartialEq, Debug)]
enum E<A, B, C> where A: Trait<C, Assoc2=B> {
A(A::Assoc),
#[serde(untagged)]
B(A::Assoc2),
}

impl<T> Trait<T> for () {
type Assoc = T;
type Assoc2 = bool;
}

type MyE = E<(), bool, u32>;
use E::*;

assert_tokens::<MyE>(&B(true), &[Token::Bool(true)]);

assert_tokens::<MyE>(
&A(5),
&[
Token::NewtypeVariant {
name: "E",
variant: "A",
},
Token::U32(5),
],
);
}

#[test]
fn test_partially_untagged_enum_desugared() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
enum Test {
A(u32, u32),
B(u32),
#[serde(untagged)]
C(u32),
#[serde(untagged)]
D(u32, u32),
}
use Test::*;

mod desugared {
use super::*;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
pub(super) enum Test {
A(u32, u32),
B(u32),
}
}
use desugared::Test as TestTagged;

#[derive(Serialize, Deserialize, PartialEq, Debug)]
#[serde(untagged)]
enum TestUntagged {
Tagged(TestTagged),
C(u32),
D(u32, u32),
}

impl From<Test> for TestUntagged {
fn from(test: Test) -> Self {
match test {
A(x, y) => TestUntagged::Tagged(TestTagged::A(x, y)),
B(x) => TestUntagged::Tagged(TestTagged::B(x)),
C(x) => TestUntagged::C(x),
D(x, y) => TestUntagged::D(x, y),
}
}
}

fn assert_tokens_desugared(value: Test, tokens: &[Token]) {
assert_tokens(&value, tokens);
let desugared: TestUntagged = value.into();
assert_tokens(&desugared, tokens);
}

assert_tokens_desugared(
A(0, 1),
&[
Token::TupleVariant {
name: "Test",
variant: "A",
len: 2,
},
Token::U32(0),
Token::U32(1),
Token::TupleVariantEnd,
],
);

assert_tokens_desugared(
B(1),
&[
Token::NewtypeVariant {
name: "Test",
variant: "B",
},
Token::U32(1),
],
);

assert_tokens_desugared(C(2), &[Token::U32(2)]);

assert_tokens_desugared(
D(3, 5),
&[
Token::Tuple { len: 2 },
Token::U32(3),
Token::U32(5),
Token::TupleEnd,
],
);
}

#[test]
fn test_flatten_untagged_enum() {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
Expand Down
@@ -0,0 +1,10 @@
use serde_derive::Serialize;

#[derive(Serialize)]
enum E {
#[serde(untagged)]
A(u8),
B(String),
}

fn main() {}
@@ -0,0 +1,5 @@
error: all variants with the #[serde(untagged)] attribute must be placed at the end of the enum
--> tests/ui/enum-representation/partially_tagged_wrong_order.rs:7:5
|
7 | B(String),
| ^