diff --git a/crates/toml/src/value.rs b/crates/toml/src/value.rs index 8ed75e81..2785d9de 100644 --- a/crates/toml/src/value.rs +++ b/crates/toml/src/value.rs @@ -584,7 +584,7 @@ impl<'de> de::Deserializer<'de> for Value { #[inline] fn deserialize_enum( self, - _name: &str, + _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result @@ -593,6 +593,21 @@ impl<'de> de::Deserializer<'de> for Value { { match self { Value::String(variant) => visitor.visit_enum(variant.into_deserializer()), + Value::Table(variant) => { + use de::Error; + if variant.is_empty() { + Err(crate::de::Error::custom( + "wanted exactly 1 element, found 0 elements", + )) + } else if variant.len() != 1 { + Err(crate::de::Error::custom( + "wanted exactly 1 element, more than 1 element", + )) + } else { + let deserializer = MapDeserializer::new(variant); + visitor.visit_enum(deserializer) + } + } _ => Err(de::Error::invalid_type( de::Unexpected::UnitVariant, &"string only", @@ -712,6 +727,132 @@ impl<'de> de::MapAccess<'de> for MapDeserializer { } } +impl<'de> de::EnumAccess<'de> for MapDeserializer { + type Error = crate::de::Error; + type Variant = MapEnumDeserializer; + + fn variant_seed(mut self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: serde::de::DeserializeSeed<'de>, + { + use de::Error; + let (key, value) = match self.iter.next() { + Some(pair) => pair, + None => { + return Err(Error::custom( + "expected table with exactly 1 entry, found empty table", + )); + } + }; + + let val = seed.deserialize(key.into_deserializer())?; + + let variant = MapEnumDeserializer::new(value); + + Ok((val, variant)) + } +} + +/// Deserializes table values into enum variants. +pub(crate) struct MapEnumDeserializer { + value: Value, +} + +impl MapEnumDeserializer { + pub(crate) fn new(value: Value) -> Self { + MapEnumDeserializer { value } + } +} + +impl<'de> serde::de::VariantAccess<'de> for MapEnumDeserializer { + type Error = crate::de::Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + use de::Error; + match self.value { + Value::Table(values) => { + if values.is_empty() { + Ok(()) + } else { + Err(Error::custom("expected empty table")) + } + } + e => Err(Error::custom(format!( + "expected table, found {}", + e.type_str() + ))), + } + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(self.value.into_deserializer()) + } + + fn tuple_variant(self, len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + use de::Error; + match self.value { + Value::Table(values) => { + let tuple_values = values + .into_iter() + .enumerate() + .map(|(index, (key, value))| match key.parse::() { + Ok(key_index) if key_index == index => Ok(value), + Ok(_) | Err(_) => Err(Error::custom(format!( + "expected table key `{}`, but was `{}`", + index, key + ))), + }) + // Fold all values into a `Vec`, or return the first error. + .fold(Ok(Vec::with_capacity(len)), |result, value_result| { + result.and_then(move |mut tuple_values| match value_result { + Ok(value) => { + tuple_values.push(value); + Ok(tuple_values) + } + // `Result` to `Result, Self::Error>` + Err(e) => Err(e), + }) + })?; + + if tuple_values.len() == len { + serde::de::Deserializer::deserialize_seq( + tuple_values.into_deserializer(), + visitor, + ) + } else { + Err(Error::custom(format!("expected tuple with length {}", len))) + } + } + e => Err(Error::custom(format!( + "expected table, found {}", + e.type_str() + ))), + } + } + + fn struct_variant( + self, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + serde::de::Deserializer::deserialize_struct( + self.value.into_deserializer(), + "", // TODO: this should be the variant name + fields, + visitor, + ) + } +} + impl<'de> de::IntoDeserializer<'de, crate::de::Error> for Value { type Deserializer = Self; @@ -827,15 +968,18 @@ impl ser::Serializer for ValueSerializer { fn serialize_newtype_variant( self, - name: &'static str, + _name: &'static str, _variant_index: u32, - _variant: &'static str, - _value: &T, + variant: &'static str, + value: &T, ) -> Result where T: ser::Serialize, { - Err(crate::ser::Error::unsupported_type(Some(name))) + let value = value.serialize(ValueSerializer)?; + let mut table = Table::new(); + table.insert(variant.to_owned(), value); + Ok(table.into()) } fn serialize_none(self) -> Result { @@ -1005,15 +1149,18 @@ impl ser::Serializer for TableSerializer { fn serialize_newtype_variant( self, - name: &'static str, + _name: &'static str, _variant_index: u32, - _variant: &'static str, - _value: &T, + variant: &'static str, + value: &T, ) -> Result where T: ser::Serialize, { - Err(crate::ser::Error::unsupported_type(Some(name))) + let value = value.serialize(ValueSerializer)?; + let mut table = Table::new(); + table.insert(variant.to_owned(), value); + Ok(table) } fn serialize_none(self) -> Result { diff --git a/crates/toml/tests/testsuite/serde.rs b/crates/toml/tests/testsuite/serde.rs index 3b9f65a4..6e927f73 100644 --- a/crates/toml/tests/testsuite/serde.rs +++ b/crates/toml/tests/testsuite/serde.rs @@ -613,6 +613,28 @@ fn newtypes2() { } } +#[test] +fn newtype_variant() { + #[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] + struct Struct { + field: Enum, + } + + #[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] + enum Enum { + Variant(u8), + } + + equivalent! { + Struct { field: Enum::Variant(21) }, + map! { + field: map! { + Variant: Value::Integer(21) + } + }, + } +} + #[derive(Debug, Default, PartialEq, Serialize, Deserialize)] struct CanBeEmpty { a: Option, diff --git a/crates/toml_edit/src/ser/value.rs b/crates/toml_edit/src/ser/value.rs index cc7dfb77..d29390a4 100644 --- a/crates/toml_edit/src/ser/value.rs +++ b/crates/toml_edit/src/ser/value.rs @@ -167,15 +167,18 @@ impl serde::ser::Serializer for ValueSerializer { fn serialize_newtype_variant( self, - name: &'static str, + _name: &'static str, _variant_index: u32, - _variant: &'static str, - _value: &T, + variant: &'static str, + value: &T, ) -> Result where T: serde::ser::Serialize, { - Err(Error::UnsupportedType(Some(name))) + let value = value.serialize(self)?; + let mut table = crate::InlineTable::new(); + table.insert(variant, value); + Ok(table.into()) } fn serialize_seq(self, len: Option) -> Result {