Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lifetime bounds to nested references #38

Merged
merged 4 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ proc-macro = true
[dependencies]
proc-macro2 = { version = "1.0", default-features = false }
quote = { version = "1.0", default-features = false }
syn = { version = "2.0", features = ["full", "parsing", "printing", "proc-macro", "clone-impls"], default-features = false }
syn = { version = "2.0", features = ["full", "parsing", "printing", "proc-macro", "clone-impls", "visit-mut"], default-features = false }

[dev-dependencies]
futures-executor = "0.3"
Expand Down
155 changes: 72 additions & 83 deletions src/expand.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
parse_quote, punctuated::Punctuated, Block, FnArg, Lifetime, ReturnType, Signature, Type,
WhereClause,
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Block, Lifetime, Receiver,
ReturnType, Signature, TypeReference, WhereClause,
};

use crate::parse::{AsyncItem, RecursionArgs};
Expand Down Expand Up @@ -40,6 +40,63 @@ impl ArgLifetime {
}
}

#[derive(Default)]
struct ReferenceVisitor {
counter: usize,
lifetimes: Vec<ArgLifetime>,
self_receiver: bool,
self_receiver_new_lifetime: bool,
self_lifetime: Option<Lifetime>,
}

impl VisitMut for ReferenceVisitor {
fn visit_receiver_mut(&mut self, receiver: &mut Receiver) {
self.self_lifetime = Some(if let Some((_, lt)) = &mut receiver.reference {
self.self_receiver = true;

if let Some(lt) = lt {
lt.clone()
} else {
// Use 'life_self to avoid collisions with 'life<count> lifetimes.
let new_lifetime: Lifetime = parse_quote!('life_self);
lt.replace(new_lifetime.clone());

self.self_receiver_new_lifetime = true;

new_lifetime
}
} else {
return;
});
}

fn visit_type_reference_mut(&mut self, argument: &mut TypeReference) {
if argument.lifetime.is_none() {
// If this reference doesn't have a lifetime (e.g. &T), then give it one.
let lt = Lifetime::new(&format!("'life{}", self.counter), Span::call_site());
self.lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));
argument.lifetime = Some(lt);
self.counter += 1;
} else {
// If it does (e.g. &'life T), then keep track of it.
let lt = argument.lifetime.as_ref().cloned().unwrap();

// Check that this lifetime isn't already in our vector
let ident_matches = |x: &ArgLifetime| {
if let ArgLifetime::Existing(elt) = x {
elt.ident == lt.ident
} else {
false
}
};

if !self.lifetimes.iter().any(ident_matches) {
self.lifetimes.push(ArgLifetime::Existing(lt));
}
}
}
}

// Input:
// async fn f<S, T>(x : S, y : &T) -> Ret;
//
Expand All @@ -55,67 +112,13 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
// Remove the asyncness of this function
sig.asyncness = None;

// Find all reference arguments
let mut ref_arguments = Vec::new();
let mut self_lifetime = None;

for arg in &mut sig.inputs {
if let FnArg::Typed(pt) = arg {
match pt.ty.as_mut() {
// rustc can give us a None-delimited group if this type comes from
// a macro_rules macro. I don't think this can happen for code the user has written.
Type::Group(tg) => {
if let Type::Reference(tr) = &mut *tg.elem {
ref_arguments.push(tr);
}
}
Type::Reference(tr) => {
ref_arguments.push(tr);
}
_ => {}
}
} else if let FnArg::Receiver(recv) = arg {
if let Some((_, slt)) = &mut recv.reference {
self_lifetime = Some(slt);
}
}
}

let mut counter = 0;
let mut lifetimes = Vec::new();

if !ref_arguments.is_empty() {
for ra in &mut ref_arguments {
// If this reference arg doesn't have a lifetime, give it an explicit one
if ra.lifetime.is_none() {
let lt = Lifetime::new(&format!("'life{counter}"), Span::call_site());

lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));

ra.lifetime = Some(lt);
counter += 1;
} else {
let lt = ra.lifetime.as_ref().cloned().unwrap();

// Check that this lifetime isn't already in our vector
let ident_matches = |x: &ArgLifetime| {
if let ArgLifetime::Existing(elt) = x {
elt.ident == lt.ident
} else {
false
}
};

if !lifetimes.iter().any(ident_matches) {
lifetimes.push(ArgLifetime::Existing(
ra.lifetime.as_ref().cloned().unwrap(),
));
}
}
}
// Find and update any references in the input arguments
let mut v = ReferenceVisitor::default();
for input in &mut sig.inputs {
v.visit_fn_arg_mut(input);
}

// Does this expansion require `async_recursion to be added to the output
// Does this expansion require `async_recursion to be added to the output?
let mut requires_lifetime = false;
let mut where_clause_lifetimes = vec![];
let mut where_clause_generics = vec![];
Expand All @@ -127,13 +130,13 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
for param in sig.generics.type_params() {
let ident = param.ident.clone();
where_clause_generics.push(ident);

requires_lifetime = true;
}

// Add an 'a : 'async_recursion bound to any lifetimes 'a appearing in the function
if !lifetimes.is_empty() {
for alt in lifetimes {
if !v.lifetimes.is_empty() {
requires_lifetime = true;
for alt in v.lifetimes {
if let ArgLifetime::New(lt) = &alt {
// If this is a new argument,
sig.generics.params.push(parse_quote!(#lt));
Expand All @@ -143,29 +146,15 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
let lt = alt.lifetime();
where_clause_lifetimes.push(lt);
}

requires_lifetime = true;
}

// If our function accepts &self, then we modify this to the explicit lifetime &'life_self,
// and add the bound &'life_self : 'async_recursion
if let Some(slt) = self_lifetime {
let lt = {
if let Some(lt) = slt.as_mut() {
lt.clone()
} else {
// We use `life_self here to avoid any collisions with `life0, `life1 from above
let lt: Lifetime = parse_quote!('life_self);
sig.generics.params.push(parse_quote!(#lt));

// add lt to the lifetime of self
*slt = Some(lt.clone());

lt
}
};

where_clause_lifetimes.push(lt);
if v.self_receiver {
if v.self_receiver_new_lifetime {
sig.generics.params.push(parse_quote!('life_self));
}
where_clause_lifetimes.extend(v.self_lifetime);
requires_lifetime = true;
}

Expand Down
19 changes: 19 additions & 0 deletions tests/expand/lifetimes_nested_reference.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use async_recursion::async_recursion;
#[must_use]
fn count_down<'life0, 'async_recursion>(
foo: Option<&'life0 str>,
) -> ::core::pin::Pin<
Box<
dyn ::core::future::Future<
Output = i32,
> + 'async_recursion + ::core::marker::Send,
>,
>
where
'life0: 'async_recursion,
{
Box::pin(async move {
let _ = foo;
0
})
}
7 changes: 7 additions & 0 deletions tests/expand/lifetimes_nested_reference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use async_recursion::async_recursion;

#[async_recursion]
async fn count_down(foo: Option<&str>) -> i32 {
let _ = foo;
0
}
9 changes: 9 additions & 0 deletions tests/lifetimes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ async fn contains_value_2<'a, 'b, T: PartialEq>(value: &'b T, node: &'b Node<'a,
contains_value(value, node).await
}

// The reference inside foo needs a `async_recursion bound
#[async_recursion]
async fn count_down(foo: Option<&str>) -> i32 {
let _ = foo;
0
}

#[test]
fn lifetime_expansion_works() {
block_on(async move {
Expand Down Expand Up @@ -64,5 +71,7 @@ fn lifetime_expansion_works() {
assert_eq!(contains_value_2(&17, &node).await, true);
assert_eq!(contains_value_2(&13, &node).await, true);
assert_eq!(contains_value_2(&12, &node).await, false);

count_down(None).await;
});
}