Skip to content

Commit

Permalink
Add lifetime bounds to nested references (#38)
Browse files Browse the repository at this point in the history
* Add failing test

* Use visitor to collect references

* Include nested reference expansion output test

* Fix up a few comments
  • Loading branch information
dcchut committed Mar 17, 2024
1 parent 0b38976 commit f01f486
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 84 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ proc-macro = true
[dependencies]
proc-macro2 = { version = "1.0", default-features = false }
quote = { version = "1.0", default-features = false }
syn = { version = "2.0", features = ["full", "parsing", "printing", "proc-macro", "clone-impls"], default-features = false }
syn = { version = "2.0", features = ["full", "parsing", "printing", "proc-macro", "clone-impls", "visit-mut"], default-features = false }

[dev-dependencies]
futures-executor = "0.3"
Expand Down
155 changes: 72 additions & 83 deletions src/expand.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
parse_quote, punctuated::Punctuated, Block, FnArg, Lifetime, ReturnType, Signature, Type,
WhereClause,
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Block, Lifetime, Receiver,
ReturnType, Signature, TypeReference, WhereClause,
};

use crate::parse::{AsyncItem, RecursionArgs};
Expand Down Expand Up @@ -40,6 +40,63 @@ impl ArgLifetime {
}
}

#[derive(Default)]
struct ReferenceVisitor {
counter: usize,
lifetimes: Vec<ArgLifetime>,
self_receiver: bool,
self_receiver_new_lifetime: bool,
self_lifetime: Option<Lifetime>,
}

impl VisitMut for ReferenceVisitor {
fn visit_receiver_mut(&mut self, receiver: &mut Receiver) {
self.self_lifetime = Some(if let Some((_, lt)) = &mut receiver.reference {
self.self_receiver = true;

if let Some(lt) = lt {
lt.clone()
} else {
// Use 'life_self to avoid collisions with 'life<count> lifetimes.
let new_lifetime: Lifetime = parse_quote!('life_self);
lt.replace(new_lifetime.clone());

self.self_receiver_new_lifetime = true;

new_lifetime
}
} else {
return;
});
}

fn visit_type_reference_mut(&mut self, argument: &mut TypeReference) {
if argument.lifetime.is_none() {
// If this reference doesn't have a lifetime (e.g. &T), then give it one.
let lt = Lifetime::new(&format!("'life{}", self.counter), Span::call_site());
self.lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));
argument.lifetime = Some(lt);
self.counter += 1;
} else {
// If it does (e.g. &'life T), then keep track of it.
let lt = argument.lifetime.as_ref().cloned().unwrap();

// Check that this lifetime isn't already in our vector
let ident_matches = |x: &ArgLifetime| {
if let ArgLifetime::Existing(elt) = x {
elt.ident == lt.ident
} else {
false
}
};

if !self.lifetimes.iter().any(ident_matches) {
self.lifetimes.push(ArgLifetime::Existing(lt));
}
}
}
}

// Input:
// async fn f<S, T>(x : S, y : &T) -> Ret;
//
Expand All @@ -55,67 +112,13 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
// Remove the asyncness of this function
sig.asyncness = None;

// Find all reference arguments
let mut ref_arguments = Vec::new();
let mut self_lifetime = None;

for arg in &mut sig.inputs {
if let FnArg::Typed(pt) = arg {
match pt.ty.as_mut() {
// rustc can give us a None-delimited group if this type comes from
// a macro_rules macro. I don't think this can happen for code the user has written.
Type::Group(tg) => {
if let Type::Reference(tr) = &mut *tg.elem {
ref_arguments.push(tr);
}
}
Type::Reference(tr) => {
ref_arguments.push(tr);
}
_ => {}
}
} else if let FnArg::Receiver(recv) = arg {
if let Some((_, slt)) = &mut recv.reference {
self_lifetime = Some(slt);
}
}
}

let mut counter = 0;
let mut lifetimes = Vec::new();

if !ref_arguments.is_empty() {
for ra in &mut ref_arguments {
// If this reference arg doesn't have a lifetime, give it an explicit one
if ra.lifetime.is_none() {
let lt = Lifetime::new(&format!("'life{counter}"), Span::call_site());

lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));

ra.lifetime = Some(lt);
counter += 1;
} else {
let lt = ra.lifetime.as_ref().cloned().unwrap();

// Check that this lifetime isn't already in our vector
let ident_matches = |x: &ArgLifetime| {
if let ArgLifetime::Existing(elt) = x {
elt.ident == lt.ident
} else {
false
}
};

if !lifetimes.iter().any(ident_matches) {
lifetimes.push(ArgLifetime::Existing(
ra.lifetime.as_ref().cloned().unwrap(),
));
}
}
}
// Find and update any references in the input arguments
let mut v = ReferenceVisitor::default();
for input in &mut sig.inputs {
v.visit_fn_arg_mut(input);
}

// Does this expansion require `async_recursion to be added to the output
// Does this expansion require `async_recursion to be added to the output?
let mut requires_lifetime = false;
let mut where_clause_lifetimes = vec![];
let mut where_clause_generics = vec![];
Expand All @@ -127,13 +130,13 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
for param in sig.generics.type_params() {
let ident = param.ident.clone();
where_clause_generics.push(ident);

requires_lifetime = true;
}

// Add an 'a : 'async_recursion bound to any lifetimes 'a appearing in the function
if !lifetimes.is_empty() {
for alt in lifetimes {
if !v.lifetimes.is_empty() {
requires_lifetime = true;
for alt in v.lifetimes {
if let ArgLifetime::New(lt) = &alt {
// If this is a new argument,
sig.generics.params.push(parse_quote!(#lt));
Expand All @@ -143,29 +146,15 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
let lt = alt.lifetime();
where_clause_lifetimes.push(lt);
}

requires_lifetime = true;
}

// If our function accepts &self, then we modify this to the explicit lifetime &'life_self,
// and add the bound &'life_self : 'async_recursion
if let Some(slt) = self_lifetime {
let lt = {
if let Some(lt) = slt.as_mut() {
lt.clone()
} else {
// We use `life_self here to avoid any collisions with `life0, `life1 from above
let lt: Lifetime = parse_quote!('life_self);
sig.generics.params.push(parse_quote!(#lt));

// add lt to the lifetime of self
*slt = Some(lt.clone());

lt
}
};

where_clause_lifetimes.push(lt);
if v.self_receiver {
if v.self_receiver_new_lifetime {
sig.generics.params.push(parse_quote!('life_self));
}
where_clause_lifetimes.extend(v.self_lifetime);
requires_lifetime = true;
}

Expand Down
19 changes: 19 additions & 0 deletions tests/expand/lifetimes_nested_reference.expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use async_recursion::async_recursion;
#[must_use]
fn count_down<'life0, 'async_recursion>(
foo: Option<&'life0 str>,
) -> ::core::pin::Pin<
Box<
dyn ::core::future::Future<
Output = i32,
> + 'async_recursion + ::core::marker::Send,
>,
>
where
'life0: 'async_recursion,
{
Box::pin(async move {
let _ = foo;
0
})
}
7 changes: 7 additions & 0 deletions tests/expand/lifetimes_nested_reference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use async_recursion::async_recursion;

#[async_recursion]
async fn count_down(foo: Option<&str>) -> i32 {
let _ = foo;
0
}
9 changes: 9 additions & 0 deletions tests/lifetimes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ async fn contains_value_2<'a, 'b, T: PartialEq>(value: &'b T, node: &'b Node<'a,
contains_value(value, node).await
}

// The reference inside foo needs a `async_recursion bound
#[async_recursion]
async fn count_down(foo: Option<&str>) -> i32 {
let _ = foo;
0
}

#[test]
fn lifetime_expansion_works() {
block_on(async move {
Expand Down Expand Up @@ -64,5 +71,7 @@ fn lifetime_expansion_works() {
assert_eq!(contains_value_2(&17, &node).await, true);
assert_eq!(contains_value_2(&13, &node).await, true);
assert_eq!(contains_value_2(&12, &node).await, false);

count_down(None).await;
});
}

0 comments on commit f01f486

Please sign in to comment.