Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow f32 and f64 map keys #1027

Merged
merged 11 commits into from Jul 11, 2023
43 changes: 23 additions & 20 deletions src/de.rs
Expand Up @@ -2118,20 +2118,21 @@ struct MapKey<'a, R: 'a> {
de: &'a mut Deserializer<R>,
}

macro_rules! deserialize_integer_key {
($method:ident => $visit:ident) => {
macro_rules! deserialize_numeric_key {
($method:ident) => {
fn $method<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.de.eat_char();
self.de.scratch.clear();
let string = tri!(self.de.read.parse_str(&mut self.de.scratch));
match (string.parse(), string) {
(Ok(integer), _) => visitor.$visit(integer),
(Err(_), Reference::Borrowed(s)) => visitor.visit_borrowed_str(s),
(Err(_), Reference::Copied(s)) => visitor.visit_str(s),
let value = self.de.$method(visitor)?;

match self.de.peek()? {
Some(b'"') => self.de.eat_char(),
_ => return Err(self.de.peek_error(ErrorCode::ExpectedDoubleQuote)),
}

Ok(value)
}
};
}
Expand All @@ -2155,16 +2156,18 @@ where
}
}

deserialize_integer_key!(deserialize_i8 => visit_i8);
deserialize_integer_key!(deserialize_i16 => visit_i16);
deserialize_integer_key!(deserialize_i32 => visit_i32);
deserialize_integer_key!(deserialize_i64 => visit_i64);
deserialize_integer_key!(deserialize_i128 => visit_i128);
deserialize_integer_key!(deserialize_u8 => visit_u8);
deserialize_integer_key!(deserialize_u16 => visit_u16);
deserialize_integer_key!(deserialize_u32 => visit_u32);
deserialize_integer_key!(deserialize_u64 => visit_u64);
deserialize_integer_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_i8);
deserialize_numeric_key!(deserialize_i16);
deserialize_numeric_key!(deserialize_i32);
deserialize_numeric_key!(deserialize_i64);
deserialize_numeric_key!(deserialize_i128);
deserialize_numeric_key!(deserialize_u8);
deserialize_numeric_key!(deserialize_u16);
deserialize_numeric_key!(deserialize_u32);
deserialize_numeric_key!(deserialize_u64);
deserialize_numeric_key!(deserialize_u128);
deserialize_numeric_key!(deserialize_f32);
deserialize_numeric_key!(deserialize_f64);

#[inline]
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -2221,8 +2224,8 @@ where
}

forward_to_deserialize_any! {
bool f32 f64 char str string unit unit_struct seq tuple tuple_struct map
struct identifier ignored_any
bool char str string unit unit_struct seq tuple tuple_struct map struct
identifier ignored_any
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/error.rs
Expand Up @@ -64,12 +64,14 @@ impl Error {
| ErrorCode::ExpectedObjectCommaOrEnd
| ErrorCode::ExpectedSomeIdent
| ErrorCode::ExpectedSomeValue
| ErrorCode::ExpectedDoubleQuote
| ErrorCode::InvalidEscape
| ErrorCode::InvalidNumber
| ErrorCode::NumberOutOfRange
| ErrorCode::InvalidUnicodeCodePoint
| ErrorCode::ControlCharacterWhileParsingString
| ErrorCode::KeyMustBeAString
| ErrorCode::FloatKeyMustBeFinite
| ErrorCode::LoneLeadingSurrogateInHexEscape
| ErrorCode::TrailingComma
| ErrorCode::TrailingCharacters
Expand Down Expand Up @@ -264,6 +266,9 @@ pub(crate) enum ErrorCode {
/// Expected this character to start a JSON value.
ExpectedSomeValue,

/// Expected this character to be a `"`.
ExpectedDoubleQuote,

/// Invalid hex escape code.
InvalidEscape,

Expand All @@ -282,6 +287,9 @@ pub(crate) enum ErrorCode {
/// Object key is not a string.
KeyMustBeAString,

/// Object key is a non-finite float value.
FloatKeyMustBeFinite,

/// Lone leading surrogate in hex escape.
LoneLeadingSurrogateInHexEscape,

Expand Down Expand Up @@ -348,6 +356,7 @@ impl Display for ErrorCode {
ErrorCode::ExpectedObjectCommaOrEnd => f.write_str("expected `,` or `}`"),
ErrorCode::ExpectedSomeIdent => f.write_str("expected ident"),
ErrorCode::ExpectedSomeValue => f.write_str("expected value"),
ErrorCode::ExpectedDoubleQuote => f.write_str("expected `\"`"),
ErrorCode::InvalidEscape => f.write_str("invalid escape"),
ErrorCode::InvalidNumber => f.write_str("invalid number"),
ErrorCode::NumberOutOfRange => f.write_str("number out of range"),
Expand All @@ -356,6 +365,9 @@ impl Display for ErrorCode {
f.write_str("control character (\\u0000-\\u001F) found while parsing a string")
}
ErrorCode::KeyMustBeAString => f.write_str("key must be a string"),
ErrorCode::FloatKeyMustBeFinite => {
f.write_str("float key must be finite (got NaN or +/-inf)")
}
ErrorCode::LoneLeadingSurrogateInHexEscape => {
f.write_str("lone leading surrogate in hex escape")
}
Expand Down
46 changes: 42 additions & 4 deletions src/ser.rs
Expand Up @@ -789,6 +789,10 @@ fn key_must_be_a_string() -> Error {
Error::syntax(ErrorCode::KeyMustBeAString, 0, 0)
}

fn float_key_must_be_finite() -> Error {
Error::syntax(ErrorCode::FloatKeyMustBeFinite, 0, 0)
}

impl<'a, W, F> ser::Serializer for MapKeySerializer<'a, W, F>
where
W: io::Write,
Expand Down Expand Up @@ -1002,12 +1006,46 @@ where
.map_err(Error::io)
}

fn serialize_f32(self, _value: f32) -> Result<()> {
Err(key_must_be_a_string())
fn serialize_f32(self, value: f32) -> Result<()> {
if !value.is_finite() {
return Err(float_key_must_be_finite());
}

tri!(self
.ser
.formatter
.begin_string(&mut self.ser.writer)
.map_err(Error::io));
tri!(self
.ser
.formatter
.write_f32(&mut self.ser.writer, value)
.map_err(Error::io));
dtolnay marked this conversation as resolved.
Show resolved Hide resolved
self.ser
.formatter
.end_string(&mut self.ser.writer)
.map_err(Error::io)
}

fn serialize_f64(self, _value: f64) -> Result<()> {
Err(key_must_be_a_string())
fn serialize_f64(self, value: f64) -> Result<()> {
if !value.is_finite() {
return Err(float_key_must_be_finite());
}

tri!(self
.ser
.formatter
.begin_string(&mut self.ser.writer)
.map_err(Error::io));
tri!(self
.ser
.formatter
.write_f64(&mut self.ser.writer, value)
.map_err(Error::io));
self.ser
.formatter
.end_string(&mut self.ser.writer)
.map_err(Error::io)
}

fn serialize_char(self, value: char) -> Result<()> {
Expand Down
29 changes: 16 additions & 13 deletions src/value/de.rs
Expand Up @@ -1120,13 +1120,14 @@ struct MapKeyDeserializer<'de> {
key: Cow<'de, str>,
}

macro_rules! deserialize_integer_key {
macro_rules! deserialize_numeric_key {
($method:ident => $visit:ident) => {
fn $method<V>(self, visitor: V) -> Result<V::Value, Error>
where
V: Visitor<'de>,
{
match (self.key.parse(), self.key) {
let parsed = crate::from_str(&self.key);
match (parsed, self.key) {
(Ok(integer), _) => visitor.$visit(integer),
(Err(_), Cow::Borrowed(s)) => visitor.visit_borrowed_str(s),
#[cfg(any(feature = "std", feature = "alloc"))]
Expand All @@ -1146,16 +1147,18 @@ impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> {
BorrowedCowStrDeserializer::new(self.key).deserialize_any(visitor)
}

deserialize_integer_key!(deserialize_i8 => visit_i8);
deserialize_integer_key!(deserialize_i16 => visit_i16);
deserialize_integer_key!(deserialize_i32 => visit_i32);
deserialize_integer_key!(deserialize_i64 => visit_i64);
deserialize_integer_key!(deserialize_i128 => visit_i128);
deserialize_integer_key!(deserialize_u8 => visit_u8);
deserialize_integer_key!(deserialize_u16 => visit_u16);
deserialize_integer_key!(deserialize_u32 => visit_u32);
deserialize_integer_key!(deserialize_u64 => visit_u64);
deserialize_integer_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_i8 => visit_i8);
deserialize_numeric_key!(deserialize_i16 => visit_i16);
deserialize_numeric_key!(deserialize_i32 => visit_i32);
deserialize_numeric_key!(deserialize_i64 => visit_i64);
deserialize_numeric_key!(deserialize_i128 => visit_i128);
deserialize_numeric_key!(deserialize_u8 => visit_u8);
deserialize_numeric_key!(deserialize_u16 => visit_u16);
deserialize_numeric_key!(deserialize_u32 => visit_u32);
deserialize_numeric_key!(deserialize_u64 => visit_u64);
deserialize_numeric_key!(deserialize_u128 => visit_u128);
deserialize_numeric_key!(deserialize_f32 => visit_f32);
deserialize_numeric_key!(deserialize_f64 => visit_f64);
dtolnay marked this conversation as resolved.
Show resolved Hide resolved

#[inline]
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Error>
Expand Down Expand Up @@ -1193,7 +1196,7 @@ impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> {
}

forward_to_deserialize_any! {
bool f32 f64 char str string bytes byte_buf unit unit_struct seq tuple
bool char str string bytes byte_buf unit unit_struct seq tuple
tuple_struct map struct identifier ignored_any
}
}
Expand Down
20 changes: 16 additions & 4 deletions src/value/ser.rs
Expand Up @@ -451,6 +451,10 @@ fn key_must_be_a_string() -> Error {
Error::syntax(ErrorCode::KeyMustBeAString, 0, 0)
}

fn float_key_must_be_finite() -> Error {
Error::syntax(ErrorCode::FloatKeyMustBeFinite, 0, 0)
}

impl serde::Serializer for MapKeySerializer {
type Ok = String;
type Error = Error;
Expand Down Expand Up @@ -517,12 +521,20 @@ impl serde::Serializer for MapKeySerializer {
Ok(value.to_string())
}

fn serialize_f32(self, _value: f32) -> Result<String> {
Err(key_must_be_a_string())
fn serialize_f32(self, value: f32) -> Result<String> {
if value.is_finite() {
Ok(ryu::Buffer::new().format_finite(value).to_owned())
} else {
Err(float_key_must_be_finite())
}
}

fn serialize_f64(self, _value: f64) -> Result<String> {
Err(key_must_be_a_string())
fn serialize_f64(self, value: f64) -> Result<String> {
if value.is_finite() {
Ok(ryu::Buffer::new().format_finite(value).to_owned())
} else {
Err(float_key_must_be_finite())
}
}

#[inline]
Expand Down
85 changes: 77 additions & 8 deletions tests/test.rs
Expand Up @@ -1897,10 +1897,7 @@ fn test_integer_key() {
test_parse_ok(vec![(j, map)]);

let j = r#"{"x":null}"#;
test_parse_err::<BTreeMap<i32, ()>>(&[(
j,
"invalid type: string \"x\", expected i32 at line 1 column 4",
)]);
test_parse_err::<BTreeMap<i32, ()>>(&[(j, "expected value at line 1 column 3")]);
}

#[test]
Expand All @@ -1914,20 +1911,92 @@ fn test_integer128_key() {
}

#[test]
fn test_deny_float_key() {
#[derive(Eq, PartialEq, Ord, PartialOrd)]
fn test_float_key() {
#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone)]
struct Float;
impl Serialize for Float {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_f32(1.0)
serializer.serialize_f32(1.23)
}
}
impl<'de> Deserialize<'de> for Float {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
f32::deserialize(deserializer).map(|_| Float)
}
}

// map with float key
let map = treemap!(Float => "x");
let map = treemap!(Float => "x".to_owned());
let j = r#"{"1.23":"x"}"#;

test_encode_ok(&[(&map, j)]);
test_parse_ok(vec![(j, map)]);

let j = r#"{"x": null}"#;
test_parse_err::<BTreeMap<Float, ()>>(&[(j, "expected value at line 1 column 3")]);
}

#[test]
fn test_deny_non_finite_f32_key() {
// We store float bits so that we can derive `Ord`, and other traits. In a real context, we
// would use a crate like `ordered-float` instead.

#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone)]
struct F32Bits(u32);
impl Serialize for F32Bits {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_f32(f32::from_bits(self.0))
}
}

let map = treemap!(F32Bits(f32::INFINITY.to_bits()) => "x".to_owned());
assert!(serde_json::to_string(&map).is_err());
assert!(serde_json::to_value(map).is_err());

let map = treemap!(F32Bits(f32::NEG_INFINITY.to_bits()) => "x".to_owned());
assert!(serde_json::to_string(&map).is_err());
assert!(serde_json::to_value(map).is_err());

let map = treemap!(F32Bits(f32::NAN.to_bits()) => "x".to_owned());
assert!(serde_json::to_string(&map).is_err());
assert!(serde_json::to_value(map).is_err());
}

#[test]
fn test_deny_non_finite_f64_key() {
// We store float bits so that we can derive `Ord`, and other traits. In a real context, we
// would use a crate like `ordered-float` instead.

#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone)]
struct F64Bits(u64);
impl Serialize for F64Bits {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_f64(f64::from_bits(self.0))
}
}

let map = treemap!(F64Bits(f64::INFINITY.to_bits()) => "x".to_owned());
assert!(serde_json::to_string(&map).is_err());
assert!(serde_json::to_value(map).is_err());

let map = treemap!(F64Bits(f64::NEG_INFINITY.to_bits()) => "x".to_owned());
assert!(serde_json::to_string(&map).is_err());
assert!(serde_json::to_value(map).is_err());

let map = treemap!(F64Bits(f64::NAN.to_bits()) => "x".to_owned());
assert!(serde_json::to_string(&map).is_err());
assert!(serde_json::to_value(map).is_err());
}

Expand Down