From 91c60bd184639586305fdb55260234f79174d71d Mon Sep 17 00:00:00 2001 From: Michele d'Amico Date: Sun, 19 Mar 2023 10:53:25 +0100 Subject: [PATCH] Auto await feature (#186) * WIP: (Refactoring) Implemented the future boilerplate in the rendering stage. Just render and test (missing the parse stage). * WIP: removed `ReplaceFutureAttribute` hack in `fixture` and moved the parsing code in `extend_with_function_attrs` trait impl * Move test in the right module and simplfied it * Add check also for no future args * Refactored and removed useless tests that're already tested in future module * Use Arguments also in tests and removed the old impl. Also recoverd tests for errors but implemented them in future module * WIP: Implementing await policy * WIP: implemented the render part. * WIP: implemented tests for is_future_await logic * WIP: Implemented future options parsing. Missed the global one * WIP implemented all global parser, pass arguments info down to the renderers and impleneted all tests in render code to check that the parsed info will used correctly. * Changed impl: we should await in the original method because we want the future in the signature. Now E2E tests for fixture are enabled and work. * Integration tests * Add documentation and changelog info --- CHANGELOG.md | 2 + README.md | 23 ++ playground/src/main.rs | 38 ++- rstest/Cargo.toml | 2 +- rstest/tests/fixture/mod.rs | 9 +- .../fixture/await_complete_fixture.rs | 105 ++++++ .../fixture/await_partial_fixture.rs | 94 ++++++ rstest/tests/resources/fixture/no_warning.rs | 13 +- .../tests/resources/rstest/cases/async_awt.rs | 28 ++ .../rstest/cases/async_awt_global.rs | 30 ++ .../resources/rstest/matrix/async_awt.rs | 12 + .../rstest/matrix/async_awt_global.rs | 13 + .../resources/rstest/single/async_awt.rs | 28 ++ .../rstest/single/async_awt_global.rs | 32 ++ rstest/tests/resources/rstest/timeout.rs | 30 +- rstest/tests/rstest/mod.rs | 28 +- rstest_macros/Cargo.toml | 2 +- rstest_macros/src/lib.rs | 37 ++- rstest_macros/src/parse/fixture.rs | 32 +- rstest_macros/src/parse/future.rs | 299 ++++++++++------- rstest_macros/src/parse/mod.rs | 135 +++++++- rstest_macros/src/parse/rstest.rs | 32 +- rstest_macros/src/render/apply_argumets.rs | 249 ++++++++++++++ rstest_macros/src/render/fixture.rs | 93 +++++- rstest_macros/src/render/mod.rs | 51 +-- rstest_macros/src/render/test.rs | 311 +++++++++++++++++- rstest_macros/src/test.rs | 14 +- 27 files changed, 1524 insertions(+), 218 deletions(-) create mode 100644 rstest/tests/resources/fixture/await_complete_fixture.rs create mode 100644 rstest/tests/resources/fixture/await_partial_fixture.rs create mode 100644 rstest/tests/resources/rstest/cases/async_awt.rs create mode 100644 rstest/tests/resources/rstest/cases/async_awt_global.rs create mode 100644 rstest/tests/resources/rstest/matrix/async_awt.rs create mode 100644 rstest/tests/resources/rstest/matrix/async_awt_global.rs create mode 100644 rstest/tests/resources/rstest/single/async_awt.rs create mode 100644 rstest/tests/resources/rstest/single/async_awt_global.rs create mode 100644 rstest_macros/src/render/apply_argumets.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 628d3b1..d92febd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## [0.17.0] Unreleased ### Add +- Add `#[awt]` and `#[timeout(awt)]` to `.await` future input + ### Changed ### Fixed diff --git a/README.md b/README.md index ce3ed22..9adcba0 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,29 @@ async fn my_async_test(#[future] base: u32, #[case] expected: u32, #[future] #[c } ``` +As you noted you should `.await` all _future_ values and this some times can be really boring. +In this case you can use `#[timeout(awt)]` to _awaiting_ an input or annotating your function +with `#[awt]` attributes to globally `.await` all your _future_ inputs. Previous code can be +simplified like follow: +```rust +use rstest::*; +# #[fixture] +# async fn base() -> u32 { 42 } +#[rstest] +#[case(21, async { 2 })] +#[case(6, async { 7 })] +#[awt] +async fn global(#[future] base: u32, #[case] expected: u32, #[future] #[case] div: u32) { + assert_eq!(expected, base / div); +} +#[rstest] +#[case(21, async { 2 })] +#[case(6, async { 7 })] +async fn single(#[future] base: u32, #[case] expected: u32, #[future(awt)] #[case] div: u32) { + assert_eq!(expected, base.await / div); +} +``` + ### Test `#[timeout()]` You can define an execution timeout for your tests with `#[timeout()]` attribute. Timeouts diff --git a/playground/src/main.rs b/playground/src/main.rs index 4c7c1b7..3a6b32f 100644 --- a/playground/src/main.rs +++ b/playground/src/main.rs @@ -1,12 +1,38 @@ use rstest::*; -fn valid_user(name: &str, age: u8) -> bool { - true + +#[fixture] +#[awt] +async fn two_args_mix_fixture( + #[future] + #[default(async { 4 })] + four: u32, + #[default(2)] two: u32, +) -> u32 { + four * 10 + two } +// #[rstest] +// #[awt] +// async fn use_two_args_mix_fixture(#[future] two_args_mix_fixture: u32) { +// assert_eq!(42, two_args_mix_fixture); +// } + +// #[rstest] +// #[awt] +// async fn use_two_args_mix_fixture_inject_first( +// #[future] +// #[with(async { 5 })] +// two_args_mix_fixture: u32, +// ) { +// assert_eq!(52, two_args_mix_fixture); +// } + #[rstest] -fn should_accept_all_corner_cases( - #[values("J", "A", "A________________________________________21")] name: &str, - #[values(14, 100)] age: u8, +#[awt] +async fn use_two_args_mix_fixture_inject_both( + #[future] + #[with(async { 3 }, 1)] + two_args_mix_fixture: u32, ) { - assert!(valid_user(name, age)) + assert_eq!(31, two_args_mix_fixture); } diff --git a/rstest/Cargo.toml b/rstest/Cargo.toml index d8dffe0..df29736 100644 --- a/rstest/Cargo.toml +++ b/rstest/Cargo.toml @@ -23,7 +23,7 @@ default = ["async-timeout"] [dependencies] futures = {version = "0.3.21", optional = true} futures-timer = {version = "3.0.2", optional = true} -rstest_macros = {version = "0.16.0", path = "../rstest_macros", default-features = false} +rstest_macros = {version = "0.17.0", path = "../rstest_macros", default-features = false} [dev-dependencies] actix-rt = "2.7.0" diff --git a/rstest/tests/fixture/mod.rs b/rstest/tests/fixture/mod.rs index fcc2d21..6c769a6 100644 --- a/rstest/tests/fixture/mod.rs +++ b/rstest/tests/fixture/mod.rs @@ -93,9 +93,12 @@ mod should { } } - #[test] - fn resolve_async_fixture() { - let prj = prj("async_fixture.rs"); + #[rstest] + #[case::base("async_fixture.rs")] + #[case::use_global("await_complete_fixture.rs")] + #[case::use_selective("await_partial_fixture.rs")] + fn resolve_async_fixture(#[case] code: &str) { + let prj = prj(code); prj.add_dependency("async-std", r#"{version="*", features=["attributes"]}"#); let output = prj.run_tests().unwrap(); diff --git a/rstest/tests/resources/fixture/await_complete_fixture.rs b/rstest/tests/resources/fixture/await_complete_fixture.rs new file mode 100644 index 0000000..8962d8e --- /dev/null +++ b/rstest/tests/resources/fixture/await_complete_fixture.rs @@ -0,0 +1,105 @@ +use std::io::prelude::*; + +use rstest::*; + +#[fixture] +async fn async_u32() -> u32 { + 42 +} + +#[fixture] +#[awt] +async fn nest_fixture(#[future] async_u32: u32) -> u32 { + async_u32 +} + +#[fixture] +#[awt] +async fn nest_fixture_with_default( + #[future] + #[default(async { 42 })] + fortytwo: u32, +) -> u32 { + fortytwo +} + +#[rstest] +async fn default_is_async() { + assert_eq!(42, async_u32::default().await); +} + +#[rstest] +#[awt] +async fn use_async_nest_fixture_default(#[future] nest_fixture: u32) { + assert_eq!(42, nest_fixture); +} + +#[rstest] +#[awt] +async fn use_async_nest_fixture_injected( + #[future] + #[with(async { 24 })] + nest_fixture: u32, +) { + assert_eq!(24, nest_fixture); +} + +#[rstest] +#[awt] +async fn use_async_nest_fixture_with_default(#[future] nest_fixture_with_default: u32) { + assert_eq!(42, nest_fixture_with_default); +} + +#[rstest] +#[awt] +async fn use_async_fixture(#[future] async_u32: u32) { + assert_eq!(42, async_u32); +} + +#[fixture] +async fn async_impl_output() -> impl Read { + std::io::Cursor::new(vec![1, 2, 3, 4, 5]) +} + +#[rstest] +#[awt] +async fn use_async_impl_output(#[future] async_impl_output: T) { + let reader = async_impl_output; +} + +#[fixture] +#[awt] +async fn two_args_mix_fixture( + #[future] + #[default(async { 4 })] + four: u32, + #[default(2)] two: u32, +) -> u32 { + four * 10 + two +} + +#[rstest] +#[awt] +async fn use_two_args_mix_fixture(#[future] two_args_mix_fixture: u32) { + assert_eq!(42, two_args_mix_fixture); +} + +#[rstest] +#[awt] +async fn use_two_args_mix_fixture_inject_first( + #[future] + #[with(async { 5 })] + two_args_mix_fixture: u32, +) { + assert_eq!(52, two_args_mix_fixture); +} + +#[rstest] +#[awt] +async fn use_two_args_mix_fixture_inject_both( + #[future] + #[with(async { 3 }, 1)] + two_args_mix_fixture: u32, +) { + assert_eq!(31, two_args_mix_fixture); +} diff --git a/rstest/tests/resources/fixture/await_partial_fixture.rs b/rstest/tests/resources/fixture/await_partial_fixture.rs new file mode 100644 index 0000000..9e26f4a --- /dev/null +++ b/rstest/tests/resources/fixture/await_partial_fixture.rs @@ -0,0 +1,94 @@ +use std::io::prelude::*; + +use rstest::*; + +#[fixture] +async fn async_u32() -> u32 { + 42 +} + +#[fixture] +async fn nest_fixture(#[future(awt)] async_u32: u32) -> u32 { + async_u32 +} + +#[fixture] +async fn nest_fixture_with_default( + #[future(awt)] + #[default(async { 42 })] + fortytwo: u32, +) -> u32 { + fortytwo +} + +#[rstest] +async fn default_is_async() { + assert_eq!(42, async_u32::default().await); +} + +#[rstest] +async fn use_async_nest_fixture_default(#[future(awt)] nest_fixture: u32) { + assert_eq!(42, nest_fixture); +} + +#[rstest] +async fn use_async_nest_fixture_injected( + #[future(awt)] + #[with(async { 24 })] + nest_fixture: u32, +) { + assert_eq!(24, nest_fixture); +} + +#[rstest] +async fn use_async_nest_fixture_with_default(#[future(awt)] nest_fixture_with_default: u32) { + assert_eq!(42, nest_fixture_with_default); +} + +#[rstest] +async fn use_async_fixture(#[future(awt)] async_u32: u32) { + assert_eq!(42, async_u32); +} + +#[fixture] +async fn async_impl_output() -> impl Read { + std::io::Cursor::new(vec![1, 2, 3, 4, 5]) +} + +#[rstest] +async fn use_async_impl_output(#[future(awt)] async_impl_output: T) { + let reader = async_impl_output; +} + +#[fixture] +async fn two_args_mix_fixture( + #[future(awt)] + #[default(async { 4 })] + four: u32, + #[default(2)] two: u32, +) -> u32 { + four * 10 + two +} + +#[rstest] +async fn use_two_args_mix_fixture(#[future(awt)] two_args_mix_fixture: u32) { + assert_eq!(42, two_args_mix_fixture); +} + +#[rstest] +async fn use_two_args_mix_fixture_inject_first( + #[future(awt)] + #[with(async { 5 })] + two_args_mix_fixture: u32, +) { + assert_eq!(52, two_args_mix_fixture); +} + +#[rstest] +async fn use_two_args_mix_fixture_inject_both( + #[future(awt)] + #[with(async { 3 }, 1)] + two_args_mix_fixture: u32, +) { + assert_eq!(31, two_args_mix_fixture); +} diff --git a/rstest/tests/resources/fixture/no_warning.rs b/rstest/tests/resources/fixture/no_warning.rs index 5fe1f7e..b6aa924 100644 --- a/rstest/tests/resources/fixture/no_warning.rs +++ b/rstest/tests/resources/fixture/no_warning.rs @@ -1,10 +1,17 @@ use rstest::*; #[fixture] -fn val() -> i32 { 21 } +fn val() -> i32 { + 21 +} #[fixture] -fn fortytwo(mut val: i32) -> i32 { val *= 2; val } +fn fortytwo(mut val: i32) -> i32 { + val *= 2; + val +} #[rstest] -fn the_test(fortytwo: i32) { assert_eq!(fortytwo, 42); } \ No newline at end of file +fn the_test(fortytwo: i32) { + assert_eq!(fortytwo, 42); +} diff --git a/rstest/tests/resources/rstest/cases/async_awt.rs b/rstest/tests/resources/rstest/cases/async_awt.rs new file mode 100644 index 0000000..fcd9de4 --- /dev/null +++ b/rstest/tests/resources/rstest/cases/async_awt.rs @@ -0,0 +1,28 @@ +use rstest::*; + +#[rstest] +#[case::pass(42, async { 42 })] +#[case::fail(42, async { 41 })] +#[should_panic] +#[case::pass_panic(42, async { 41 })] +#[should_panic] +#[case::fail_panic(42, async { 42 })] +async fn my_async_test( + #[case] expected: u32, + #[case] + #[future(awt)] + value: u32, +) { + assert_eq!(expected, value); +} + +#[rstest] +#[case::pass(42, async { 42 })] +async fn my_async_test_revert( + #[case] expected: u32, + #[future(awt)] + #[case] + value: u32, +) { + assert_eq!(expected, value); +} diff --git a/rstest/tests/resources/rstest/cases/async_awt_global.rs b/rstest/tests/resources/rstest/cases/async_awt_global.rs new file mode 100644 index 0000000..ce88204 --- /dev/null +++ b/rstest/tests/resources/rstest/cases/async_awt_global.rs @@ -0,0 +1,30 @@ +use rstest::*; + +#[rstest] +#[case::pass(42, async { 42 })] +#[case::fail(42, async { 41 })] +#[should_panic] +#[case::pass_panic(42, async { 41 })] +#[should_panic] +#[case::fail_panic(42, async { 42 })] +#[awt] +async fn my_async_test( + #[case] expected: u32, + #[case] + #[future] + value: u32, +) { + assert_eq!(expected, value); +} + +#[rstest] +#[case::pass(42, async { 42 })] +#[awt] +async fn my_async_test_revert( + #[case] expected: u32, + #[future] + #[case] + value: u32, +) { + assert_eq!(expected, value); +} diff --git a/rstest/tests/resources/rstest/matrix/async_awt.rs b/rstest/tests/resources/rstest/matrix/async_awt.rs new file mode 100644 index 0000000..f9e8819 --- /dev/null +++ b/rstest/tests/resources/rstest/matrix/async_awt.rs @@ -0,0 +1,12 @@ +use rstest::*; + +#[rstest] +async fn my_async_test( + #[future(awt)] + #[values(async { 1 }, async { 2 })] + first: u32, + #[values(42, 21)] + second: u32 +) { + assert_eq!(42, first * second); +} \ No newline at end of file diff --git a/rstest/tests/resources/rstest/matrix/async_awt_global.rs b/rstest/tests/resources/rstest/matrix/async_awt_global.rs new file mode 100644 index 0000000..780d8fa --- /dev/null +++ b/rstest/tests/resources/rstest/matrix/async_awt_global.rs @@ -0,0 +1,13 @@ +use rstest::*; + +#[rstest] +#[awt] +async fn my_async_test( + #[future] + #[values(async { 1 }, async { 2 })] + first: u32, + #[values(42, 21)] + second: u32 +) { + assert_eq!(42, first * second); +} \ No newline at end of file diff --git a/rstest/tests/resources/rstest/single/async_awt.rs b/rstest/tests/resources/rstest/single/async_awt.rs new file mode 100644 index 0000000..9d39236 --- /dev/null +++ b/rstest/tests/resources/rstest/single/async_awt.rs @@ -0,0 +1,28 @@ +use rstest::*; + +#[fixture] +async fn fixture() -> u32 { + 42 +} + +#[rstest] +async fn should_pass(#[future(awt)] fixture: u32) { + assert_eq!(fixture, 42); +} + +#[rstest] +async fn should_fail(#[future(awt)] fixture: u32) { + assert_ne!(fixture, 42); +} + +#[rstest] +#[should_panic] +async fn should_panic_pass(#[future(awt)] fixture: u32) { + panic!(format!("My panic -> fixture = {}", fixture)); +} + +#[rstest] +#[should_panic] +async fn should_panic_fail(#[future(awt)] fixture: u32) { + assert_eq!(fixture, 42); +} diff --git a/rstest/tests/resources/rstest/single/async_awt_global.rs b/rstest/tests/resources/rstest/single/async_awt_global.rs new file mode 100644 index 0000000..4279461 --- /dev/null +++ b/rstest/tests/resources/rstest/single/async_awt_global.rs @@ -0,0 +1,32 @@ +use rstest::*; + +#[fixture] +async fn fixture() -> u32 { + 42 +} + +#[rstest] +#[awt] +async fn should_pass(#[future] fixture: u32) { + assert_eq!(fixture, 42); +} + +#[rstest] +#[awt] +async fn should_fail(#[future] fixture: u32) { + assert_ne!(fixture, 42); +} + +#[rstest] +#[awt] +#[should_panic] +async fn should_panic_pass(#[future] fixture: u32) { + panic!(format!("My panic -> fixture = {}", fixture)); +} + +#[rstest] +#[awt] +#[should_panic] +async fn should_panic_fail(#[future] fixture: u32) { + assert_eq!(fixture, 42); +} diff --git a/rstest/tests/resources/rstest/timeout.rs b/rstest/tests/resources/rstest/timeout.rs index c857af6..3c0aa7d 100644 --- a/rstest/tests/resources/rstest/timeout.rs +++ b/rstest/tests/resources/rstest/timeout.rs @@ -64,7 +64,7 @@ mod thread { #[case::fail_timeout(ms(80), 4)] #[case::fail_value(ms(1), 5)] #[timeout(ms(40))] - fn group_same_timeout(#[case] delay: Duration,#[case] expected: u32) { + fn group_same_timeout(#[case] delay: Duration, #[case] expected: u32) { assert_eq!(expected, delayed_sum(2, 2, delay)); } @@ -75,7 +75,7 @@ mod thread { #[case::fail_timeout(ms(70), 4)] #[timeout(ms(100))] #[case::fail_value(ms(1), 5)] - fn group_single_timeout(#[case] delay: Duration,#[case] expected: u32) { + fn group_single_timeout(#[case] delay: Duration, #[case] expected: u32) { assert_eq!(expected, delayed_sum(2, 2, delay)); } @@ -85,10 +85,10 @@ mod thread { #[case::fail_timeout(ms(60), 4)] #[case::fail_value(ms(1), 5)] #[timeout(ms(100))] - fn group_one_timeout_override(#[case] delay: Duration,#[case] expected: u32) { + fn group_one_timeout_override(#[case] delay: Duration, #[case] expected: u32) { assert_eq!(expected, delayed_sum(2, 2, delay)); } - + struct S {} #[rstest] @@ -99,20 +99,19 @@ mod thread { #[fixture] fn no_copy() -> S { - S{} + S {} } #[rstest] fn compile_with_no_copy_fixture(no_copy: S) { assert!(true); } - } mod async_std_cases { use super::*; - async fn delayed_sum(a: u32, b: u32,delay: Duration) -> u32 { + async fn delayed_sum(a: u32, b: u32, delay: Duration) -> u32 { async_std::task::sleep(delay).await; a + b } @@ -139,7 +138,7 @@ mod async_std_cases { #[timeout(ms(1000))] #[should_panic = "user message"] async fn fail_with_user_message() { - panic!{"user message"}; + panic! {"user message"}; } #[rstest] @@ -202,8 +201,8 @@ mod async_std_cases { } #[fixture] - fn no_copy() -> S{ - S{} + fn no_copy() -> S { + S {} } #[rstest] @@ -212,12 +211,17 @@ mod async_std_cases { } #[fixture] - async fn a_fix() -> S{ - S{} + async fn a_fix() -> S { + S {} } #[rstest] fn compile_with_async_fixture(#[future] a_fix: S) { assert!(true); } -} \ No newline at end of file + + #[rstest] + async fn compile_with_async_awt_fixture(#[future(awt)] a_fix: S) { + assert!(true); + } +} diff --git a/rstest/tests/rstest/mod.rs b/rstest/tests/rstest/mod.rs index d8f9499..945c18a 100644 --- a/rstest/tests/rstest/mod.rs +++ b/rstest/tests/rstest/mod.rs @@ -326,9 +326,12 @@ mod single { .assert(output); } - #[test] - fn should_run_async_function() { - let prj = prj(res("async.rs")); + #[rstest] + #[case("async.rs")] + #[case("async_awt.rs")] + #[case("async_awt_global.rs")] + fn should_run_async_function(#[case] name: &str) { + let prj = prj(res(name)); prj.add_dependency("async-std", r#"{version="*", features=["attributes"]}"#); let output = prj.run_tests().unwrap(); @@ -456,9 +459,12 @@ mod cases { .assert(output); } - #[test] - fn should_run_async_function() { - let prj = prj(res("async.rs")); + #[rstest] + #[case("async.rs")] + #[case("async_awt.rs")] + #[case("async_awt_global.rs")] + fn should_run_async_function(#[case] name: &str) { + let prj = prj(res(name)); prj.add_dependency("async-std", r#"{version="*", features=["attributes"]}"#); let output = prj.run_tests().unwrap(); @@ -775,9 +781,12 @@ mod matrix { .assert(output); } - #[test] - fn should_run_async_function() { - let prj = prj(res("async.rs")); + #[rstest] + #[case("async.rs")] + #[case("async_awt.rs")] + #[case("async_awt_global.rs")] + fn should_run_async_function(#[case] name: &str) { + let prj = prj(res(name)); prj.add_dependency("async-std", r#"{version="*", features=["attributes"]}"#); let output = prj.run_tests().unwrap(); @@ -950,6 +959,7 @@ fn timeout() { .ok("async_std_cases::compile_with_no_copy_arg::case_1") .ok("async_std_cases::compile_with_no_copy_fixture") .ok("async_std_cases::compile_with_async_fixture") + .ok("async_std_cases::compile_with_async_awt_fixture") .assert(output); } diff --git a/rstest_macros/Cargo.toml b/rstest_macros/Cargo.toml index fba39ea..137a103 100644 --- a/rstest_macros/Cargo.toml +++ b/rstest_macros/Cargo.toml @@ -11,7 +11,7 @@ keywords = ["test", "fixture"] license = "MIT/Apache-2.0" name = "rstest_macros" repository = "https://github.com/la10736/rstest" -version = "0.16.0" +version = "0.17.0" [lib] proc-macro = true diff --git a/rstest_macros/src/lib.rs b/rstest_macros/src/lib.rs index 20d8d98..5f6c5a4 100644 --- a/rstest_macros/src/lib.rs +++ b/rstest_macros/src/lib.rs @@ -17,7 +17,7 @@ mod utils; use syn::{parse_macro_input, ItemFn}; -use crate::parse::{fixture::FixtureInfo, future::ReplaceFutureAttribute, rstest::RsTestInfo}; +use crate::parse::{fixture::FixtureInfo, rstest::RsTestInfo}; use parse::ExtendWithFunctionAttrs; use quote::ToTokens; @@ -314,14 +314,10 @@ pub fn fixture( let mut info: FixtureInfo = parse_macro_input!(args as FixtureInfo); let mut fixture = parse_macro_input!(input as ItemFn); - let replace_result = ReplaceFutureAttribute::replace(&mut fixture); let extend_result = info.extend_with_function_attrs(&mut fixture); let mut errors = error::fixture(&fixture, &info); - if let Err(attrs_errors) = replace_result { - attrs_errors.to_tokens(&mut errors); - } if let Err(attrs_errors) = extend_result { attrs_errors.to_tokens(&mut errors); } @@ -731,6 +727,33 @@ pub fn fixture( /// assert_eq!(expected, base.await / div.await); /// } /// ``` +/// +/// As you noted you should `.await` all _future_ values and this some times can be really boring. +/// In this case you can use `#[timeout(awt)]` to _awaiting_ an input or annotating your function +/// with `#[awt]` attributes to globally `.await` all your _future_ inputs. Previous code can be +/// simplified like follow: +/// +/// ``` +/// use rstest::*; +/// # #[fixture] +/// # async fn base() -> u32 { 42 } +/// +/// #[rstest] +/// #[case(21, async { 2 })] +/// #[case(6, async { 7 })] +/// #[awt] +/// async fn global(#[future] base: u32, #[case] expected: u32, #[future] #[case] div: u32) { +/// assert_eq!(expected, base / div); +/// } +/// +/// #[rstest] +/// #[case(21, async { 2 })] +/// #[case(6, async { 7 })] +/// async fn single(#[future] base: u32, #[case] expected: u32, #[future(awt)] #[case] div: u32) { +/// assert_eq!(expected, base.await / div); +/// } +/// ``` +/// /// ### Test `#[timeout()]` /// /// You can define an execution timeout for your tests with `#[timeout()]` attribute. Timeouts @@ -1032,14 +1055,10 @@ pub fn rstest( let mut test = parse_macro_input!(input as ItemFn); let mut info = parse_macro_input!(args as RsTestInfo); - let replace_result = ReplaceFutureAttribute::replace(&mut test); let extend_result = info.extend_with_function_attrs(&mut test); let mut errors = error::rstest(&test, &info); - if let Err(attrs_errors) = replace_result { - attrs_errors.to_tokens(&mut errors); - } if let Err(attrs_errors) = extend_result { attrs_errors.to_tokens(&mut errors); } diff --git a/rstest_macros/src/parse/fixture.rs b/rstest_macros/src/parse/fixture.rs index 793e066..55b0df8 100644 --- a/rstest_macros/src/parse/fixture.rs +++ b/rstest_macros/src/parse/fixture.rs @@ -7,9 +7,9 @@ use syn::{ }; use super::{ - extract_argument_attrs, extract_default_return_type, extract_defaults, extract_fixtures, - extract_partials_return_type, parse_vector_trailing_till_double_comma, Attributes, - ExtendWithFunctionAttrs, Fixture, + arguments::ArgumentsInfo, extract_argument_attrs, extract_default_return_type, + extract_defaults, extract_fixtures, extract_partials_return_type, future::extract_futures, + parse_vector_trailing_till_double_comma, Attributes, ExtendWithFunctionAttrs, Fixture, }; use crate::{ error::ErrorsVec, @@ -25,6 +25,7 @@ use quote::{format_ident, ToTokens}; pub(crate) struct FixtureInfo { pub(crate) data: FixtureData, pub(crate) attributes: FixtureModifiers, + pub(crate) arguments: ArgumentsInfo, } impl Parse for FixtureModifiers { @@ -44,6 +45,7 @@ impl Parse for FixtureInfo { .parse::() .or_else(|_| Ok(Default::default())) .and_then(|_| input.parse())?, + arguments: Default::default(), } }) } @@ -59,13 +61,15 @@ impl ExtendWithFunctionAttrs for FixtureInfo { defaults, default_return_type, partials_return_type, - once + once, + futures ) = merge_errors!( extract_fixtures(item_fn), extract_defaults(item_fn), extract_default_return_type(item_fn), extract_partials_return_type(item_fn), - extract_once(item_fn) + extract_once(item_fn), + extract_futures(item_fn) )?; self.data.items.extend( fixtures @@ -82,6 +86,9 @@ impl ExtendWithFunctionAttrs for FixtureInfo { if let Some(ident) = once { self.attributes.set_once(ident) }; + let (futures, global_awt) = futures; + self.arguments.set_global_await(global_awt); + self.arguments.set_futures(futures.into_iter()); Ok(()) } } @@ -352,6 +359,7 @@ mod should { ], } .into(), + arguments: Default::default(), }; assert_eq!(expected, data); @@ -589,6 +597,20 @@ mod extend { assert!(!info.attributes.is_once()); } + #[rstest] + fn extract_future() { + let mut item_fn = "fn f(#[future] a: u32, b: u32) {}".ast(); + let expected = "fn f(a: u32, b: u32) {}".ast(); + + let mut info = FixtureInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!(item_fn, expected); + assert!(info.arguments.is_future(&ident("a"))); + assert!(!info.arguments.is_future(&ident("b"))); + } + mod raise_error { use super::{assert_eq, *}; use rstest_test::assert_in; diff --git a/rstest_macros/src/parse/future.rs b/rstest_macros/src/parse/future.rs index 6e74f5a..d23de04 100644 --- a/rstest_macros/src/parse/future.rs +++ b/rstest_macros/src/parse/future.rs @@ -1,102 +1,143 @@ use quote::{format_ident, ToTokens}; -use syn::{parse_quote, visit_mut::VisitMut, FnArg, ItemFn, Lifetime}; +use syn::{visit_mut::VisitMut, FnArg, Ident, ItemFn, PatType, Type}; -use crate::{error::ErrorsVec, refident::MaybeIdent, utils::attr_is}; +use crate::{error::ErrorsVec, refident::MaybeType, utils::attr_is}; -#[derive(Default)] -pub(crate) struct ReplaceFutureAttribute { - lifetimes: Vec, - errors: Vec, +use super::{arguments::FutureArg, extract_argument_attrs}; + +pub(crate) fn extract_futures( + item_fn: &mut ItemFn, +) -> Result<(Vec<(Ident, FutureArg)>, bool), ErrorsVec> { + let mut extractor = FutureFunctionExtractor::default(); + extractor.visit_item_fn_mut(item_fn); + extractor.take() } -fn extend_generics_with_lifetimes<'a, 'b>( - generics: impl Iterator, - lifetimes: impl Iterator, -) -> syn::Generics { - let all = lifetimes - .map(|lt| lt as &dyn ToTokens) - .chain(generics.map(|gp| gp as &dyn ToTokens)); - parse_quote! { - <#(#all),*> - } +pub(crate) trait MaybeFutureImplType { + fn as_future_impl_type(&self) -> Option<&Type>; + + fn as_mut_future_impl_type(&mut self) -> Option<&mut Type>; } -impl ReplaceFutureAttribute { - pub(crate) fn replace(item_fn: &mut ItemFn) -> Result<(), ErrorsVec> { - let mut visitor = Self::default(); - visitor.visit_item_fn_mut(item_fn); - if !visitor.lifetimes.is_empty() { - item_fn.sig.generics = extend_generics_with_lifetimes( - item_fn.sig.generics.params.iter(), - visitor.lifetimes.iter(), - ); +impl MaybeFutureImplType for FnArg { + fn as_future_impl_type(&self) -> Option<&Type> { + match self { + FnArg::Typed(PatType { ty, .. }) if can_impl_future(ty.as_ref()) => Some(ty.as_ref()), + _ => None, } - if visitor.errors.is_empty() { - Ok(()) - } else { - Err(visitor.errors.into()) + } + + fn as_mut_future_impl_type(&mut self) -> Option<&mut Type> { + match self { + FnArg::Typed(PatType { ty, .. }) if can_impl_future(ty.as_ref()) => Some(ty.as_mut()), + _ => None, } } } -fn extract_arg_attributes( - node: &mut syn::PatType, - predicate: fn(a: &syn::Attribute) -> bool, -) -> Vec { - let attrs = std::mem::take(&mut node.attrs); - let (extracted, attrs): (Vec<_>, Vec<_>) = attrs.into_iter().partition(predicate); - node.attrs = attrs; - extracted +fn can_impl_future(ty: &Type) -> bool { + use Type::*; + !matches!( + ty, + Group(_) + | ImplTrait(_) + | Infer(_) + | Macro(_) + | Never(_) + | Slice(_) + | TraitObject(_) + | Verbatim(_) + ) +} + +/// Simple struct used to visit function attributes and extract future args to +/// implement the boilerplate. +#[derive(Default)] +struct FutureFunctionExtractor { + futures: Vec<(Ident, FutureArg)>, + awt: bool, + errors: Vec, +} + +impl FutureFunctionExtractor { + pub(crate) fn take(self) -> Result<(Vec<(Ident, FutureArg)>, bool), ErrorsVec> { + if self.errors.is_empty() { + Ok((self.futures, self.awt)) + } else { + Err(self.errors.into()) + } + } } -impl VisitMut for ReplaceFutureAttribute { +impl VisitMut for FutureFunctionExtractor { + fn visit_item_fn_mut(&mut self, node: &mut ItemFn) { + let attrs = std::mem::take(&mut node.attrs); + let (awts, remain): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|a| attr_is(a, "awt")); + if awts.len() == 1 { + self.awt = true; + } else if awts.len() > 1 { + self.errors.extend(awts.into_iter().skip(1).map(|a| { + syn::Error::new_spanned( + a.into_token_stream(), + "Cannot use #[awt] more than once.".to_owned(), + ) + })) + } + node.attrs = remain; + syn::visit_mut::visit_item_fn_mut(self, node); + } + fn visit_fn_arg_mut(&mut self, node: &mut FnArg) { - let ident = node.maybe_ident().cloned(); - match node { - FnArg::Typed(t) => { - let futures = extract_arg_attributes(t, |a| attr_is(a, "future")); - if futures.is_empty() { - return; - } else if futures.len() > 1 { - self.errors.extend(futures.iter().skip(1).map(|attr| { - syn::Error::new_spanned( - attr.into_token_stream(), - "Cannot use #[future] more than once.".to_owned(), - ) - })); - return; - } - let ty = &mut t.ty; - use syn::Type::*; - match ty.as_ref() { - Group(_) | ImplTrait(_) | Infer(_) | Macro(_) | Never(_) | Slice(_) - | TraitObject(_) | Verbatim(_) => { - self.errors.push(syn::Error::new_spanned( - ty.into_token_stream(), - "This type cannot used to generete impl Future.".to_owned(), - )); - return; + if matches!(node, FnArg::Receiver(_)) { + return; + } + match extract_argument_attrs( + node, + |a| attr_is(a, "future"), + |arg, name| { + let kind = if arg.tokens.is_empty() { + FutureArg::Define + } else { + match arg.parse_args::>()? { + Some(awt) if awt == format_ident!("awt") => FutureArg::Await, + None => FutureArg::Define, + Some(invalid) => { + return Err(syn::Error::new_spanned( + arg.parse_args::>()?.into_token_stream(), + format!("Invalid '{}' #[future(...)] arg.", invalid), + )); + } } - _ => {} }; - if let Reference(tr) = ty.as_mut() { - let ident = ident.unwrap(); - if tr.lifetime.is_none() { - let lifetime = syn::Lifetime { - apostrophe: ident.span(), - ident: format_ident!("_{}", ident), - }; - self.lifetimes.push(lifetime.clone()); - tr.lifetime = lifetime.into(); + Ok((arg, name.clone(), kind)) + }, + ) + .collect::, _>>() + { + Ok(futures) => { + if futures.len() > 1 { + self.errors + .extend(futures.iter().skip(1).map(|(attr, _ident, _type)| { + syn::Error::new_spanned( + attr.into_token_stream(), + "Cannot use #[future] more than once.".to_owned(), + ) + })); + return; + } else if futures.len() == 1 { + match node.as_future_impl_type() { + Some(_) => self.futures.push((futures[0].1.clone(), futures[0].2)), + None => self.errors.push(syn::Error::new_spanned( + node.maybe_type().unwrap().into_token_stream(), + "This type cannot used to generate impl Future.".to_owned(), + )), } } - - t.ty = parse_quote! { - impl std::future::Future - } } - FnArg::Receiver(_) => {} - } + Err(e) => { + self.errors.push(e); + } + }; } } @@ -115,70 +156,102 @@ mod should { let mut item_fn: ItemFn = item_fn.ast(); let orig = item_fn.clone(); - ReplaceFutureAttribute::replace(&mut item_fn).unwrap(); + let (futures, awt) = extract_futures(&mut item_fn).unwrap(); - assert_eq!(orig, item_fn) + assert_eq!(orig, item_fn); + assert!(futures.is_empty()); + assert!(!awt); } #[rstest] - #[case::simple( - "fn f(#[future] a: u32) {}", - "fn f(a: impl std::future::Future) {}" - )] + #[case::simple("fn f(#[future] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Define)], false)] + #[case::global_awt("#[awt] fn f(a: u32) {}", "fn f(a: u32) {}", &[], true)] + #[case::simple_awaited("fn f(#[future(awt)] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Await)], false)] + #[case::simple_awaited_and_global("#[awt] fn f(#[future(awt)] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Await)], true)] #[case::more_than_one( - "fn f(#[future] a: u32, #[future] b: String, #[future] c: std::collection::HashMap) {}", - r#"fn f(a: impl std::future::Future, - b: impl std::future::Future, - c: impl std::future::Future>) {}"#, + "fn f(#[future] a: u32, #[future(awt)] b: String, #[future()] c: std::collection::HashMap) {}", + r#"fn f(a: u32, + b: String, + c: std::collection::HashMap) {}"#, + &[("a", FutureArg::Define), ("b", FutureArg::Await), ("c", FutureArg::Define)], + false, )] #[case::just_one( "fn f(a: u32, #[future] b: String) {}", - r#"fn f(a: u32, - b: impl std::future::Future) {}"# + r#"fn f(a: u32, b: String) {}"#, + &[("b", FutureArg::Define)], + false, )] - #[case::generics( - "fn f>(#[future] a: S) {}", - "fn f>(a: impl std::future::Future) {}" + #[case::just_one_awaited( + "fn f(a: u32, #[future(awt)] b: String) {}", + r#"fn f(a: u32, b: String) {}"#, + &[("b", FutureArg::Await)], + false, )] - fn replace_basic_type(#[case] item_fn: &str, #[case] expected: &str) { + fn extract( + #[case] item_fn: &str, + #[case] expected: &str, + #[case] expected_futures: &[(&str, FutureArg)], + #[case] expected_awt: bool, + ) { let mut item_fn: ItemFn = item_fn.ast(); let expected: ItemFn = expected.ast(); - ReplaceFutureAttribute::replace(&mut item_fn).unwrap(); + let (futures, awt) = extract_futures(&mut item_fn).unwrap(); - assert_eq!(expected, item_fn) + assert_eq!(expected, item_fn); + assert_eq!( + futures, + expected_futures + .into_iter() + .map(|(id, a)| (ident(id), *a)) + .collect::>() + ); + assert_eq!(expected_awt, awt); } #[rstest] - #[case::base( - "fn f(#[future] ident_name: &u32) {}", - "fn f<'_ident_name>(ident_name: impl std::future::Future) {}" - )] - #[case::lifetime_already_exists( - "fn f<'b>(#[future] a: &'b u32) {}", - "fn f<'b>(a: impl std::future::Future) {}" + #[case::base(r#"#[awt] fn f(a: u32) {}"#, r#"fn f(a: u32) {}"#)] + #[case::two( + r#" + #[awt] + #[awt] + fn f(a: u32) {} + "#, + r#"fn f(a: u32) {}"# )] - #[case::some_other_generics( - "fn f<'b, IT: Iterator>(#[future] a: &u32, it: IT) {}", - "fn f<'_a, 'b, IT: Iterator>(a: impl std::future::Future, it: IT) {}" + #[case::inner( + r#" + #[one] + #[awt] + #[two] + fn f(a: u32) {} + "#, + r#" + #[one] + #[two] + fn f(a: u32) {} + "# )] - fn replace_reference_type(#[case] item_fn: &str, #[case] expected: &str) { + fn remove_all_awt_attributes(#[case] item_fn: &str, #[case] expected: &str) { let mut item_fn: ItemFn = item_fn.ast(); let expected: ItemFn = expected.ast(); - ReplaceFutureAttribute::replace(&mut item_fn).unwrap(); + let _ = extract_futures(&mut item_fn); - assert_eq!(expected, item_fn) + assert_eq!(item_fn, expected); } #[rstest] #[case::no_more_than_one("fn f(#[future] #[future] a: u32) {}", "more than once")] - #[case::no_impl("fn f(#[future] a: impl AsRef) {}", "generete impl Future")] - #[case::no_slice("fn f(#[future] a: [i32]) {}", "generete impl Future")] + #[case::no_impl("fn f(#[future] a: impl AsRef) {}", "generate impl Future")] + #[case::no_slice("fn f(#[future] a: [i32]) {}", "generate impl Future")] + #[case::invalid_arg("fn f(#[future(other)] a: [i32]) {}", "Invalid 'other'")] + #[case::no_more_than_one_awt("#[awt] #[awt] fn f(a: u32) {}", "more than once")] fn raise_error(#[case] item_fn: &str, #[case] message: &str) { let mut item_fn: ItemFn = item_fn.ast(); - let err = ReplaceFutureAttribute::replace(&mut item_fn).unwrap_err(); + let err = extract_futures(&mut item_fn).unwrap_err(); assert_in!(format!("{:?}", err), message); } diff --git a/rstest_macros/src/parse/mod.rs b/rstest_macros/src/parse/mod.rs index b216da2..3bb70ec 100644 --- a/rstest_macros/src/parse/mod.rs +++ b/rstest_macros/src/parse/mod.rs @@ -230,7 +230,7 @@ pub(crate) fn extract_once(item_fn: &mut ItemFn) -> Result, Errors extractor.take() } -fn extract_argument_attrs<'a, B: 'a + std::fmt::Debug>( +pub(crate) fn extract_argument_attrs<'a, B: 'a + std::fmt::Debug>( node: &mut FnArg, is_valid_attr: fn(&syn::Attribute) -> bool, build: fn(syn::Attribute, &Ident) -> syn::Result, @@ -628,6 +628,139 @@ pub(crate) fn check_timeout_attrs(item_fn: &mut ItemFn) -> Result<(), ErrorsVec> checker.take() } +pub(crate) mod arguments { + use std::collections::HashMap; + + use syn::Ident; + + #[derive(PartialEq, Debug, Clone, Copy)] + #[allow(dead_code)] + pub(crate) enum FutureArg { + None, + Define, + Await, + } + + impl Default for FutureArg { + fn default() -> Self { + FutureArg::None + } + } + + #[derive(PartialEq, Default, Debug)] + pub(crate) struct ArgumentInfo { + future: FutureArg, + } + + impl ArgumentInfo { + fn future(future: FutureArg) -> Self { + Self { + future, + ..Default::default() + } + } + + fn is_future(&self) -> bool { + use FutureArg::*; + + matches!(self.future, Define | Await) + } + + fn is_future_await(&self) -> bool { + use FutureArg::*; + + matches!(self.future, Await) + } + } + + #[derive(PartialEq, Default, Debug)] + pub(crate) struct ArgumentsInfo { + args: HashMap, + is_global_await: bool, + } + + impl ArgumentsInfo { + pub(crate) fn set_future(&mut self, ident: Ident, kind: FutureArg) { + self.args + .entry(ident) + .and_modify(|v| v.future = kind) + .or_insert_with(|| ArgumentInfo::future(kind)); + } + + pub(crate) fn set_futures(&mut self, futures: impl Iterator) { + futures.for_each(|(ident, k)| self.set_future(ident, k)); + } + + pub(crate) fn set_global_await(&mut self, is_global_await: bool) { + self.is_global_await = is_global_await; + } + + #[allow(dead_code)] + pub(crate) fn add_future(&mut self, ident: Ident) { + self.set_future(ident, FutureArg::Define); + } + + pub(crate) fn is_future(&self, id: &Ident) -> bool { + self.args + .get(id) + .map(|arg| arg.is_future()) + .unwrap_or_default() + } + + pub(crate) fn is_future_await(&self, ident: &Ident) -> bool { + match self.args.get(ident) { + Some(arg) => arg.is_future_await() || (arg.is_future() && self.is_global_await()), + None => false, + } + } + + pub(crate) fn is_global_await(&self) -> bool { + self.is_global_await + } + } + + #[cfg(test)] + mod should_implement_is_future_await_logic { + use super::*; + use crate::test::*; + + #[fixture] + fn info() -> ArgumentsInfo { + let mut a = ArgumentsInfo::default(); + a.set_future(ident("simple"), FutureArg::Define); + a.set_future(ident("other_simple"), FutureArg::Define); + a.set_future(ident("awaited"), FutureArg::Await); + a.set_future(ident("other_awaited"), FutureArg::Await); + a.set_future(ident("none"), FutureArg::None); + a + } + + #[rstest] + fn no_matching_ident(info: ArgumentsInfo) { + assert!(!info.is_future_await(&ident("some"))); + assert!(!info.is_future_await(&ident("simple"))); + assert!(!info.is_future_await(&ident("none"))); + } + + #[rstest] + fn matching_ident(info: ArgumentsInfo) { + assert!(info.is_future_await(&ident("awaited"))); + assert!(info.is_future_await(&ident("other_awaited"))); + } + + #[rstest] + fn global_matching_future_ident(mut info: ArgumentsInfo) { + info.set_global_await(true); + assert!(info.is_future_await(&ident("simple"))); + assert!(info.is_future_await(&ident("other_simple"))); + assert!(info.is_future_await(&ident("awaited"))); + + assert!(!info.is_future_await(&ident("some"))); + assert!(!info.is_future_await(&ident("none"))); + } + } +} + #[cfg(test)] mod should { use super::*; diff --git a/rstest_macros/src/parse/rstest.rs b/rstest_macros/src/parse/rstest.rs index bb2051d..ef398b0 100644 --- a/rstest_macros/src/parse/rstest.rs +++ b/rstest_macros/src/parse/rstest.rs @@ -5,9 +5,10 @@ use syn::{ use super::testcase::TestCase; use super::{ - check_timeout_attrs, extract_case_args, extract_cases, extract_excluded_trace, - extract_fixtures, extract_value_list, parse_vector_trailing_till_double_comma, Attribute, - Attributes, ExtendWithFunctionAttrs, Fixture, + arguments::ArgumentsInfo, check_timeout_attrs, extract_case_args, extract_cases, + extract_excluded_trace, extract_fixtures, extract_value_list, future::extract_futures, + parse_vector_trailing_till_double_comma, Attribute, Attributes, ExtendWithFunctionAttrs, + Fixture, }; use crate::parse::vlist::ValueList; use crate::{ @@ -21,6 +22,7 @@ use quote::{format_ident, ToTokens}; pub(crate) struct RsTestInfo { pub(crate) data: RsTestData, pub(crate) attributes: RsTestAttributes, + pub(crate) arguments: ArgumentsInfo, } impl Parse for RsTestInfo { @@ -34,6 +36,7 @@ impl Parse for RsTestInfo { .parse::() .or_else(|_| Ok(Default::default())) .and_then(|_| input.parse())?, + arguments: Default::default(), } }) } @@ -41,12 +44,16 @@ impl Parse for RsTestInfo { impl ExtendWithFunctionAttrs for RsTestInfo { fn extend_with_function_attrs(&mut self, item_fn: &mut ItemFn) -> Result<(), ErrorsVec> { - let (_, (excluded, _)) = merge_errors!( + let composed_tuple!(_inner, excluded, _timeout, futures) = merge_errors!( self.data.extend_with_function_attrs(item_fn), extract_excluded_trace(item_fn), - check_timeout_attrs(item_fn) + check_timeout_attrs(item_fn), + extract_futures(item_fn) )?; + let (futures, global_awt) = futures; self.attributes.add_notraces(excluded); + self.arguments.set_global_await(global_awt); + self.arguments.set_futures(futures.into_iter()); Ok(()) } } @@ -362,6 +369,7 @@ mod test { ], } .into(), + ..Default::default() }; assert_eq!(expected, data); @@ -505,6 +513,20 @@ mod test { .unwrap(); assert_eq!(attrs("#[something_else]"), b_args); } + + #[rstest] + fn extract_future() { + let mut item_fn = "fn f(#[future] a: u32, b: u32) {}".ast(); + let expected = "fn f(a: u32, b: u32) {}".ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!(item_fn, expected); + assert!(info.arguments.is_future(&ident("a"))); + assert!(!info.arguments.is_future(&ident("b"))); + } } mod parametrize_cases { diff --git a/rstest_macros/src/render/apply_argumets.rs b/rstest_macros/src/render/apply_argumets.rs new file mode 100644 index 0000000..b32ae9c --- /dev/null +++ b/rstest_macros/src/render/apply_argumets.rs @@ -0,0 +1,249 @@ +use quote::{format_ident, ToTokens}; +use syn::{parse_quote, FnArg, Generics, Ident, ItemFn, Lifetime, Signature, Type, TypeReference}; + +use crate::{ + parse::{arguments::ArgumentsInfo, future::MaybeFutureImplType}, + refident::MaybeIdent, +}; + +pub(crate) trait ApplyArgumets { + fn apply_argumets(&mut self, arguments: &ArgumentsInfo) -> R; +} + +impl ApplyArgumets> for FnArg { + fn apply_argumets(&mut self, arguments: &ArgumentsInfo) -> Option { + if self + .maybe_ident() + .map(|id| arguments.is_future(id)) + .unwrap_or_default() + { + self.impl_future_arg() + } else { + None + } + } +} + +fn move_generic_list(data: &mut Generics, other: Generics) { + data.lt_token = data.lt_token.or(other.lt_token); + data.params = other.params; + data.gt_token = data.gt_token.or(other.gt_token); +} + +fn extend_generics_with_lifetimes<'a, 'b>( + generics: impl Iterator, + lifetimes: impl Iterator, +) -> Generics { + let all = lifetimes + .map(|lt| lt as &dyn ToTokens) + .chain(generics.map(|gp| gp as &dyn ToTokens)); + parse_quote! { + <#(#all),*> + } +} + +impl ApplyArgumets for Signature { + fn apply_argumets(&mut self, arguments: &ArgumentsInfo) { + let new_lifetimes = self + .inputs + .iter_mut() + .filter_map(|arg| arg.apply_argumets(arguments)) + .collect::>(); + if !new_lifetimes.is_empty() || !self.generics.params.is_empty() { + let new_generics = + extend_generics_with_lifetimes(self.generics.params.iter(), new_lifetimes.iter()); + move_generic_list(&mut self.generics, new_generics); + } + } +} + +impl ApplyArgumets for ItemFn { + fn apply_argumets(&mut self, arguments: &ArgumentsInfo) -> () { + let awaited_args = self + .sig + .inputs + .iter() + .filter_map(|a| a.maybe_ident()) + .filter(|&a| arguments.is_future_await(a)) + .cloned(); + let orig_block_impl = self.block.clone(); + self.block = parse_quote! { + { + #(let #awaited_args = #awaited_args.await;)* + #orig_block_impl + } + }; + self.sig.apply_argumets(arguments); + } +} + +pub(crate) trait ImplFutureArg { + fn impl_future_arg(&mut self) -> Option; +} + +impl ImplFutureArg for FnArg { + fn impl_future_arg(&mut self) -> Option { + let lifetime_id = self.maybe_ident().map(|id| format_ident!("_{}", id)); + match self.as_mut_future_impl_type() { + Some(ty) => { + let lifetime = lifetime_id.and_then(|id| update_type_with_lifetime(ty, id)); + *ty = parse_quote! { + impl std::future::Future + }; + lifetime + } + None => None, + } + } +} + +fn update_type_with_lifetime(ty: &mut Type, ident: Ident) -> Option { + if let Type::Reference(ty_ref @ TypeReference { lifetime: None, .. }) = ty { + let lifetime = Some(syn::Lifetime { + apostrophe: ident.span(), + ident, + }); + ty_ref.lifetime = lifetime.clone(); + lifetime + } else { + None + } +} + +#[cfg(test)] +mod should { + use super::*; + use crate::test::{assert_eq, *}; + use syn::ItemFn; + + #[rstest] + #[case("fn simple(a: u32) {}")] + #[case("fn more(a: u32, b: &str) {}")] + #[case("fn gen>(a: u32, b: S) {}")] + #[case("fn attr(#[case] a: u32, #[values(1,2)] b: i32) {}")] + fn no_change(#[case] item_fn: &str) { + let mut item_fn: ItemFn = item_fn.ast(); + let orig = item_fn.clone(); + + item_fn.sig.apply_argumets(&ArgumentsInfo::default()); + + assert_eq!(orig, item_fn) + } + + #[rstest] + #[case::simple( + "fn f(a: u32) {}", + &["a"], + "fn f(a: impl std::future::Future) {}" + )] + #[case::more_than_one( + "fn f(a: u32, b: String, c: std::collection::HashMap) {}", + &["a", "b", "c"], + r#"fn f(a: impl std::future::Future, + b: impl std::future::Future, + c: impl std::future::Future>) {}"#, + )] + #[case::just_one( + "fn f(a: u32, b: String) {}", + &["b"], + r#"fn f(a: u32, + b: impl std::future::Future) {}"# + )] + #[case::generics( + "fn f>(a: S) {}", + &["a"], + "fn f>(a: impl std::future::Future) {}" + )] + fn replace_future_basic_type( + #[case] item_fn: &str, + #[case] futures: &[&str], + #[case] expected: &str, + ) { + let mut item_fn: ItemFn = item_fn.ast(); + let expected: ItemFn = expected.ast(); + + let mut arguments = ArgumentsInfo::default(); + futures + .into_iter() + .for_each(|&f| arguments.add_future(ident(f))); + + item_fn.sig.apply_argumets(&arguments); + + assert_eq!(expected, item_fn) + } + + #[rstest] + #[case::base( + "fn f(ident_name: &u32) {}", + &["ident_name"], + "fn f<'_ident_name>(ident_name: impl std::future::Future) {}" + )] + #[case::lifetime_already_exists( + "fn f<'b>(a: &'b u32) {}", + &["a"], + "fn f<'b>(a: impl std::future::Future) {}" + )] + #[case::some_other_generics( + "fn f<'b, IT: Iterator>(a: &u32, it: IT) {}", + &["a"], + "fn f<'_a, 'b, IT: Iterator>(a: impl std::future::Future, it: IT) {}" + )] + fn replace_reference_type( + #[case] item_fn: &str, + #[case] futures: &[&str], + #[case] expected: &str, + ) { + let mut item_fn: ItemFn = item_fn.ast(); + let expected: ItemFn = expected.ast(); + + let mut arguments = ArgumentsInfo::default(); + futures + .into_iter() + .for_each(|&f| arguments.add_future(ident(f))); + + item_fn.sig.apply_argumets(&arguments); + + assert_eq!(expected, item_fn) + } + + mod await_future_args { + use rstest_test::{assert_in, assert_not_in}; + + use crate::parse::arguments::FutureArg; + + use super::*; + + #[test] + fn with_global_await() { + let mut item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut arguments: ArgumentsInfo = Default::default(); + arguments.set_global_await(true); + arguments.add_future(ident("a")); + arguments.add_future(ident("b")); + + item_fn.apply_argumets(&arguments); + + let code = item_fn.block.display_code(); + + assert_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + #[test] + fn with_selective_await() { + let mut item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut arguments: ArgumentsInfo = Default::default(); + arguments.set_future(ident("a"), FutureArg::Define); + arguments.set_future(ident("b"), FutureArg::Await); + + item_fn.apply_argumets(&arguments); + + let code = item_fn.block.display_code(); + + assert_not_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + } +} diff --git a/rstest_macros/src/render/fixture.rs b/rstest_macros/src/render/fixture.rs index d369978..c1fde57 100644 --- a/rstest_macros/src/render/fixture.rs +++ b/rstest_macros/src/render/fixture.rs @@ -3,6 +3,7 @@ use syn::{parse_quote, Ident, ItemFn, ReturnType}; use quote::quote; +use super::apply_argumets::ApplyArgumets; use super::{inject, render_exec_call}; use crate::resolver::{self, Resolver}; use crate::utils::{fn_args, fn_args_idents}; @@ -32,7 +33,8 @@ fn wrap_call_impl_with_call_once_impl(call_impl: TokenStream, rt: &ReturnType) - } } -pub(crate) fn render(fixture: ItemFn, info: FixtureInfo) -> TokenStream { +pub(crate) fn render(mut fixture: ItemFn, info: FixtureInfo) -> TokenStream { + fixture.apply_argumets(&info.arguments); let name = &fixture.sig.ident; let asyncness = &fixture.sig.asyncness.clone(); let vargs = fn_args_idents(&fixture).cloned().collect::>(); @@ -60,6 +62,7 @@ pub(crate) fn render(fixture: ItemFn, info: FixtureInfo) -> TokenStream { .cloned() .collect::>(); let inject = inject::resolve_aruments(fixture.sig.inputs.iter(), &resolver, &generics_idents); + let partials = (1..=orig_args.len()).map(|n| render_partial_impl(&fixture, n, &resolver, &info)); @@ -144,12 +147,16 @@ fn render_partial_impl( #[cfg(test)] mod should { + use rstest_test::{assert_in, assert_not_in}; use syn::{ parse::{Parse, ParseStream}, parse2, parse_str, ItemFn, ItemImpl, ItemStruct, Result, }; - use crate::parse::{Attribute, Attributes}; + use crate::parse::{ + arguments::{ArgumentsInfo, FutureArg}, + Attribute, Attributes, + }; use super::*; use crate::test::{assert_eq, *}; @@ -483,4 +490,86 @@ mod should { assert_eq!(expected.sig, partial.sig); } + + #[test] + fn add_future_boilerplate_if_requested() { + let item_fn: ItemFn = + r#"async fn test(async_ref_u32: &u32, async_u32: u32,simple: u32) { }"#.ast(); + + let mut arguments = ArgumentsInfo::default(); + arguments.add_future(ident("async_ref_u32")); + arguments.add_future(ident("async_u32")); + + let tokens = render( + item_fn.clone(), + FixtureInfo { + arguments, + ..Default::default() + }, + ); + let out: FixtureOutput = parse2(tokens).unwrap(); + + let expected = parse_str::( + r#" + async fn get<'_async_ref_u32>( + async_ref_u32: impl std::future::Future, + async_u32: impl std::future::Future, + simple: u32 + ) + { } + "#, + ) + .unwrap(); + + let rendered = select_method(out.core_impl, "get").unwrap(); + + assert_eq!(expected.sig, rendered.sig); + } + + #[test] + fn use_global_await() { + let item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut arguments: ArgumentsInfo = Default::default(); + arguments.set_global_await(true); + arguments.add_future(ident("a")); + arguments.add_future(ident("b")); + + let tokens = render( + item_fn.clone(), + FixtureInfo { + arguments, + ..Default::default() + }, + ); + let out: FixtureOutput = parse2(tokens).unwrap(); + + let code = out.orig.display_code(); + + assert_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + #[test] + fn use_selective_await() { + let item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut arguments: ArgumentsInfo = Default::default(); + arguments.set_future(ident("a"), FutureArg::Define); + arguments.set_future(ident("b"), FutureArg::Await); + + let tokens = render( + item_fn.clone(), + FixtureInfo { + arguments, + ..Default::default() + }, + ); + let out: FixtureOutput = parse2(tokens).unwrap(); + + let code = out.orig.display_code(); + + assert_not_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } } diff --git a/rstest_macros/src/render/mod.rs b/rstest_macros/src/render/mod.rs index c64cb62..9aa1d34 100644 --- a/rstest_macros/src/render/mod.rs +++ b/rstest_macros/src/render/mod.rs @@ -3,6 +3,7 @@ mod test; mod wrapper; use std::collections::HashMap; + use syn::token::Async; use proc_macro2::{Span, TokenStream}; @@ -27,9 +28,13 @@ use crate::{ use wrapper::WrapByModule; pub(crate) use fixture::render as fixture; + +use self::apply_argumets::ApplyArgumets; +pub(crate) mod apply_argumets; pub(crate) mod inject; pub(crate) fn single(mut test: ItemFn, info: RsTestInfo) -> TokenStream { + test.apply_argumets(&info.arguments); let resolver = resolver::fixtures::get(info.data.fixtures()); let args = test.sig.inputs.iter().cloned().collect::>(); let attrs = std::mem::take(&mut test.attrs); @@ -56,8 +61,13 @@ pub(crate) fn single(mut test: ItemFn, info: RsTestInfo) -> TokenStream { ) } -pub(crate) fn parametrize(test: ItemFn, info: RsTestInfo) -> TokenStream { - let RsTestInfo { data, attributes } = info; +pub(crate) fn parametrize(mut test: ItemFn, info: RsTestInfo) -> TokenStream { + let RsTestInfo { + data, + attributes, + arguments, + } = info; + test.apply_argumets(&arguments); let resolver_fixtures = resolver::fixtures::get(data.fixtures()); let rendered_cases = cases_data(&data, test.sig.ident.span()) @@ -139,10 +149,13 @@ fn _matrix_recursive<'a>( } } -pub(crate) fn matrix(test: ItemFn, info: RsTestInfo) -> TokenStream { +pub(crate) fn matrix(mut test: ItemFn, info: RsTestInfo) -> TokenStream { let RsTestInfo { - data, attributes, .. + data, + attributes, + arguments, } = info; + test.apply_argumets(&arguments); let span = test.sig.ident.span(); let cases = cases_data(&data, span).collect::>(); @@ -424,33 +437,3 @@ fn sanitize_ident(expr: &Expr) -> String { .filter(|&c| is_xid_continue(c)) .collect() } - -#[cfg(test)] -mod tests { - use crate::test::ToAst; - - use super::*; - use crate::test::{assert_eq, *}; - - #[rstest] - #[case("1", "1")] - #[case(r#""1""#, "__1__")] - #[case(r#"Some::SomeElse"#, "Some__SomeElse")] - #[case(r#""minnie".to_owned()"#, "__minnie___to_owned__")] - #[case( - r#"vec![1 , 2, - 3]"#, - "vec__1_2_3_" - )] - #[case( - r#"some_macro!("first", {second}, [third])"#, - "some_macro____first____second___third__" - )] - #[case(r#"'x'"#, "__x__")] - #[case::ops(r#"a*b+c/d-e%f^g"#, "a_b_c_d_e_f_g")] - fn sanitaze_ident_name(#[case] expression: impl AsRef, #[case] expected: impl AsRef) { - let expression: Expr = expression.as_ref().ast(); - - assert_eq!(expected.as_ref(), sanitize_ident(&expression)); - } -} diff --git a/rstest_macros/src/render/test.rs b/rstest_macros/src/render/test.rs index 0b4bdd0..cefd98a 100644 --- a/rstest_macros/src/render/test.rs +++ b/rstest_macros/src/render/test.rs @@ -33,10 +33,35 @@ fn trace_argument_code_string(arg_name: &str) -> String { statment.display_code() } +#[rstest] +#[case("1", "1")] +#[case(r#""1""#, "__1__")] +#[case(r#"Some::SomeElse"#, "Some__SomeElse")] +#[case(r#""minnie".to_owned()"#, "__minnie___to_owned__")] +#[case( + r#"vec![1 , 2, + 3]"#, + "vec__1_2_3_" +)] +#[case( + r#"some_macro!("first", {second}, [third])"#, + "some_macro____first____second___third__" +)] +#[case(r#"'x'"#, "__x__")] +#[case::ops(r#"a*b+c/d-e%f^g"#, "a_b_c_d_e_f_g")] +fn sanitaze_ident_name(#[case] expression: impl AsRef, #[case] expected: impl AsRef) { + let expression: Expr = expression.as_ref().ast(); + + assert_eq!(expected.as_ref(), sanitize_ident(&expression)); +} + mod single_test_should { use rstest_test::{assert_in, assert_not_in}; - use crate::test::{assert_eq, *}; + use crate::{ + parse::arguments::{ArgumentsInfo, FutureArg}, + test::{assert_eq, *}, + }; use super::*; @@ -49,6 +74,14 @@ mod single_test_should { assert_eq!(result.sig.output, input_fn.sig.output); } + fn extract_inner_test_function(outer: &ItemFn) -> ItemFn { + let first_stmt = outer.block.stmts.get(0).unwrap(); + + parse_quote! { + #first_stmt + } + } + #[test] fn include_given_function() { let input_fn: ItemFn = r#" @@ -61,13 +94,12 @@ mod single_test_should { "#.ast(); let result: ItemFn = single(input_fn.clone(), Default::default()).ast(); - let first_stmt = result.block.stmts.get(0).unwrap(); - let inner_fn: ItemFn = parse_quote! { - #first_stmt - }; + let inner_fn = extract_inner_test_function(&result); + let inner_fn_impl: Stmt = inner_fn.block.stmts.last().cloned().unwrap(); - assert_eq!(inner_fn, input_fn); + assert_eq!(inner_fn.sig, input_fn.sig); + assert_eq!(inner_fn_impl.display_code(), input_fn.block.display_code()); } #[rstest] @@ -120,6 +152,53 @@ mod single_test_should { assert_eq!(result.attrs, attributes); } + #[test] + fn use_global_await() { + let input_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut info: RsTestInfo = Default::default(); + info.arguments.set_global_await(true); + info.arguments.add_future(ident("a")); + info.arguments.add_future(ident("b")); + + let item_fn: ItemFn = single(input_fn.clone(), info).ast(); + + assert_in!( + item_fn.block.display_code(), + await_argument_code_string("a") + ); + assert_in!( + item_fn.block.display_code(), + await_argument_code_string("b") + ); + assert_not_in!( + item_fn.block.display_code(), + await_argument_code_string("c") + ); + } + + #[test] + fn use_selective_await() { + let input_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut info: RsTestInfo = Default::default(); + info.arguments.set_future(ident("a"), FutureArg::Define); + info.arguments.set_future(ident("b"), FutureArg::Await); + + let item_fn: ItemFn = single(input_fn.clone(), info).ast(); + + assert_not_in!( + item_fn.block.display_code(), + await_argument_code_string("a",) + ); + assert_in!( + item_fn.block.display_code(), + await_argument_code_string("b") + ); + assert_not_in!( + item_fn.block.display_code(), + await_argument_code_string("c") + ); + } + #[test] fn trace_arguments_values() { let input_fn: ItemFn = r#"#[trace]fn test(s: String, a:i32) {} "#.ast(); @@ -210,6 +289,39 @@ mod single_test_should { assert_eq!(use_await, last_stmt.is_await()); } + #[test] + fn add_future_boilerplate_if_requested() { + let item_fn: ItemFn = r#" + async fn test(async_ref_u32: &u32, async_u32: u32,simple: u32) + { } + "# + .ast(); + + let mut arguments = ArgumentsInfo::default(); + arguments.add_future(ident("async_ref_u32")); + arguments.add_future(ident("async_u32")); + + let info = RsTestInfo { + arguments, + ..Default::default() + }; + + let result: ItemFn = single(item_fn.clone(), info).ast(); + let inner_fn = extract_inner_test_function(&result); + + let expected = parse_str::( + r#"async fn test<'_async_ref_u32>( + async_ref_u32: impl std::future::Future, + async_u32: impl std::future::Future, + simple: u32 + ) + { } + "#, + ) + .unwrap(); + + assert_eq!(inner_fn.sig, expected.sig); + } } struct TestsGroup { @@ -395,6 +507,7 @@ mod cases_should { use rstest_test::{assert_in, assert_not_in}; use crate::parse::{ + arguments::{ArgumentsInfo, FutureArg}, rstest::{RsTestData, RsTestInfo, RsTestItem}, testcase::TestCase, }; @@ -484,9 +597,11 @@ mod cases_should { let tokens = parametrize(item_fn.clone(), info); let mut output = TestsGroup::from(tokens); + let test_impl: Stmt = output.requested_test.block.stmts.last().cloned().unwrap(); output.requested_test.attrs = vec![]; - assert_eq!(output.requested_test, item_fn); + assert_eq!(output.requested_test.sig, item_fn.sig); + assert_eq!(test_impl.display_code(), item_fn.block.display_code()); } #[test] @@ -761,6 +876,36 @@ mod cases_should { assert_eq!(&tests[0].attrs[1..], attributes.as_slice()); } + #[test] + fn add_future_boilerplate_if_requested() { + let (item_fn, mut info) = TestCaseBuilder::from( + r#"async fn test(async_ref_u32: &u32, async_u32: u32,simple: u32) { }"#, + ) + .take(); + + let mut arguments = ArgumentsInfo::default(); + arguments.add_future(ident("async_ref_u32")); + arguments.add_future(ident("async_u32")); + + info.arguments = arguments; + + let tokens = parametrize(item_fn.clone(), info); + let test_function = TestsGroup::from(tokens).requested_test; + + let expected = parse_str::( + r#"async fn test<'_async_ref_u32>( + async_ref_u32: impl std::future::Future, + async_u32: impl std::future::Future, + simple: u32 + ) + { } + "#, + ) + .unwrap(); + + assert_eq!(test_function.sig, expected.sig); + } + #[rstest] #[case::sync(false, false)] #[case::async_fn(true, true)] @@ -856,12 +1001,56 @@ mod cases_should { trace_argument_code_string("a_no_trace_me") ); } + + #[test] + fn use_global_await() { + let (item_fn, mut info) = TestCaseBuilder::from(r#"fn test(a: i32, b:i32, c:i32) {}"#) + .push_case(TestCase::from_iter(vec!["1", "2", "3"])) + .push_case(TestCase::from_iter(vec!["1", "2", "3"])) + .take(); + info.arguments.set_global_await(true); + info.arguments.add_future(ident("a")); + info.arguments.add_future(ident("b")); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens); + + let code = tests.requested_test.block.display_code(); + + assert_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + #[test] + fn use_selective_await() { + let (item_fn, mut info) = TestCaseBuilder::from(r#"fn test(a: i32, b:i32, c:i32) {}"#) + .push_case(TestCase::from_iter(vec!["1", "2", "3"])) + .push_case(TestCase::from_iter(vec!["1", "2", "3"])) + .take(); + info.arguments.set_future(ident("a"), FutureArg::Define); + info.arguments.set_future(ident("b"), FutureArg::Await); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens); + + let code = tests.requested_test.block.display_code(); + + assert_not_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } } mod matrix_cases_should { use rstest_test::{assert_in, assert_not_in}; - use crate::parse::vlist::ValueList; + use crate::parse::{ + arguments::{ArgumentsInfo, FutureArg}, + vlist::ValueList, + }; /// Should test matrix tests render without take in account MatrixInfo to RsTestInfo /// transformation @@ -903,9 +1092,11 @@ mod matrix_cases_should { let tokens = matrix(item_fn.clone(), data.into()); let mut output = TestsGroup::from(tokens); + let test_impl: Stmt = output.requested_test.block.stmts.last().cloned().unwrap(); output.requested_test.attrs = vec![]; - assert_eq!(output.requested_test, item_fn); + assert_eq!(output.requested_test.sig, item_fn.sig); + assert_eq!(test_impl.display_code(), item_fn.block.display_code()); } #[test] @@ -1130,6 +1321,37 @@ mod matrix_cases_should { } } + #[test] + fn add_future_boilerplate_if_requested() { + let item_fn = r#"async fn test(async_ref_u32: &u32, async_u32: u32,simple: u32) { }"#.ast(); + + let mut arguments = ArgumentsInfo::default(); + arguments.add_future(ident("async_ref_u32")); + arguments.add_future(ident("async_u32")); + + let info = RsTestInfo { + arguments, + ..Default::default() + }; + + let tokens = matrix(item_fn, info); + + let test_function = TestsGroup::from(tokens).requested_test; + + let expected = parse_str::( + r#"async fn test<'_async_ref_u32>( + async_ref_u32: impl std::future::Future, + async_u32: impl std::future::Future, + simple: u32 + ) + { } + "#, + ) + .unwrap(); + + assert_eq!(test_function.sig, expected.sig); + } + #[rstest] fn add_allow_non_snake_case( #[values( @@ -1225,7 +1447,14 @@ mod matrix_cases_should { attributes.add_notraces(vec![ident("b_no_trace_me"), ident("c_no_trace_me")]); let item_fn: ItemFn = r#"#[trace] fn test(a_trace_me: u32, b_no_trace_me: u32, c_no_trace_me: u32, d_trace_me: u32) {}"#.ast(); - let tokens = matrix(item_fn, RsTestInfo { data, attributes }); + let tokens = matrix( + item_fn, + RsTestInfo { + data, + attributes, + ..Default::default() + }, + ); let tests = TestsGroup::from(tokens).get_all_tests(); @@ -1246,6 +1475,68 @@ mod matrix_cases_should { } } + #[test] + fn use_global_await() { + let item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {}"#.ast(); + let data = RsTestData { + items: vec![ + values_list("a", &["1"]).into(), + values_list("b", &["2"]).into(), + values_list("c", &["3"]).into(), + ] + .into(), + }; + let mut info = RsTestInfo { + data, + attributes: Default::default(), + arguments: Default::default(), + }; + info.arguments.set_global_await(true); + info.arguments.add_future(ident("a")); + info.arguments.add_future(ident("b")); + + let tokens = matrix(item_fn, info); + + let tests = TestsGroup::from(tokens); + + let code = tests.requested_test.block.display_code(); + + assert_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + #[test] + fn use_selective_await() { + let item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {}"#.ast(); + let data = RsTestData { + items: vec![ + values_list("a", &["1"]).into(), + values_list("b", &["2"]).into(), + values_list("c", &["3"]).into(), + ] + .into(), + }; + let mut info = RsTestInfo { + data, + attributes: Default::default(), + arguments: Default::default(), + }; + + info.arguments.set_future(ident("a"), FutureArg::Define); + info.arguments.set_future(ident("b"), FutureArg::Await); + + let tokens = matrix(item_fn, info); + + let tests = TestsGroup::from(tokens); + + let code = tests.requested_test.block.display_code(); + + assert_not_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + mod two_args_should { /// Should test matrix tests render without take in account MatrixInfo to RsTestInfo /// transformation diff --git a/rstest_macros/src/test.rs b/rstest_macros/src/test.rs index 52fe9a7..c4a25ca 100644 --- a/rstest_macros/src/test.rs +++ b/rstest_macros/src/test.rs @@ -10,7 +10,7 @@ pub(crate) use pretty_assertions::assert_eq; use proc_macro2::TokenTree; use quote::quote; pub(crate) use rstest::{fixture, rstest}; -use syn::{parse::Parse, parse2, parse_str, Error, Expr, Ident, ItemFn, Stmt}; +use syn::{parse::Parse, parse2, parse_quote, parse_str, Error, Expr, Ident, ItemFn, Stmt}; use super::*; use crate::parse::{ @@ -112,7 +112,7 @@ impl ToAst for proc_macro2::TokenStream { } } -pub(crate) fn ident(s: impl AsRef) -> syn::Ident { +pub(crate) fn ident(s: impl AsRef) -> Ident { s.as_ref().ast() } @@ -257,7 +257,7 @@ impl From for RsTestInfo { fn from(data: RsTestData) -> Self { Self { data, - attributes: Default::default(), + ..Default::default() } } } @@ -318,3 +318,11 @@ impl crate::parse::fixture::FixtureModifiers { self } } + +pub(crate) fn await_argument_code_string(arg_name: &str) -> String { + let arg_name = ident(arg_name); + let statment: Stmt = parse_quote! { + let #arg_name = #arg_name.await; + }; + statment.display_code() +}