Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement #[key = ...] #2698

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions serde/src/ser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,25 @@ pub trait SerializeStruct {
Ok(())
}

/// Serialize a struct field consisting of a typed key and value.
///
/// This method is similar to `serialize_field`, but it allows the key to be
/// a type that is not a string. This is useful for formats that support
/// non-string keys such as CBOR.
fn serialize_typed_field<K: ?Sized, V: ?Sized>(
&mut self,
key: &K,
value: &V,
) -> Result<(), Self::Error>
where
K: Serialize,
V: Serialize,
{
let _ = key;
let _ = value;
Ok(())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the default implementation error instead of silently doing nothing?

}

/// Finish serializing a struct.
fn end(self) -> Result<Self::Ok, Self::Error>;
}
Expand Down
117 changes: 103 additions & 14 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,7 @@ fn deserialize_struct(
field.attrs.name().deserialize_name(),
field_i(i),
field.attrs.aliases(),
field.attrs.key(),
)
})
.collect();
Expand Down Expand Up @@ -1010,7 +1011,7 @@ fn deserialize_struct(
} else {
let field_names = field_names_idents
.iter()
.flat_map(|&(_, _, aliases)| aliases);
.flat_map(|&(_, _, aliases, _)| aliases);

Some(quote! {
#[doc(hidden)]
Expand Down Expand Up @@ -1112,6 +1113,7 @@ fn deserialize_struct_in_place(
field.attrs.name().deserialize_name(),
field_i(i),
field.attrs.aliases(),
field.attrs.key(),
)
})
.collect();
Expand All @@ -1127,7 +1129,7 @@ fn deserialize_struct_in_place(
let visit_map = Stmts(deserialize_map_in_place(params, fields, cattrs));
let field_names = field_names_idents
.iter()
.flat_map(|&(_, _, aliases)| aliases);
.flat_map(|&(_, _, aliases, _)| aliases);
let type_name = cattrs.name().deserialize_name();

let in_place_impl_generics = de_impl_generics.in_place();
Expand Down Expand Up @@ -1226,6 +1228,7 @@ fn prepare_enum_variant_enum(
variant.attrs.name().deserialize_name(),
field_i(i),
variant.attrs.aliases(),
None,
)
})
.collect();
Expand All @@ -1238,7 +1241,7 @@ fn prepare_enum_variant_enum(
});

let variants_stmt = {
let variant_names = variant_names_idents.iter().map(|(name, _, _)| name);
let variant_names = variant_names_idents.iter().map(|(name, _, _, _)| name);
quote! {
#[doc(hidden)]
const VARIANTS: &'static [&'static str] = &[ #(#variant_names),* ];
Expand Down Expand Up @@ -1984,14 +1987,14 @@ fn deserialize_untagged_newtype_variant(
}

fn deserialize_generated_identifier(
fields: &[(&str, Ident, &BTreeSet<String>)],
fields: &[(&str, Ident, &BTreeSet<String>, Option<&syn::Lit>)],
cattrs: &attr::Container,
is_variant: bool,
ignore_variant: Option<TokenStream>,
fallthrough: Option<TokenStream>,
) -> Fragment {
let this_value = quote!(__Field);
let field_idents: &Vec<_> = &fields.iter().map(|(_, ident, _)| ident).collect();
let field_idents: &Vec<_> = &fields.iter().map(|(_, ident, _, _)| ident).collect();

let visitor_impl = Stmts(deserialize_identifier(
&this_value,
Expand Down Expand Up @@ -2041,7 +2044,7 @@ fn deserialize_generated_identifier(
/// Generates enum and its `Deserialize` implementation that represents each
/// non-skipped field of the struct
fn deserialize_field_identifier(
fields: &[(&str, Ident, &BTreeSet<String>)],
fields: &[(&str, Ident, &BTreeSet<String>, Option<&syn::Lit>)],
cattrs: &attr::Container,
) -> Stmts {
let (ignore_variant, fallthrough) = if cattrs.has_flatten() {
Expand Down Expand Up @@ -2122,11 +2125,12 @@ fn deserialize_custom_identifier(
variant.attrs.name().deserialize_name(),
variant.ident.clone(),
variant.attrs.aliases(),
None,
)
})
.collect();

let names = names_idents.iter().flat_map(|&(_, _, aliases)| aliases);
let names = names_idents.iter().flat_map(|&(_, _, aliases, _)| aliases);

let names_const = if fallthrough.is_some() {
None
Expand Down Expand Up @@ -2182,23 +2186,27 @@ fn deserialize_custom_identifier(

fn deserialize_identifier(
this_value: &TokenStream,
fields: &[(&str, Ident, &BTreeSet<String>)],
fields: &[(&str, Ident, &BTreeSet<String>, Option<&syn::Lit>)],
is_variant: bool,
fallthrough: Option<TokenStream>,
fallthrough_borrowed: Option<TokenStream>,
collect_other_fields: bool,
expecting: Option<&str>,
) -> Fragment {
let str_mapping = fields.iter().map(|(_, ident, aliases)| {
let str_mapping = fields.iter().map(|(_, ident, aliases, _)| {
// `aliases` also contains a main name
quote!(#(#aliases)|* => _serde::__private::Ok(#this_value::#ident))
});
let bytes_mapping = fields.iter().map(|(_, ident, aliases)| {
// `aliases` also contains a main name
let bytes_mapping = fields.iter().map(|(_, ident, aliases, key)| {
let aliases = aliases
.iter()
.map(|alias| Literal::byte_string(alias.as_bytes()));
quote!(#(#aliases)|* => _serde::__private::Ok(#this_value::#ident))
// `aliases` also contains a main name
if let Some(syn::Lit::ByteStr(key)) = key {
quote!(#key | #(#aliases)|* => _serde::__private::Ok(#this_value::#ident))
} else {
quote!(#(#aliases)|* => _serde::__private::Ok(#this_value::#ident))
}
});

let expecting = expecting.unwrap_or(if is_variant {
Expand Down Expand Up @@ -2348,8 +2356,87 @@ fn deserialize_identifier(
}
}
} else {
let u64_mapping = fields.iter().enumerate().map(|(i, (_, ident, _))| {
let i = i as u64;
let mut visits = Vec::new();

let char_mapping = fields
.iter()
.enumerate()
.filter_map(|(_, (_, ident, _, key))| {
if let Some(syn::Lit::Char(key)) = key {
Some(quote!(#key => _serde::__private::Ok(#this_value::#ident)))
} else {
None
}
})
.collect::<Vec<_>>();
if !char_mapping.is_empty() {
visits.push(quote! {
fn visit_char<__E>(self, __value: char) -> _serde::__private::Result<Self::Value, __E>
where
__E: _serde::de::Error,
{
match __value {
#(#char_mapping,)*
_ => #fallthrough_arm,
}
}
});
}

let bool_mapping = fields
.iter()
.enumerate()
.filter_map(|(_, (_, ident, _, key))| {
if let Some(syn::Lit::Bool(key)) = key {
return Some(quote!(#key => _serde::__private::Ok(#this_value::#ident)));
}
None
})
.collect::<Vec<_>>();
if !bool_mapping.is_empty() {
visits.push(quote! {
fn visit_bool<__E>(self, __value: bool) -> _serde::__private::Result<Self::Value, __E>
where
__E: _serde::de::Error,
{
match __value {
#(#bool_mapping,)*
_ => #fallthrough_arm,
}
}
});
}

let byte_mapping = fields
.iter()
.enumerate()
.filter_map(|(_, (_, ident, _, key))| {
if let Some(syn::Lit::Byte(key)) = key {
return Some(quote!(#key => _serde::__private::Ok(#this_value::#ident)));
}
None
})
.collect::<Vec<_>>();
if !byte_mapping.is_empty() {
visits.push(quote! {
fn visit_u8<__E>(self, __value: u8) -> _serde::__private::Result<Self::Value, __E>
where
__E: _serde::de::Error,
{
match __value {
#(#byte_mapping,)*
_ => #fallthrough_arm,
}
}
});
}

let u64_mapping = fields.iter().enumerate().map(|(i, (_, ident, _, key))| {
let i = if let Some(syn::Lit::Int(key)) = key {
key.base10_parse().unwrap_or(i as u64)
} else {
i as u64
};
quote!(#i => _serde::__private::Ok(#this_value::#ident))
});

Expand Down Expand Up @@ -2378,6 +2465,8 @@ fn deserialize_identifier(
_ => #u64_fallthrough_arm,
}
}

#( #visits )*
}
};

Expand Down
36 changes: 36 additions & 0 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ pub struct Field {
getter: Option<syn::ExprPath>,
flatten: bool,
transparent: bool,
key: Option<syn::Lit>,
}

/// Represents the default to use for a field when deserializing.
Expand Down Expand Up @@ -1089,6 +1090,7 @@ impl Field {
let mut borrowed_lifetimes = Attr::none(cx, BORROW);
let mut getter = Attr::none(cx, GETTER);
let mut flatten = BoolAttr::none(cx, FLATTEN);
let mut key = Attr::none(cx, KEY);

let ident = match &field.ident {
Some(ident) => unraw(ident),
Expand Down Expand Up @@ -1225,6 +1227,35 @@ impl Field {
} else if meta.path == FLATTEN {
// #[serde(flatten)]
flatten.set_true(&meta.path);
} else if meta.path == KEY {
// #[serde(key = ...)]
let expr: syn::Expr = meta.value()?.parse()?;
if let syn::Expr::Lit(syn::ExprLit { lit, .. }) = expr {
match lit {
syn::Lit::ByteStr(_)
| syn::Lit::Bool(_)
| syn::Lit::Char(_)
| syn::Lit::Byte(_)
| syn::Lit::Int(_) => key.set(&meta.path, lit),
syn::Lit::Str(_) => cx.error_spanned_by(
lit,
format!("use the serde {} attribute instead", RENAME),
),
syn::Lit::Float(_) => cx.error_spanned_by(
lit,
format!("serde {} attribute cannot be a float", KEY),
),
_ => cx.error_spanned_by(
lit,
format!("expected serde {} attribute to be a literal", KEY),
),
}
} else {
cx.error_spanned_by(
expr,
format!("expected serde {} attribute to be a literal", KEY),
);
}
} else {
let path = meta.path.to_token_stream().to_string().replace(' ', "");
return Err(
Expand Down Expand Up @@ -1312,6 +1343,7 @@ impl Field {
getter: getter.get(),
flatten: flatten.get(),
transparent: false,
key: key.get(),
}
}

Expand Down Expand Up @@ -1386,6 +1418,10 @@ impl Field {
pub fn mark_transparent(&mut self) {
self.transparent = true;
}

pub fn key(&self) -> Option<&syn::Lit> {
self.key.as_ref()
}
}

type SerAndDe<T> = (Option<T>, Option<T>);
Expand Down
1 change: 1 addition & 0 deletions serde_derive/src/internals/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub const FLATTEN: Symbol = Symbol("flatten");
pub const FROM: Symbol = Symbol("from");
pub const GETTER: Symbol = Symbol("getter");
pub const INTO: Symbol = Symbol("into");
pub const KEY: Symbol = Symbol("key");
pub const NON_EXHAUSTIVE: Symbol = Symbol("non_exhaustive");
pub const OTHER: Symbol = Symbol("other");
pub const REMOTE: Symbol = Symbol("remote");
Expand Down
19 changes: 14 additions & 5 deletions serde_derive/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ fn serialize_struct_tag_field(cattrs: &attr::Container, struct_trait: &StructTra
match cattrs.tag() {
attr::TagType::Internal { tag } => {
let type_name = cattrs.name().serialize_name();
let func = struct_trait.serialize_field(Span::call_site());
let func = struct_trait.serialize_field(Span::call_site(), None);
quote! {
#func(&mut __serde_state, #tag, #type_name)?;
}
Expand Down Expand Up @@ -1122,7 +1122,12 @@ fn serialize_struct_visitor(
get_member(params, field, member)
};

let key_expr = field.attrs.name().serialize_name();
let key_expr = if let Some(key) = field.attrs.key() {
quote!(&#key)
} else {
let name = field.attrs.name().serialize_name();
quote!(#name)
};

let skip = field
.attrs
Expand All @@ -1140,7 +1145,7 @@ fn serialize_struct_visitor(
#func(&#field_expr, _serde::__private::ser::FlatMapSerializer(&mut __serde_state))?;
}
} else {
let func = struct_trait.serialize_field(span);
let func = struct_trait.serialize_field(span, field.attrs.key());
quote! {
#func(&mut __serde_state, #key_expr, #field_expr)?;
}
Expand Down Expand Up @@ -1309,13 +1314,17 @@ enum StructTrait {
}

impl StructTrait {
fn serialize_field(&self, span: Span) -> TokenStream {
fn serialize_field(&self, span: Span, key_hint: Option<&syn::Lit>) -> TokenStream {
match *self {
StructTrait::SerializeMap => {
quote_spanned!(span=> _serde::ser::SerializeMap::serialize_entry)
}
StructTrait::SerializeStruct => {
quote_spanned!(span=> _serde::ser::SerializeStruct::serialize_field)
if key_hint.is_some() {
quote_spanned!(span=> _serde::ser::SerializeStruct::serialize_typed_field)
} else {
quote_spanned!(span=> _serde::ser::SerializeStruct::serialize_field)
}
}
StructTrait::SerializeStructVariant => {
quote_spanned!(span=> _serde::ser::SerializeStructVariant::serialize_field)
Expand Down
12 changes: 12 additions & 0 deletions test_suite/tests/test_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,18 @@ fn test_gen() {
#[derive(Deserialize)]
#[serde(bound(deserialize = "[&'de str; N]: Copy"))]
struct GenericUnitStruct<const N: usize>;

#[derive(Serialize, Deserialize)]
struct KeyedStruct<'se> {
#[serde(key = b'a')]
a: &'se str,
#[serde(key = 8)]
b: u16,
#[serde(key = 8i16)]
c: u32,
#[serde(key = 'c')]
d: u64,
}
}

//////////////////////////////////////////////////////////////////////////
Expand Down