Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JSON support to FromRow derive #2121

Merged
merged 2 commits into from Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 43 additions & 2 deletions sqlx-core/src/from_row.rs
@@ -1,5 +1,4 @@
use crate::error::Error;
use crate::row::Row;
use crate::{error::Error, row::Row};

/// A record that can be built from a row returned by the database.
///
Expand Down Expand Up @@ -210,6 +209,48 @@ use crate::row::Row;
///
/// In MySql, `BigInt` type matches `i64`, but you can convert it to `u64` by `try_from`.
///
/// #### `json`
///
/// If your database supports a JSON type, you can leverage `#[sqlx(json)]`
/// to automatically integrate JSON deserialization in your [`FromRow`] implementation using [`serde`](https://docs.rs/serde/latest/serde/).
///
/// ```rust,ignore
/// #[derive(serde::Deserialize)]
/// struct Data {
/// field1: String,
/// field2: u64
/// }
///
/// #[derive(sqlx::FromRow)]
/// struct User {
/// id: i32,
/// name: String,
/// #[sqlx(json)]
/// metadata: Data
/// }
/// ```
///
/// Given a query like the following:
///
/// ```sql
/// SELECT
/// 1 AS id,
/// 'Name' AS name,
/// JSON_OBJECT('field1', 'value1', 'field2', 42) AS metadata
/// ```
///
/// The `metadata` field will be deserialized used its `serde::Deserialize` implementation:
///
/// ```rust,ignore
/// User {
/// id: 1,
/// name: "Name",
/// metadata: Data {
/// field1: "value1",
/// field2: 42
/// }
/// }
/// ```
pub trait FromRow<'r, R: Row>: Sized {
fn from_row(row: &'r R) -> Result<Self, Error>;
}
Expand Down
19 changes: 15 additions & 4 deletions sqlx-macros-core/src/derives/attributes.rs
@@ -1,9 +1,9 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{Attribute, DeriveInput, Field, Lit, Meta, MetaNameValue, NestedMeta, Type, Variant};
use syn::{
punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, DeriveInput, Field, Lit,
Meta, MetaNameValue, NestedMeta, Type, Variant,
};

macro_rules! assert_attribute {
($e:expr, $err:expr, $input:expr) => {
Expand Down Expand Up @@ -65,6 +65,7 @@ pub struct SqlxChildAttributes {
pub flatten: bool,
pub try_from: Option<Type>,
pub skip: bool,
pub json: bool,
}

pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
Expand Down Expand Up @@ -164,6 +165,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
let mut try_from = None;
let mut flatten = false;
let mut skip: bool = false;
let mut json = false;

for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
let meta = attr
Expand All @@ -187,12 +189,20 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
Meta::Path(path) if path.is_ident("default") => default = true,
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
Meta::Path(path) if path.is_ident("skip") => skip = true,
Meta::Path(path) if path.is_ident("json") => json = true,
u => fail!(u, "unexpected attribute"),
},
u => fail!(u, "unexpected attribute"),
}
}
}

if json && flatten {
fail!(
attr,
"Cannot use `json` and `flatten` together on the same field"
);
}
}

Ok(SqlxChildAttributes {
Expand All @@ -201,6 +211,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
flatten,
try_from,
skip,
json,
})
}

Expand Down
70 changes: 46 additions & 24 deletions sqlx-macros-core/src/derives/row.rs
Expand Up @@ -78,45 +78,67 @@ fn expand_derive_from_row_struct(
));
}

let expr: Expr = match (attributes.flatten, attributes.try_from) {
(true, None) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
(false, None) => {
let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();

let expr: Expr = match (attributes.flatten, attributes.try_from, attributes.json) {
// <No attributes>
(false, None, false) => {
predicates
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
}
(true,Some(try_from)) => {
// Flatten
(true, None, false) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
// Flatten + Try from
(true, Some(try_from), false) => {
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
(false,Some(try_from)) => {
// Flatten + Json
(true, _, true) => {
panic!("Cannot use both flatten and json")
}
// Try from
(false, Some(try_from), false) => {
predicates
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
// Try from + Json
(false, Some(try_from), true) => {
predicates
.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::types::Type<R::Database>));

parse_quote!(
row.try_get::<::sqlx::types::Json<_>, _>(#id_s).and_then(|v|
<#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v.0)
.map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))
)
)
},
// Json
(false, None, true) => {
predicates
.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::types::Type<R::Database>));

parse_quote!(row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0))
},
};

if attributes.default {
Expand Down
62 changes: 62 additions & 0 deletions tests/mysql/macros.rs
Expand Up @@ -468,4 +468,66 @@ async fn test_try_from_attr_with_complex_type() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_from_row_json_attr() -> anyhow::Result<()> {
#[derive(serde::Deserialize)]
struct J {
a: u32,
b: u32,
}

#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(json)]
j: J,
}

let mut conn = new::<MySql>().await?;

let record = sqlx::query_as::<_, Record>("select json_object('a', 1, 'b', 2) as j")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.j.a, 1);
assert_eq!(record.j.b, 2);

Ok(())
}

#[sqlx_macros::test]
async fn test_from_row_json_try_from_attr() -> anyhow::Result<()> {
#[derive(serde::Deserialize)]
struct J {
a: u32,
b: u32,
}

// Non-deserializable
struct J2 {
sum: u32,
}

impl std::convert::From<J> for J2 {
fn from(j: J) -> Self {
Self { sum: j.a + j.b }
}
}

#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(json, try_from = "J")]
j: J2,
}

let mut conn = new::<MySql>().await?;

let record = sqlx::query_as::<_, Record>("select json_object('a', 1, 'b', 2) as j")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.j.sum, 3);

Ok(())
}

// we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant