Skip to content

Commit

Permalink
Add cfg to HTTP version enum
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Guarniz <mi9uel9@gmail.com>
  • Loading branch information
kckeiks committed Aug 6, 2022
1 parent eb66d19 commit 1b68cfc
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 74 deletions.
58 changes: 32 additions & 26 deletions src/async_impl/client.rs
Expand Up @@ -22,11 +22,12 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::time::Sleep;

use log::{debug, trace};
use super::decoder::Accepts;
use super::request::{Request, RequestBuilder};
use super::response::Response;
use super::Body;
#[cfg(feature = "http3")]
use crate::async_impl::h3_client::{H3Builder, H3Client, H3ResponseFuture};
use crate::connect::{Connector, HttpConnector};
#[cfg(feature = "cookies")]
use crate::cookie;
Expand All @@ -40,8 +41,7 @@ use crate::Certificate;
#[cfg(any(feature = "native-tls", feature = "__rustls"))]
use crate::Identity;
use crate::{IntoUrl, Method, Proxy, StatusCode, Url};
#[cfg(feature = "http3")]
use crate::async_impl::h3_client::{H3Builder, H3Client, H3ResponseFuture};
use log::{debug, trace};

/// An asynchronous `Client` to make Requests with.
///
Expand Down Expand Up @@ -70,6 +70,7 @@ pub struct ClientBuilder {
enum HttpVersionPref {
Http1,
Http2,
#[cfg(feature = "http3")]
Http3,
All,
}
Expand Down Expand Up @@ -124,7 +125,7 @@ struct Config {
https_only: bool,
dns_overrides: HashMap<String, SocketAddr>,
#[cfg(feature = "http3")]
tls_enable_early_data: bool
tls_enable_early_data: bool,
}

impl Default for ClientBuilder {
Expand Down Expand Up @@ -193,7 +194,7 @@ impl ClientBuilder {
https_only: false,
dns_overrides: HashMap::new(),
#[cfg(feature = "http3")]
tls_enable_early_data: false
tls_enable_early_data: false,
},
}
}
Expand Down Expand Up @@ -249,7 +250,7 @@ impl ClientBuilder {
TlsBackend::Default => {
let mut tls = TlsConnector::builder();

#[cfg(feature = "native-tls-alpn")]
#[cfg(all(feature = "native-tls-alpn", not(feature = "http3")))]
{
match config.http_version_pref {
HttpVersionPref::Http1 => {
Expand All @@ -258,9 +259,6 @@ impl ClientBuilder {
HttpVersionPref::Http2 => {
tls.request_alpns(&["h2"]);
}
HttpVersionPref::Http3 => {
unreachable!("HTTP/3 shouldn't be enabled unless the feature is")
},
HttpVersionPref::All => {
tls.request_alpns(&["h2", "http/1.1"]);
}
Expand Down Expand Up @@ -443,6 +441,7 @@ impl ClientBuilder {
HttpVersionPref::Http2 => {
tls.alpn_protocols = vec!["h2".into()];
}
#[cfg(feature = "http3")]
HttpVersionPref::Http3 => {
tls.alpn_protocols = vec!["h3".into()];
}
Expand Down Expand Up @@ -941,6 +940,7 @@ impl ClientBuilder {
}

/// Only use HTTP/3.
#[cfg(feature = "http3")]
pub fn http3_prior_knowledge(mut self) -> ClientBuilder {
self.config.http_version_pref = HttpVersionPref::Http3;
self
Expand Down Expand Up @@ -1535,12 +1535,14 @@ impl Client {
let mut req = builder.body(()).expect("valid request parts");
*req.headers_mut() = headers.clone();
ResponseFuture::H3(self.inner.h3_client.request(req))
},
}
_ => {
let mut req = builder.body(body.into_stream()).expect("valid request parts");
let mut req = builder
.body(body.into_stream())
.expect("valid request parts");
*req.headers_mut() = headers.clone();
ResponseFuture::Default(self.inner.hyper.request(req))
},
}
};

let timeout = timeout
Expand Down Expand Up @@ -1875,7 +1877,8 @@ impl PendingRequest {

*req.headers_mut() = self.headers.clone();

*self.as_mut().in_flight().get_mut() = ResponseFuture::Default(self.client.hyper.request(req));
*self.as_mut().in_flight().get_mut() =
ResponseFuture::Default(self.client.hyper.request(req));

true
}
Expand Down Expand Up @@ -1933,29 +1936,31 @@ impl Future for PendingRequest {

loop {
let res = match self.as_mut().in_flight().get_mut() {
ResponseFuture::Default(r) => {
match Pin::new(r).poll(cx) {
Poll::Ready(Err(e)) => {
if self.as_mut().retry_error(&e) {
continue;
}
return Poll::Ready(Err(crate::error::request(e).with_url(self.url.clone())));
ResponseFuture::Default(r) => match Pin::new(r).poll(cx) {
Poll::Ready(Err(e)) => {
if self.as_mut().retry_error(&e) {
continue;
}
Poll::Ready(Ok(res)) => res,
Poll::Pending => return Poll::Pending,
return Poll::Ready(Err(
crate::error::request(e).with_url(self.url.clone())
));
}
}
Poll::Ready(Ok(res)) => res,
Poll::Pending => return Poll::Pending,
},
#[cfg(feature = "http3")]
ResponseFuture::H3(r) => match Pin::new(r).poll(cx) {
Poll::Ready(Err(e)) => {
if self.as_mut().retry_error(&e) {
continue;
}
return Poll::Ready(Err(crate::error::request(e).with_url(self.url.clone())));
return Poll::Ready(Err(
crate::error::request(e).with_url(self.url.clone())
));
}
Poll::Ready(Ok(res)) => res,
Poll::Pending => return Poll::Pending,
}
},
};

#[cfg(feature = "cookies")]
Expand Down Expand Up @@ -2071,7 +2076,8 @@ impl Future for PendingRequest {

*req.headers_mut() = headers.clone();
std::mem::swap(self.as_mut().headers(), &mut headers);
*self.as_mut().in_flight().get_mut() = ResponseFuture::Default(self.client.hyper.request(req));
*self.as_mut().in_flight().get_mut() =
ResponseFuture::Default(self.client.hyper.request(req));
continue;
}
redirect::ActionKind::Stop => {
Expand Down
34 changes: 19 additions & 15 deletions src/async_impl/h3_client/mod.rs
Expand Up @@ -2,20 +2,20 @@

mod pool;

use crate::async_impl::h3_client::pool::{Key, Pool, PoolClient};
use crate::error;
use crate::error::{BoxError, Error, Kind};
use futures_util::future;
use h3_quinn::Connection;
use http::{Request, Response};
use hyper::Body;
use log::debug;
use std::future::Future;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use http::{Request, Response};
use crate::error::{BoxError, Error, Kind};
use hyper::Body;
use futures_util::future;
use h3_quinn::Connection;
use log::debug;
use crate::async_impl::h3_client::pool::{Key, Pool, PoolClient};
use crate::error;

pub struct H3Builder {
pool_idle_timeout: Option<Duration>,
Expand All @@ -41,8 +41,8 @@ impl H3Builder {
None => "[::]:0".parse::<SocketAddr>().unwrap(),
};

let mut endpoint = quinn::Endpoint::client(socket_addr)
.expect("unable to create QUIC endpoint");
let mut endpoint =
quinn::Endpoint::client(socket_addr).expect("unable to create QUIC endpoint");
endpoint.set_default_client_config(config);

H3Client {
Expand Down Expand Up @@ -91,9 +91,7 @@ impl H3Client {
.next()
.ok_or("dns found no addresses")?;

let quinn_conn = Connection::new(
self.endpoint.connect(addr, auth.host())?.await?
);
let quinn_conn = Connection::new(self.endpoint.connect(addr, auth.host())?.await?);
let (mut driver, tx) = h3::client::new(quinn_conn).await?;

// TODO: What does poll_close() do?
Expand All @@ -120,9 +118,15 @@ impl H3Client {
pub fn request(&self, mut req: Request<()>) -> H3ResponseFuture {
let pool_key = match pool::extract_domain(req.uri_mut()) {
Ok(s) => s,
Err(e) => return H3ResponseFuture{inner: Box::pin(future::err(e))},
Err(e) => {
return H3ResponseFuture {
inner: Box::pin(future::err(e)),
}
}
};
H3ResponseFuture{inner: Box::pin(self.clone().send_request(pool_key, req))}
H3ResponseFuture {
inner: Box::pin(self.clone().send_request(pool_key, req)),
}
}
}

Expand Down
46 changes: 20 additions & 26 deletions src/async_impl/h3_client/pool.rs
@@ -1,22 +1,22 @@
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use bytes::Bytes;
use tokio::time::Instant;

use crate::error::{BoxError, Error, Kind};
use bytes::Buf;
use h3::client::SendRequest;
use http::{Request, Response, Uri};
use http::uri::{Authority, Scheme};
use http::{Request, Response, Uri};
use hyper::Body;
use crate::error::{BoxError, Error, Kind};
use bytes::Buf;
use log::debug;

pub(super) type Key = (Scheme, Authority);

#[derive(Clone)]
pub struct Pool {
inner: Arc<Mutex<PoolInner>>
inner: Arc<Mutex<PoolInner>>,
}

impl Pool {
Expand All @@ -26,7 +26,7 @@ impl Pool {
idle: HashMap::new(),
max_idle_per_host,
timeout,
}))
})),
}
}

Expand All @@ -38,19 +38,17 @@ impl Pool {
pub fn try_pool(&self, key: &Key) -> Option<PoolClient> {
let mut inner = self.inner.lock().unwrap();
let timeout = inner.timeout;
inner.idle.get_mut(&key).and_then(|list| {
match list.pop() {
Some(idle) => {
if let Some(duration) = timeout {
if Instant::now().saturating_duration_since(idle.idle_at) > duration {
debug!("pooled client expired");
return None;
}
inner.idle.get_mut(&key).and_then(|list| match list.pop() {
Some(idle) => {
if let Some(duration) = timeout {
if Instant::now().saturating_duration_since(idle.idle_at) > duration {
debug!("pooled client expired");
return None;
}
Some(idle.value)
},
None => None,
}
Some(idle.value)
}
None => None,
})
}
}
Expand Down Expand Up @@ -79,21 +77,19 @@ impl PoolInner {

idle_list.push(Idle {
idle_at: Instant::now(),
value: client
value: client,
});
}
}

#[derive(Clone)]
pub struct PoolClient {
tx: SendRequest<h3_quinn::OpenStreams, Bytes>
tx: SendRequest<h3_quinn::OpenStreams, Bytes>,
}

impl PoolClient {
pub fn new(tx: SendRequest<h3_quinn::OpenStreams, Bytes>) -> Self {
Self {
tx
}
Self { tx }
}

// TODO: add support for sending data.
Expand All @@ -108,9 +104,7 @@ impl PoolClient {
body.extend(chunk.chunk())
}

Ok(resp.map(|_| {
Body::from(body)
}))
Ok(resp.map(|_| Body::from(body)))
}
}

Expand All @@ -134,4 +128,4 @@ pub(crate) fn domain_as_uri((scheme, auth): Key) -> Uri {
.path_and_query("/")
.build()
.expect("domain is valid Uri")
}
}
2 changes: 1 addition & 1 deletion src/async_impl/mod.rs
Expand Up @@ -9,9 +9,9 @@ pub(crate) use self::decoder::Decoder;
pub mod body;
pub mod client;
pub mod decoder;
pub mod h3_client;
#[cfg(feature = "multipart")]
pub mod multipart;
pub(crate) mod request;
mod response;
mod upgrade;
pub mod h3_client;
7 changes: 2 additions & 5 deletions src/connect.rs
Expand Up @@ -31,7 +31,6 @@ use crate::dns::TrustDnsResolver;
use crate::error::BoxError;
use crate::proxy::{Proxy, ProxyScheme};


#[derive(Clone)]
pub(crate) enum HttpConnector {
Gai(hyper::client::HttpConnector),
Expand Down Expand Up @@ -525,11 +524,9 @@ impl Connector {
#[cfg(feature = "http3")]
pub fn deep_clone_tls(&self) -> rustls::ClientConfig {
match &self.inner {
Inner::RustlsTls { tls, .. } => {
(*(*tls)).clone()
}
Inner::RustlsTls { tls, .. } => (*(*tls)).clone(),
#[cfg(feature = "default-tls")]
_ => unreachable!("HTTP/3 should only be enabled with Rustls")
_ => unreachable!("HTTP/3 should only be enabled with Rustls"),
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/tls.rs
Expand Up @@ -380,7 +380,10 @@ impl Default for TlsBackend {
TlsBackend::Default
}

#[cfg(any(all(feature = "__rustls", not(feature = "default-tls")), feature = "http3"))]
#[cfg(any(
all(feature = "__rustls", not(feature = "default-tls")),
feature = "http3"
))]
{
TlsBackend::Rustls
}
Expand Down

0 comments on commit 1b68cfc

Please sign in to comment.