Skip to content

Commit

Permalink
Add json attribute to FromRow derive
Browse files Browse the repository at this point in the history
  • Loading branch information
95ulisse committed Jul 17, 2023
1 parent c70cfaf commit 16a1026
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 28 deletions.
19 changes: 15 additions & 4 deletions sqlx-macros-core/src/derives/attributes.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{Attribute, DeriveInput, Field, Lit, Meta, MetaNameValue, NestedMeta, Type, Variant};
use syn::{
punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, DeriveInput, Field, Lit,
Meta, MetaNameValue, NestedMeta, Type, Variant,
};

macro_rules! assert_attribute {
($e:expr, $err:expr, $input:expr) => {
Expand Down Expand Up @@ -65,6 +65,7 @@ pub struct SqlxChildAttributes {
pub flatten: bool,
pub try_from: Option<Type>,
pub skip: bool,
pub json: bool,
}

pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
Expand Down Expand Up @@ -164,6 +165,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
let mut try_from = None;
let mut flatten = false;
let mut skip: bool = false;
let mut json = false;

for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
let meta = attr
Expand All @@ -187,12 +189,20 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
Meta::Path(path) if path.is_ident("default") => default = true,
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
Meta::Path(path) if path.is_ident("skip") => skip = true,
Meta::Path(path) if path.is_ident("json") => json = true,
u => fail!(u, "unexpected attribute"),
},
u => fail!(u, "unexpected attribute"),
}
}
}

if json && flatten {
fail!(
attr,
"Cannot use `json` and `flatten` together on the same field"
);
}
}

Ok(SqlxChildAttributes {
Expand All @@ -201,6 +211,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
flatten,
try_from,
skip,
json,
})
}

Expand Down
70 changes: 46 additions & 24 deletions sqlx-macros-core/src/derives/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,45 +78,67 @@ fn expand_derive_from_row_struct(
));
}

let expr: Expr = match (attributes.flatten, attributes.try_from) {
(true, None) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
(false, None) => {
let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();

let expr: Expr = match (attributes.flatten, attributes.try_from, attributes.json) {
// <No attributes>
(false, None, false) => {
predicates
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
}
(true,Some(try_from)) => {
// Flatten
(true, None, false) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
// Flatten + Try from
(true, Some(try_from), false) => {
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
(false,Some(try_from)) => {
// Flatten + Json
(true, _, true) => {
panic!("Cannot use both flatten and json")
}
// Try from
(false, Some(try_from), false) => {
predicates
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
// Try from + Json
(false, Some(try_from), true) => {
predicates
.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::types::Type<R::Database>));

parse_quote!(
row.try_get::<::sqlx::types::Json<_>, _>(#id_s).and_then(|v|
<#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v.0)
.map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))
)
)
},
// Json
(false, None, true) => {
predicates
.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::types::Type<R::Database>));

parse_quote!(row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0))
},
};

if attributes.default {
Expand Down
62 changes: 62 additions & 0 deletions tests/mysql/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,4 +468,66 @@ async fn test_try_from_attr_with_complex_type() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_from_row_json_attr() -> anyhow::Result<()> {
#[derive(serde::Deserialize)]
struct J {
a: u32,
b: u32,
}

#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(json)]
j: J,
}

let mut conn = new::<MySql>().await?;

let record = sqlx::query_as::<_, Record>("select json_object('a', 1, 'b', 2) as j")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.j.a, 1);
assert_eq!(record.j.b, 2);

Ok(())
}

#[sqlx_macros::test]
async fn test_from_row_json_try_from_attr() -> anyhow::Result<()> {
#[derive(serde::Deserialize)]
struct J {
a: u32,
b: u32,
}

// Non-deserializable
struct J2 {
sum: u32,
}

impl std::convert::From<J> for J2 {
fn from(j: J) -> Self {
Self { sum: j.a + j.b }
}
}

#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(json, try_from = "J")]
j: J2,
}

let mut conn = new::<MySql>().await?;

let record = sqlx::query_as::<_, Record>("select json_object('a', 1, 'b', 2) as j")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.j.sum, 3);

Ok(())
}

// we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant

0 comments on commit 16a1026

Please sign in to comment.