Skip to content

Commit

Permalink
Add Sync support
Browse files Browse the repository at this point in the history
  • Loading branch information
dcchut committed Mar 17, 2024
1 parent f01f486 commit bec0ae7
Show file tree
Hide file tree
Showing 17 changed files with 306 additions and 28 deletions.
26 changes: 18 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,35 @@ async fn fib(n : u32) -> u32 {
}
```

## ?Send Option
## ?Send option

The returned future has a `Send` bound to make sure it can be sent between threads.
The returned `Future` has a `Send` bound to make sure it can be sent between threads.
If this is undesirable you can mark that the bound should be left out like so:

```rust
#[async_recursion(?Send)]
async fn example() {
async fn returned_future_is_not_send() {
// ...
}
```

In detail:
## Sync option

The returned `Future` doesn't have a `Sync` bound as it is usually not required.
You can include a `Sync` bound as follows:

```rust
#[async_recursion(Sync)]
async fn returned_future_is_sync() {
// ...
}
```

- `#[async_recursion]` modifies your function to return a [`BoxFuture`], and
- `#[async_recursion(?Send)]` modifies your function to return a [`LocalBoxFuture`].
In detail:

[`BoxFuture`]: https://docs.rs/futures/0.3.19/futures/future/type.BoxFuture.html
[`LocalBoxFuture`]: https://docs.rs/futures/0.3.19/futures/future/type.LocalBoxFuture.html
- `#[async_recursion]` modifies your function to return a boxed `Future` with a `Send` bound.
- `#[async_recursion(?Send)]` modifies your function to return a boxed `Future` _without_ a `Send` bound.
- `#[async_recursion(Sync)]` modifies your function to return a boxed `Future` with a `Send` and `Sync` bound.

### License

Expand Down
8 changes: 7 additions & 1 deletion src/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
quote!()
};

let sync_bound: TokenStream = if args.sync_bound {
quote!(+ ::core::marker::Sync)
} else {
quote!()
};

let where_clause = sig
.generics
.where_clause
Expand All @@ -196,6 +202,6 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
// Modify the return type
sig.output = parse_quote! {
-> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = #ret> #box_lifetime #send_bound >>
dyn ::core::future::Future<Output = #ret> #box_lifetime #send_bound #sync_bound>>
};
}
27 changes: 20 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,40 @@
//! }
//! ```
//!
//! ## ?Send Option
//! ## ?Send option
//!
//! The returned future has a [`Send`] bound to make sure it can be sent between threads.
//! The returned [`Future`] has a [`Send`] bound to make sure it can be sent between threads.
//! If this is undesirable you can mark that the bound should be left out like so:
//!
//! ```rust
//! # use async_recursion::async_recursion;
//!
//! #[async_recursion(?Send)]
//! async fn example() {
//! async fn returned_future_is_not_send() {
//! // ...
//! }
//! ```
//!
//! ## Sync option
//!
//! The returned [`Future`] doesn't have a [`Sync`] bound as it is usually not required.
//! You can include a [`Sync`] bound as follows:
//!
//! ```rust
//! # use async_recursion::async_recursion;
//!
//! #[async_recursion(Sync)]
//! async fn returned_future_is_send_and_sync() {
//! // ...
//! }
//! ```
//!
//! In detail:
//!
//! - `#[async_recursion]` modifies your function to return a [`BoxFuture`], and
//! - `#[async_recursion(?Send)]` modifies your function to return a [`LocalBoxFuture`].
//!
//! [`BoxFuture`]: https://docs.rs/futures/0.3.19/futures/future/type.BoxFuture.html
//! [`LocalBoxFuture`]: https://docs.rs/futures/0.3.19/futures/future/type.LocalBoxFuture.html
//! - `#[async_recursion]` modifies your function to return a boxed [`Future`] with a [`Send`] bound.
//! - `#[async_recursion(?Send)]` modifies your function to return a boxed [`Future`] _without_ a [`Send`] bound.
//! - `#[async_recursion(Sync)]` modifies your function to return a boxed [`Future`] with [`Send`] and [`Sync`] bounds.
//!
//! ### License
//!
Expand Down
68 changes: 56 additions & 12 deletions src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,74 @@ impl Parse for AsyncItem {

pub struct RecursionArgs {
pub send_bound: bool,
}

impl Default for RecursionArgs {
fn default() -> Self {
RecursionArgs { send_bound: true }
}
pub sync_bound: bool,
}

/// Custom keywords for parser
mod kw {
syn::custom_keyword!(Send);
syn::custom_keyword!(Sync);
}

impl Parse for RecursionArgs {
#[derive(Debug, PartialEq, Eq)]
enum Arg {
NotSend,
Sync,
}

impl std::fmt::Display for Arg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotSend => write!(f, "?Send"),
Self::Sync => write!(f, "Sync"),
}
}
}

impl Parse for Arg {
fn parse(input: ParseStream) -> Result<Self> {
// Check for the `?Send` option
if input.peek(Token![?]) {
input.parse::<Question>()?;
input.parse::<kw::Send>()?;
Ok(Self { send_bound: false })
} else if !input.is_empty() {
Err(input.error("expected `?Send` or empty"))
Ok(Arg::NotSend)
} else {
Ok(Self::default())
input.parse::<kw::Sync>()?;
Ok(Arg::Sync)
}
}
}

impl Parse for RecursionArgs {
fn parse(input: ParseStream) -> Result<Self> {
let mut send_bound: bool = true;
let mut sync_bound: bool = false;

let args_parsed: Vec<Arg> =
syn::punctuated::Punctuated::<Arg, syn::Token![,]>::parse_terminated(input)
.map_err(|e| input.error(format!("failed to parse macro arguments: {e}")))?
.into_iter()
.collect();

// Avoid sloppy input
if args_parsed.len() > 2 {
return Err(Error::new(Span::call_site(), "received too many arguments"));
} else if args_parsed.len() == 2 && args_parsed[0] == args_parsed[1] {
return Err(Error::new(
Span::call_site(),
format!("received duplicate argument: `{}`", args_parsed[0]),
));
}

for arg in args_parsed {
match arg {
Arg::NotSend => send_bound = false,
Arg::Sync => sync_bound = true,
}
}

Ok(Self {
send_bound,
sync_bound,
})
}
}
11 changes: 11 additions & 0 deletions tests/args_sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use async_recursion::async_recursion;

#[async_recursion(Sync)]
async fn send_and_sync() {}

fn assert_is_send_and_sync(_: impl Send + Sync) {}

#[test]
fn test_sync_argument() {
assert_is_send_and_sync(send_and_sync());
}
5 changes: 5 additions & 0 deletions tests/expand/args_not_send.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use async_recursion::async_recursion;
#[must_use]
fn no_send_bound() -> ::core::pin::Pin<Box<dyn ::core::future::Future<Output = ()>>> {
Box::pin(async move {})
}
4 changes: 4 additions & 0 deletions tests/expand/args_not_send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
use async_recursion::async_recursion;

#[async_recursion(?Send)]
async fn no_send_bound() {}
25 changes: 25 additions & 0 deletions tests/expand/args_punctuated.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use async_recursion::async_recursion;
#[must_use]
fn not_send_sync_1() -> ::core::pin::Pin<
Box<dyn ::core::future::Future<Output = ()> + ::core::marker::Sync>,
> {
Box::pin(async move {})
}
#[must_use]
fn not_send_sync_2() -> ::core::pin::Pin<
Box<dyn ::core::future::Future<Output = ()> + ::core::marker::Sync>,
> {
Box::pin(async move {})
}
#[must_use]
fn sync_not_send_1() -> ::core::pin::Pin<
Box<dyn ::core::future::Future<Output = ()> + ::core::marker::Sync>,
> {
Box::pin(async move {})
}
#[must_use]
fn sync_not_send_2() -> ::core::pin::Pin<
Box<dyn ::core::future::Future<Output = ()> + ::core::marker::Sync>,
> {
Box::pin(async move {})
}
13 changes: 13 additions & 0 deletions tests/expand/args_punctuated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use async_recursion::async_recursion;

#[async_recursion(?Send, Sync)]
async fn not_send_sync_1() {}

#[async_recursion(?Send,Sync)]
async fn not_send_sync_2() {}

#[async_recursion(Sync, ?Send)]
async fn sync_not_send_1() {}

#[async_recursion(Sync,?Send)]
async fn sync_not_send_2() {}
11 changes: 11 additions & 0 deletions tests/expand/args_sync.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use async_recursion::async_recursion;
#[must_use]
fn sync() -> ::core::pin::Pin<
Box<
dyn ::core::future::Future<
Output = (),
> + ::core::marker::Send + ::core::marker::Sync,
>,
> {
Box::pin(async move {})
}
4 changes: 4 additions & 0 deletions tests/expand/args_sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
use async_recursion::async_recursion;

#[async_recursion(Sync)]
async fn sync() {}
15 changes: 15 additions & 0 deletions tests/ui/arg_not_sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use async_recursion::async_recursion;

fn assert_is_sync(_: impl Sync) {}


#[async_recursion]
async fn send_not_sync() {}

#[async_recursion(?Send)]
async fn not_send_not_sync() {}

fn main() {
assert_is_sync(send_not_sync());
assert_is_sync(not_send_not_sync());
}
51 changes: 51 additions & 0 deletions tests/ui/arg_not_sync.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
error[E0277]: `dyn Future<Output = ()> + Send` cannot be shared between threads safely
--> tests/ui/arg_not_sync.rs:13:20
|
13 | assert_is_sync(send_not_sync());
| -------------- ^^^^^^^^^^^^^^^ `dyn Future<Output = ()> + Send` cannot be shared between threads safely
| |
| required by a bound introduced by this call
|
= help: the trait `Sync` is not implemented for `dyn Future<Output = ()> + Send`
= note: required for `Unique<dyn Future<Output = ()> + Send>` to implement `Sync`
note: required because it appears within the type `Box<dyn Future<Output = ()> + Send>`
--> $RUST/alloc/src/boxed.rs
|
| pub struct Box<
| ^^^
note: required because it appears within the type `Pin<Box<dyn Future<Output = ()> + Send>>`
--> $RUST/core/src/pin.rs
|
| pub struct Pin<P> {
| ^^^
note: required by a bound in `assert_is_sync`
--> tests/ui/arg_not_sync.rs:3:27
|
3 | fn assert_is_sync(_: impl Sync) {}
| ^^^^ required by this bound in `assert_is_sync`

error[E0277]: `dyn Future<Output = ()>` cannot be shared between threads safely
--> tests/ui/arg_not_sync.rs:14:20
|
14 | assert_is_sync(not_send_not_sync());
| -------------- ^^^^^^^^^^^^^^^^^^^ `dyn Future<Output = ()>` cannot be shared between threads safely
| |
| required by a bound introduced by this call
|
= help: the trait `Sync` is not implemented for `dyn Future<Output = ()>`
= note: required for `Unique<dyn Future<Output = ()>>` to implement `Sync`
note: required because it appears within the type `Box<dyn Future<Output = ()>>`
--> $RUST/alloc/src/boxed.rs
|
| pub struct Box<
| ^^^
note: required because it appears within the type `Pin<Box<dyn Future<Output = ()>>>`
--> $RUST/core/src/pin.rs
|
| pub struct Pin<P> {
| ^^^
note: required by a bound in `assert_is_sync`
--> tests/ui/arg_not_sync.rs:3:27
|
3 | fn assert_is_sync(_: impl Sync) {}
| ^^^^ required by this bound in `assert_is_sync`
14 changes: 14 additions & 0 deletions tests/ui/args_invalid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use async_recursion::async_recursion;

#[async_recursion(?Sync)]
async fn not_sync() {}

#[async_recursion(Sync Sync)]
async fn not_punctuated() {}

#[async_recursion(Sync?Send)]
async fn what_even_is_this() {}



fn main() {}
17 changes: 17 additions & 0 deletions tests/ui/args_invalid.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
error: failed to parse macro arguments: expected `Send`
--> tests/ui/args_invalid.rs:3:20
|
3 | #[async_recursion(?Sync)]
| ^^^^

error: failed to parse macro arguments: expected `,`
--> tests/ui/args_invalid.rs:6:24
|
6 | #[async_recursion(Sync Sync)]
| ^^^^

error: failed to parse macro arguments: expected `,`
--> tests/ui/args_invalid.rs:9:23
|
9 | #[async_recursion(Sync?Send)]
| ^
12 changes: 12 additions & 0 deletions tests/ui/args_repeated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use async_recursion::async_recursion;

#[async_recursion(?Send, ?Send)]
async fn repeated_args_1() {}

#[async_recursion(?Send, Sync, ?Send)]
async fn repeated_args_2() {}

#[async_recursion(Sync, ?Send, Sync, ?Send)]
async fn repeated_args_3() {}

fn main() {}

0 comments on commit bec0ae7

Please sign in to comment.