Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TrustDnsResolver #1967

Merged
merged 4 commits into from Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/async_impl/client.rs
Expand Up @@ -266,7 +266,7 @@ impl ClientBuilder {
let mut resolver: Arc<dyn Resolve> = match config.trust_dns {
false => Arc::new(GaiResolver::new()),
#[cfg(feature = "trust-dns")]
true => Arc::new(TrustDnsResolver::new().map_err(crate::error::builder)?),
true => Arc::new(TrustDnsResolver::default()),
#[cfg(not(feature = "trust-dns"))]
true => unreachable!("trust-dns shouldn't be enabled unless the feature is"),
};
Expand Down
69 changes: 13 additions & 56 deletions src/dns/trust_dns.rs
@@ -1,8 +1,7 @@
//! DNS resolution via the [trust_dns_resolver](https://github.com/bluejekyll/trust-dns) crate

use hyper::client::connect::dns::Name;
use once_cell::sync::Lazy;
use tokio::sync::Mutex;
use once_cell::sync::OnceCell;
pub use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
use trust_dns_resolver::{lookup_ip::LookupIpIntoIter, system_conf, TokioAsyncResolver};

Expand All @@ -12,62 +11,24 @@ use std::sync::Arc;

use super::{Addrs, Resolve, Resolving};

type SharedResolver = Arc<TokioAsyncResolver>;

static SYSTEM_CONF: Lazy<io::Result<(ResolverConfig, ResolverOpts)>> =
Lazy::new(|| system_conf::read_system_conf().map_err(io::Error::from));

/// Wrapper around an `AsyncResolver`, which implements the `Resolve` trait.
#[derive(Debug, Clone)]
#[derive(Debug, Default, Clone)]
pub(crate) struct TrustDnsResolver {
state: Arc<Mutex<State>>,
/// Since we might not have been called in the context of a
/// Tokio Runtime in initialization, so we must delay the actual
/// construction of the resolver.
state: Arc<OnceCell<TokioAsyncResolver>>,
}

struct SocketAddrs {
iter: LookupIpIntoIter,
}

#[derive(Debug)]
enum State {
Init,
Ready(SharedResolver),
}

impl TrustDnsResolver {
/// Create a new resolver with the default configuration,
/// which reads from `/etc/resolve.conf`.
pub fn new() -> io::Result<Self> {
SYSTEM_CONF.as_ref().map_err(|e| {
io::Error::new(e.kind(), format!("error reading DNS system conf: {}", e))
})?;

// At this stage, we might not have been called in the context of a
// Tokio Runtime, so we must delay the actual construction of the
// resolver.
Ok(TrustDnsResolver {
state: Arc::new(Mutex::new(State::Init)),
})
}
}

impl Resolve for TrustDnsResolver {
fn resolve(&self, name: Name) -> Resolving {
let resolver = self.clone();
Box::pin(async move {
let mut lock = resolver.state.lock().await;

let resolver = match &*lock {
State::Init => {
let resolver = new_resolver().await;
*lock = State::Ready(resolver.clone());
resolver
}
State::Ready(resolver) => resolver.clone(),
};

// Don't keep lock once the resolver is constructed, otherwise
// only one lookup could be done at a time.
drop(lock);
let resolver = resolver.state.get_or_try_init(new_resolver)?;

let lookup = resolver.lookup_ip(name.as_str()).await?;
let addrs: Addrs = Box::new(SocketAddrs {
Expand All @@ -86,14 +47,10 @@ impl Iterator for SocketAddrs {
}
}

async fn new_resolver() -> SharedResolver {
let (config, opts) = SYSTEM_CONF
.as_ref()
.expect("can't construct TrustDnsResolver if SYSTEM_CONF is error")
.clone();
new_resolver_with_config(config, opts)
}

fn new_resolver_with_config(config: ResolverConfig, opts: ResolverOpts) -> SharedResolver {
Arc::new(TokioAsyncResolver::tokio(config, opts))
/// Create a new resolver with the default configuration,
/// which reads from `/etc/resolve.conf`.
fn new_resolver() -> io::Result<TokioAsyncResolver> {
let (config, opts) = system_conf::read_system_conf()
.map_err(|e| io::Error::new(e.kind(), format!("error reading DNS system conf: {}", e)))?;
Ok(TokioAsyncResolver::tokio(config, opts))
}