Skip to content

Commit

Permalink
fix(serde): Newtype variant support
Browse files Browse the repository at this point in the history
Fixes #553
  • Loading branch information
epage committed May 18, 2023
1 parent 87836d9 commit 816a63e
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 14 deletions.
165 changes: 156 additions & 9 deletions crates/toml/src/value.rs
Expand Up @@ -584,7 +584,7 @@ impl<'de> de::Deserializer<'de> for Value {
#[inline]
fn deserialize_enum<V>(
self,
_name: &str,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, crate::de::Error>
Expand All @@ -593,6 +593,21 @@ impl<'de> de::Deserializer<'de> for Value {
{
match self {
Value::String(variant) => visitor.visit_enum(variant.into_deserializer()),
Value::Table(variant) => {
use de::Error;
if variant.is_empty() {
Err(crate::de::Error::custom(
"wanted exactly 1 element, found 0 elements",
))
} else if variant.len() != 1 {
Err(crate::de::Error::custom(
"wanted exactly 1 element, more than 1 element",
))
} else {
let mut deserializer = MapDeserializer::new(variant);
visitor.visit_enum(deserializer)
}
}
_ => Err(de::Error::invalid_type(
de::Unexpected::UnitVariant,
&"string only",
Expand Down Expand Up @@ -712,6 +727,132 @@ impl<'de> de::MapAccess<'de> for MapDeserializer {
}
}

impl<'de> de::EnumAccess<'de> for MapDeserializer {
type Error = crate::de::Error;
type Variant = MapEnumDeserializer;

fn variant_seed<V>(mut self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: serde::de::DeserializeSeed<'de>,
{
use de::Error;
let (key, value) = match self.iter.next() {
Some(pair) => pair,
None => {
return Err(Error::custom(
"expected table with exactly 1 entry, found empty table",
));
}
};

let val = seed.deserialize(key.into_deserializer())?;

let variant = MapEnumDeserializer::new(value);

Ok((val, variant))
}
}

/// Deserializes table values into enum variants.
pub(crate) struct MapEnumDeserializer {
value: Value,
}

impl MapEnumDeserializer {
pub(crate) fn new(value: Value) -> Self {
MapEnumDeserializer { value }
}
}

impl<'de> serde::de::VariantAccess<'de> for MapEnumDeserializer {
type Error = crate::de::Error;

fn unit_variant(self) -> Result<(), Self::Error> {
use de::Error;
match self.value {
Value::Table(values) => {
if values.is_empty() {
Ok(())
} else {
Err(Error::custom("expected empty table"))
}
}
e => Err(Error::custom(format!(
"expected table, found {}",
e.type_str()
))),
}
}

fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
where
T: serde::de::DeserializeSeed<'de>,
{
seed.deserialize(self.value.into_deserializer())
}

fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
use de::Error;
match self.value {
Value::Table(values) => {
let tuple_values = values
.into_iter()
.enumerate()
.map(|(index, (key, value))| match key.parse::<usize>() {
Ok(key_index) if key_index == index => Ok(value),
Ok(_) | Err(_) => Err(Error::custom(format!(
"expected table key `{}`, but was `{}`",
index, key
))),
})
// Fold all values into a `Vec`, or return the first error.
.fold(Ok(Vec::with_capacity(len)), |result, value_result| {
result.and_then(move |mut tuple_values| match value_result {
Ok(value) => {
tuple_values.push(value);
Ok(tuple_values)
}
// `Result<de::Value, Self::Error>` to `Result<Vec<_>, Self::Error>`
Err(e) => Err(e),
})
})?;

if tuple_values.len() == len {
serde::de::Deserializer::deserialize_seq(
tuple_values.into_deserializer(),
visitor,
)
} else {
Err(Error::custom(format!("expected tuple with length {}", len)))
}
}
e => Err(Error::custom(format!(
"expected table, found {}",
e.type_str()
))),
}
}

fn struct_variant<V>(
self,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
serde::de::Deserializer::deserialize_struct(
self.value.into_deserializer(),
"", // TODO: this should be the variant name
fields,
visitor,
)
}
}

impl<'de> de::IntoDeserializer<'de, crate::de::Error> for Value {
type Deserializer = Self;

Expand Down Expand Up @@ -827,15 +968,18 @@ impl ser::Serializer for ValueSerializer {

fn serialize_newtype_variant<T: ?Sized>(
self,
name: &'static str,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
variant: &'static str,
value: &T,
) -> Result<Value, crate::ser::Error>
where
T: ser::Serialize,
{
Err(crate::ser::Error::unsupported_type(Some(name)))
let value = value.serialize(ValueSerializer)?;
let mut table = Table::new();
table.insert(variant.to_owned(), value);
Ok(table.into())
}

fn serialize_none(self) -> Result<Value, crate::ser::Error> {
Expand Down Expand Up @@ -1005,15 +1149,18 @@ impl ser::Serializer for TableSerializer {

fn serialize_newtype_variant<T: ?Sized>(
self,
name: &'static str,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
variant: &'static str,
value: &T,
) -> Result<Table, crate::ser::Error>
where
T: ser::Serialize,
{
Err(crate::ser::Error::unsupported_type(Some(name)))
let value = value.serialize(ValueSerializer)?;
let mut table = Table::new();
table.insert(variant.to_owned(), value);
Ok(table)
}

fn serialize_none(self) -> Result<Table, crate::ser::Error> {
Expand Down
1 change: 0 additions & 1 deletion crates/toml/tests/testsuite/serde.rs
Expand Up @@ -614,7 +614,6 @@ fn newtypes2() {
}

#[test]
#[should_panic]
fn newtype_variant() {
#[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct Struct {
Expand Down
11 changes: 7 additions & 4 deletions crates/toml_edit/src/ser/value.rs
Expand Up @@ -167,15 +167,18 @@ impl serde::ser::Serializer for ValueSerializer {

fn serialize_newtype_variant<T: ?Sized>(
self,
name: &'static str,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
variant: &'static str,
value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: serde::ser::Serialize,
{
Err(Error::UnsupportedType(Some(name)))
let value = value.serialize(self)?;
let mut table = crate::InlineTable::new();
table.insert(variant, value);
Ok(table.into())
}

fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
Expand Down

0 comments on commit 816a63e

Please sign in to comment.