Skip to content

Commit

Permalink
Make pooling more robust
Browse files Browse the repository at this point in the history
Add PoolConnection to listen for errors and hold a list of PoolClients that
can be reused.

Signed-off-by: Miguel Guarniz <mi9uel9@gmail.com>
  • Loading branch information
kckeiks committed Aug 12, 2022
1 parent 7931f0c commit f069009
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 50 deletions.
10 changes: 9 additions & 1 deletion src/async_impl/client.rs
Expand Up @@ -1938,7 +1938,15 @@ impl PendingRequest {
}

fn is_retryable_error(err: &(dyn std::error::Error + 'static)) -> bool {
// TODO: Does the h3 API provide a way to determine this same type of case?
#[cfg(feature = "http3")]
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<h3::Error>() {
debug!("determining if HTTP/3 error {} can be retried", err);
// TODO: Does h3 provide an API for checking the error?
return err.to_string().as_str() == "timeout";
}
}

if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<h2::Error>() {
// They sent us a graceful shutdown, try with a new connection!
Expand Down
18 changes: 7 additions & 11 deletions src/async_impl/h3_client/connect.rs
@@ -1,14 +1,16 @@
use crate::async_impl::h3_client::dns::Resolver;
use crate::error::BoxError;
use bytes::Bytes;
use futures_util::future;
use h3::client::Connection as H3Conn;
use h3::client::SendRequest;
use h3_quinn::{Connection, OpenStreams};
use http::Uri;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;

type H3Connection = (H3Conn<Connection, Bytes>, SendRequest<OpenStreams, Bytes>);

#[derive(Clone)]
pub(crate) struct H3Connector {
resolver: Resolver,
Expand All @@ -35,10 +37,7 @@ impl H3Connector {
Self { resolver, endpoint }
}

pub async fn connect(
&mut self,
dest: Uri,
) -> Result<SendRequest<OpenStreams, Bytes>, BoxError> {
pub async fn connect(&mut self, dest: Uri) -> Result<H3Connection, BoxError> {
let host = dest.host().ok_or("destination must have a host")?;
let port = dest.port_u16().unwrap_or(443);

Expand All @@ -61,17 +60,14 @@ impl H3Connector {
&mut self,
addrs: Vec<SocketAddr>,
server_name: &str,
) -> Result<SendRequest<OpenStreams, Bytes>, BoxError> {
) -> Result<H3Connection, BoxError> {
let mut err = None;
for addr in addrs {
match self.endpoint.connect(addr, server_name)?.await {
Ok(new_conn) => {
let quinn_conn = Connection::new(new_conn);
let (mut driver, tx) = h3::client::new(quinn_conn).await?;
tokio::spawn(async move {
future::poll_fn(|cx| driver.poll_close(cx)).await.unwrap();
});
return Ok(tx);
let (driver, tx) = h3::client::new(quinn_conn).await?;
return Ok((driver, tx));
}
Err(e) => err = Some(e),
}
Expand Down
11 changes: 7 additions & 4 deletions src/async_impl/h3_client/mod.rs
Expand Up @@ -61,11 +61,14 @@ impl H3Client {
return Ok(client);
}

debug!(
"unable to find connection {:?} in pool so connecting...",
key
);

let dest = pool::domain_as_uri(key.clone());
let tx = self.connector.connect(dest).await?;
let client = PoolClient::new(tx);
self.pool.put(key, client.clone());
Ok(client)
let (driver, tx) = self.connector.connect(dest).await?;
Ok(self.pool.new_connection(key, driver, tx))
}

async fn send_request(
Expand Down
129 changes: 95 additions & 34 deletions src/async_impl/h3_client/pool.rs
@@ -1,13 +1,17 @@
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::mpsc::{Receiver, TryRecvError};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::Instant;

use crate::error::{BoxError, Error, Kind};
use crate::Body;
use bytes::Buf;
use futures_util::future;
use h3::client::Connection as H3Conn;
use h3::client::SendRequest;
use h3_quinn::{Connection, OpenStreams};
use http::uri::{Authority, Scheme};
use http::{Request, Response, Uri};
use hyper::Body as HyperBody;
Expand All @@ -24,72 +28,92 @@ impl Pool {
pub fn new(max_idle_per_host: usize, timeout: Option<Duration>) -> Self {
Self {
inner: Arc::new(Mutex::new(PoolInner {
idle: HashMap::new(),
idle_conns: HashMap::new(),
max_idle_per_host,
timeout,
})),
}
}

pub fn put(&self, key: Key, client: PoolClient) {
let mut inner = self.inner.lock().unwrap();
inner.put(key, client)
}

pub fn try_pool(&self, key: &Key) -> Option<PoolClient> {
let mut inner = self.inner.lock().unwrap();
let timeout = inner.timeout;
inner.idle.get_mut(&key).and_then(|list| match list.pop() {
Some(idle) => {
if let Some(duration) = timeout {
if Instant::now().saturating_duration_since(idle.idle_at) > duration {
debug!("pooled client expired");
return None;
if let Some(conn) = inner.idle_conns.get(&key) {
// We check first if the connection still valid
// and if not, we remove it from the pool.
if conn.is_invalid() {
debug!("pooled HTTP/3 connection is invalid so removing it...");
inner.idle_conns.remove(&key);
}
}

inner
.idle_conns
.get_mut(&key)
.and_then(|conn| match conn.idle.pop() {
Some(idle) => {
if let Some(duration) = timeout {
if Instant::now().saturating_duration_since(idle.idle_at) > duration {
debug!("pooled client expired");
return None;
}
}

conn.push_client(idle.value.clone());
Some(idle.value)
}
Some(idle.value)
None => None,
})
}

pub fn new_connection(
&mut self,
key: Key,
mut driver: H3Conn<Connection, Bytes>,
tx: SendRequest<OpenStreams, Bytes>,
) -> PoolClient {
let (send, receive) = std::sync::mpsc::channel();
tokio::spawn(async move {
if let Err(e) = future::poll_fn(|cx| driver.poll_close(cx)).await {
debug!("poll_close returned error {:?}", e);
send.send(e).ok();
}
None => None,
})
});

let mut inner = self.inner.lock().unwrap();

let mut conn = PoolConnection::new(receive, inner.max_idle_per_host);
let client = PoolClient::new(tx);
conn.push_client(client.clone());
inner.insert(key, conn);

client
}
}

struct PoolInner {
// These are internal Conns sitting in the event loop in the KeepAlive
// state, waiting to receive a new Request to send on the socket.
idle: HashMap<Key, Vec<Idle>>,
idle_conns: HashMap<Key, PoolConnection>,
max_idle_per_host: usize,
timeout: Option<Duration>,
}

impl PoolInner {
fn put(&mut self, key: Key, client: PoolClient) {
if self.idle.contains_key(&key) {
fn insert(&mut self, key: Key, conn: PoolConnection) {
if self.idle_conns.contains_key(&key) {
debug!("connection already exists for key {:?}", key);
return;
}

let idle_list = self.idle.entry(key.clone()).or_default();

if idle_list.len() >= self.max_idle_per_host {
debug!("max idle per host for {:?}, dropping connection", key);
return;
}

idle_list.push(Idle {
idle_at: Instant::now(),
value: client,
});
self.idle_conns.insert(key, conn);
}
}

#[derive(Clone)]
pub struct PoolClient {
tx: SendRequest<h3_quinn::OpenStreams, Bytes>,
tx: SendRequest<OpenStreams, Bytes>,
}

impl PoolClient {
pub fn new(tx: SendRequest<h3_quinn::OpenStreams, Bytes>) -> Self {
pub fn new(tx: SendRequest<OpenStreams, Bytes>) -> Self {
Self { tx }
}

Expand Down Expand Up @@ -121,6 +145,43 @@ impl PoolClient {
}
}

pub struct PoolConnection {
// This receives errors from polling h3 driver.
rx: Receiver<h3::Error>,
idle: Vec<Idle>,
max_idle_per_host: usize,
}

impl PoolConnection {
pub fn new(rx: Receiver<h3::Error>, max_idle_per_host: usize) -> Self {
Self {
rx,
idle: Vec::new(),
max_idle_per_host,
}
}

pub fn push_client(&mut self, client: PoolClient) {
if self.idle.len() >= self.max_idle_per_host {
debug!("max idle per host reached for HTTP/3 pool connection");
return;
}

self.idle.push(Idle {
idle_at: Instant::now(),
value: client,
});
}

pub fn is_invalid(&self) -> bool {
match self.rx.try_recv() {
Err(TryRecvError::Empty) => false,
Err(TryRecvError::Disconnected) => true,
Ok(_) => true,
}
}
}

struct Idle {
idle_at: Instant,
value: PoolClient,
Expand Down

0 comments on commit f069009

Please sign in to comment.