Skip to content

Commit 5442b6f

Browse files
authoredMay 24, 2021
feat(http2): Implement Client-side CONNECT support over HTTP/2 (#2523)
Closes #2508
1 parent be9677a commit 5442b6f

File tree

10 files changed

+833
-78
lines changed

10 files changed

+833
-78
lines changed
 

‎Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ http = "0.2"
3131
http-body = "0.4"
3232
httpdate = "1.0"
3333
httparse = "1.4"
34-
h2 = { version = "0.3", optional = true }
34+
h2 = { version = "0.3.3", optional = true }
3535
itoa = "0.4.1"
3636
tracing = { version = "0.1", default-features = false, features = ["std"] }
3737
pin-project = "1.0"

‎src/body/length.rs

+11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@ use std::fmt;
33
#[derive(Clone, Copy, PartialEq, Eq)]
44
pub(crate) struct DecodedLength(u64);
55

6+
#[cfg(any(feature = "http1", feature = "http2"))]
7+
impl From<Option<u64>> for DecodedLength {
8+
fn from(len: Option<u64>) -> Self {
9+
len.and_then(|len| {
10+
// If the length is u64::MAX, oh well, just reported chunked.
11+
Self::checked_new(len).ok()
12+
})
13+
.unwrap_or(DecodedLength::CHUNKED)
14+
}
15+
}
16+
617
#[cfg(any(feature = "http1", feature = "http2", test))]
718
const MAX_LEN: u64 = std::u64::MAX - 2;
819

‎src/client/client.rs

+2-5
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,9 @@ where
254254
absolute_form(req.uri_mut());
255255
} else {
256256
origin_form(req.uri_mut());
257-
};
257+
}
258258
} else if req.method() == Method::CONNECT {
259-
debug!("client does not support CONNECT requests over HTTP2");
260-
return Err(ClientError::Normal(
261-
crate::Error::new_user_unsupported_request_method(),
262-
));
259+
authority_form(req.uri_mut());
263260
}
264261

265262
let fut = pooled

‎src/error.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub(super) enum User {
9090
/// User tried to send a certain header in an unexpected context.
9191
///
9292
/// For example, sending both `content-length` and `transfer-encoding`.
93-
#[cfg(feature = "http1")]
93+
#[cfg(any(feature = "http1", feature = "http2"))]
9494
#[cfg(feature = "server")]
9595
UnexpectedHeader,
9696
/// User tried to create a Request with bad version.
@@ -290,7 +290,7 @@ impl Error {
290290
Error::new(Kind::User(user))
291291
}
292292

293-
#[cfg(feature = "http1")]
293+
#[cfg(any(feature = "http1", feature = "http2"))]
294294
#[cfg(feature = "server")]
295295
pub(super) fn new_user_header() -> Error {
296296
Error::new_user(User::UnexpectedHeader)
@@ -405,7 +405,7 @@ impl Error {
405405
Kind::User(User::MakeService) => "error from user's MakeService",
406406
#[cfg(any(feature = "http1", feature = "http2"))]
407407
Kind::User(User::Service) => "error from user's Service",
408-
#[cfg(feature = "http1")]
408+
#[cfg(any(feature = "http1", feature = "http2"))]
409409
#[cfg(feature = "server")]
410410
Kind::User(User::UnexpectedHeader) => "user sent unexpected header",
411411
#[cfg(any(feature = "http1", feature = "http2"))]

‎src/proto/h2/client.rs

+89-32
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@ use std::error::Error as StdError;
22
#[cfg(feature = "runtime")]
33
use std::time::Duration;
44

5+
use bytes::Bytes;
56
use futures_channel::{mpsc, oneshot};
67
use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _};
78
use futures_util::stream::StreamExt as _;
89
use h2::client::{Builder, SendRequest};
10+
use http::{Method, StatusCode};
911
use tokio::io::{AsyncRead, AsyncWrite};
1012

11-
use super::{decode_content_length, ping, PipeToSendStream, SendBuf};
13+
use super::{ping, H2Upgraded, PipeToSendStream, SendBuf};
1214
use crate::body::HttpBody;
1315
use crate::common::{exec::Exec, task, Future, Never, Pin, Poll};
1416
use crate::headers;
17+
use crate::proto::h2::UpgradedSendStream;
1518
use crate::proto::Dispatched;
19+
use crate::upgrade::Upgraded;
1620
use crate::{Body, Request, Response};
1721

1822
type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>;
@@ -233,8 +237,25 @@ where
233237
headers::set_content_length_if_missing(req.headers_mut(), len);
234238
}
235239
}
240+
241+
let is_connect = req.method() == Method::CONNECT;
236242
let eos = body.is_end_stream();
237-
let (fut, body_tx) = match self.h2_tx.send_request(req, eos) {
243+
let ping = self.ping.clone();
244+
245+
if is_connect {
246+
if headers::content_length_parse_all(req.headers())
247+
.map_or(false, |len| len != 0)
248+
{
249+
warn!("h2 connect request with non-zero body not supported");
250+
cb.send(Err((
251+
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
252+
None,
253+
)));
254+
continue;
255+
}
256+
}
257+
258+
let (fut, body_tx) = match self.h2_tx.send_request(req, !is_connect && eos) {
238259
Ok(ok) => ok,
239260
Err(err) => {
240261
debug!("client send request error: {}", err);
@@ -243,45 +264,81 @@ where
243264
}
244265
};
245266

246-
let ping = self.ping.clone();
247-
if !eos {
248-
let mut pipe = Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| {
249-
if let Err(e) = res {
250-
debug!("client request body error: {}", e);
251-
}
252-
});
253-
254-
// eagerly see if the body pipe is ready and
255-
// can thus skip allocating in the executor
256-
match Pin::new(&mut pipe).poll(cx) {
257-
Poll::Ready(_) => (),
258-
Poll::Pending => {
259-
let conn_drop_ref = self.conn_drop_ref.clone();
260-
// keep the ping recorder's knowledge of an
261-
// "open stream" alive while this body is
262-
// still sending...
263-
let ping = ping.clone();
264-
let pipe = pipe.map(move |x| {
265-
drop(conn_drop_ref);
266-
drop(ping);
267-
x
267+
let send_stream = if !is_connect {
268+
if !eos {
269+
let mut pipe =
270+
Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| {
271+
if let Err(e) = res {
272+
debug!("client request body error: {}", e);
273+
}
268274
});
269-
self.executor.execute(pipe);
275+
276+
// eagerly see if the body pipe is ready and
277+
// can thus skip allocating in the executor
278+
match Pin::new(&mut pipe).poll(cx) {
279+
Poll::Ready(_) => (),
280+
Poll::Pending => {
281+
let conn_drop_ref = self.conn_drop_ref.clone();
282+
// keep the ping recorder's knowledge of an
283+
// "open stream" alive while this body is
284+
// still sending...
285+
let ping = ping.clone();
286+
let pipe = pipe.map(move |x| {
287+
drop(conn_drop_ref);
288+
drop(ping);
289+
x
290+
});
291+
self.executor.execute(pipe);
292+
}
270293
}
271294
}
272-
}
295+
296+
None
297+
} else {
298+
Some(body_tx)
299+
};
273300

274301
let fut = fut.map(move |result| match result {
275302
Ok(res) => {
276303
// record that we got the response headers
277304
ping.record_non_data();
278305

279-
let content_length = decode_content_length(res.headers());
280-
let res = res.map(|stream| {
281-
let ping = ping.for_stream(&stream);
282-
crate::Body::h2(stream, content_length, ping)
283-
});
284-
Ok(res)
306+
let content_length = headers::content_length_parse_all(res.headers());
307+
if let (Some(mut send_stream), StatusCode::OK) =
308+
(send_stream, res.status())
309+
{
310+
if content_length.map_or(false, |len| len != 0) {
311+
warn!("h2 connect response with non-zero body not supported");
312+
313+
send_stream.send_reset(h2::Reason::INTERNAL_ERROR);
314+
return Err((
315+
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
316+
None,
317+
));
318+
}
319+
let (parts, recv_stream) = res.into_parts();
320+
let mut res = Response::from_parts(parts, Body::empty());
321+
322+
let (pending, on_upgrade) = crate::upgrade::pending();
323+
let io = H2Upgraded {
324+
ping,
325+
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
326+
recv_stream,
327+
buf: Bytes::new(),
328+
};
329+
let upgraded = Upgraded::new(io, Bytes::new());
330+
331+
pending.fulfill(upgraded);
332+
res.extensions_mut().insert(on_upgrade);
333+
334+
Ok(res)
335+
} else {
336+
let res = res.map(|stream| {
337+
let ping = ping.for_stream(&stream);
338+
crate::Body::h2(stream, content_length.into(), ping)
339+
});
340+
Ok(res)
341+
}
285342
}
286343
Err(err) => {
287344
ping.ensure_not_timed_out().map_err(|e| (e, None))?;

‎src/proto/h2/mod.rs

+186-22
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
use bytes::Buf;
2-
use h2::SendStream;
1+
use bytes::{Buf, Bytes};
2+
use h2::{RecvStream, SendStream};
33
use http::header::{
44
HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER,
55
TRANSFER_ENCODING, UPGRADE,
66
};
77
use http::HeaderMap;
88
use pin_project::pin_project;
99
use std::error::Error as StdError;
10-
use std::io::IoSlice;
10+
use std::io::{self, Cursor, IoSlice};
11+
use std::mem;
12+
use std::task::Context;
13+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1114

12-
use crate::body::{DecodedLength, HttpBody};
15+
use crate::body::HttpBody;
1316
use crate::common::{task, Future, Pin, Poll};
14-
use crate::headers::content_length_parse_all;
17+
use crate::proto::h2::ping::Recorder;
1518

1619
pub(crate) mod ping;
1720

@@ -83,15 +86,6 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) {
8386
}
8487
}
8588

86-
fn decode_content_length(headers: &HeaderMap) -> DecodedLength {
87-
if let Some(len) = content_length_parse_all(headers) {
88-
// If the length is u64::MAX, oh well, just reported chunked.
89-
DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED)
90-
} else {
91-
DecodedLength::CHUNKED
92-
}
93-
}
94-
9589
// body adapters used by both Client and Server
9690

9791
#[pin_project]
@@ -172,7 +166,7 @@ where
172166
is_eos,
173167
);
174168

175-
let buf = SendBuf(Some(chunk));
169+
let buf = SendBuf::Buf(chunk);
176170
me.body_tx
177171
.send_data(buf, is_eos)
178172
.map_err(crate::Error::new_body_write)?;
@@ -243,32 +237,202 @@ impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> {
243237

244238
fn send_eos_frame(&mut self) -> crate::Result<()> {
245239
trace!("send body eos");
246-
self.send_data(SendBuf(None), true)
240+
self.send_data(SendBuf::None, true)
247241
.map_err(crate::Error::new_body_write)
248242
}
249243
}
250244

251-
struct SendBuf<B>(Option<B>);
245+
#[repr(usize)]
246+
enum SendBuf<B> {
247+
Buf(B),
248+
Cursor(Cursor<Box<[u8]>>),
249+
None,
250+
}
252251

253252
impl<B: Buf> Buf for SendBuf<B> {
254253
#[inline]
255254
fn remaining(&self) -> usize {
256-
self.0.as_ref().map(|b| b.remaining()).unwrap_or(0)
255+
match *self {
256+
Self::Buf(ref b) => b.remaining(),
257+
Self::Cursor(ref c) => c.remaining(),
258+
Self::None => 0,
259+
}
257260
}
258261

259262
#[inline]
260263
fn chunk(&self) -> &[u8] {
261-
self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[])
264+
match *self {
265+
Self::Buf(ref b) => b.chunk(),
266+
Self::Cursor(ref c) => c.chunk(),
267+
Self::None => &[],
268+
}
262269
}
263270

264271
#[inline]
265272
fn advance(&mut self, cnt: usize) {
266-
if let Some(b) = self.0.as_mut() {
267-
b.advance(cnt)
273+
match *self {
274+
Self::Buf(ref mut b) => b.advance(cnt),
275+
Self::Cursor(ref mut c) => c.advance(cnt),
276+
Self::None => {}
268277
}
269278
}
270279

271280
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
272-
self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0)
281+
match *self {
282+
Self::Buf(ref b) => b.chunks_vectored(dst),
283+
Self::Cursor(ref c) => c.chunks_vectored(dst),
284+
Self::None => 0,
285+
}
286+
}
287+
}
288+
289+
struct H2Upgraded<B>
290+
where
291+
B: Buf,
292+
{
293+
ping: Recorder,
294+
send_stream: UpgradedSendStream<B>,
295+
recv_stream: RecvStream,
296+
buf: Bytes,
297+
}
298+
299+
impl<B> AsyncRead for H2Upgraded<B>
300+
where
301+
B: Buf,
302+
{
303+
fn poll_read(
304+
mut self: Pin<&mut Self>,
305+
cx: &mut Context<'_>,
306+
read_buf: &mut ReadBuf<'_>,
307+
) -> Poll<Result<(), io::Error>> {
308+
if self.buf.is_empty() {
309+
self.buf = loop {
310+
match ready!(self.recv_stream.poll_data(cx)) {
311+
None => return Poll::Ready(Ok(())),
312+
Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => {
313+
continue
314+
}
315+
Some(Ok(buf)) => {
316+
self.ping.record_data(buf.len());
317+
break buf;
318+
}
319+
Some(Err(e)) => {
320+
return Poll::Ready(Err(h2_to_io_error(e)));
321+
}
322+
}
323+
};
324+
}
325+
let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
326+
read_buf.put_slice(&self.buf[..cnt]);
327+
self.buf.advance(cnt);
328+
let _ = self.recv_stream.flow_control().release_capacity(cnt);
329+
Poll::Ready(Ok(()))
330+
}
331+
}
332+
333+
impl<B> AsyncWrite for H2Upgraded<B>
334+
where
335+
B: Buf,
336+
{
337+
fn poll_write(
338+
mut self: Pin<&mut Self>,
339+
cx: &mut Context<'_>,
340+
buf: &[u8],
341+
) -> Poll<Result<usize, io::Error>> {
342+
if let Poll::Ready(reset) = self.send_stream.poll_reset(cx) {
343+
return Poll::Ready(Err(h2_to_io_error(match reset {
344+
Ok(reason) => reason.into(),
345+
Err(e) => e,
346+
})));
347+
}
348+
if buf.is_empty() {
349+
return Poll::Ready(Ok(0));
350+
}
351+
self.send_stream.reserve_capacity(buf.len());
352+
Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) {
353+
None => Ok(0),
354+
Some(Ok(cnt)) => self.send_stream.write(&buf[..cnt], false).map(|()| cnt),
355+
Some(Err(e)) => Err(h2_to_io_error(e)),
356+
})
357+
}
358+
359+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
360+
Poll::Ready(Ok(()))
361+
}
362+
363+
fn poll_shutdown(
364+
mut self: Pin<&mut Self>,
365+
_cx: &mut Context<'_>,
366+
) -> Poll<Result<(), io::Error>> {
367+
Poll::Ready(self.send_stream.write(&[], true))
368+
}
369+
}
370+
371+
fn h2_to_io_error(e: h2::Error) -> io::Error {
372+
if e.is_io() {
373+
e.into_io().unwrap()
374+
} else {
375+
io::Error::new(io::ErrorKind::Other, e)
376+
}
377+
}
378+
379+
struct UpgradedSendStream<B>(SendStream<SendBuf<Neutered<B>>>);
380+
381+
impl<B> UpgradedSendStream<B>
382+
where
383+
B: Buf,
384+
{
385+
unsafe fn new(inner: SendStream<SendBuf<B>>) -> Self {
386+
assert_eq!(mem::size_of::<B>(), mem::size_of::<Neutered<B>>());
387+
Self(mem::transmute(inner))
388+
}
389+
390+
fn reserve_capacity(&mut self, cnt: usize) {
391+
unsafe { self.as_inner_unchecked().reserve_capacity(cnt) }
392+
}
393+
394+
fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<usize, h2::Error>>> {
395+
unsafe { self.as_inner_unchecked().poll_capacity(cx) }
396+
}
397+
398+
fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll<Result<h2::Reason, h2::Error>> {
399+
unsafe { self.as_inner_unchecked().poll_reset(cx) }
400+
}
401+
402+
fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
403+
let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
404+
unsafe {
405+
self.as_inner_unchecked()
406+
.send_data(send_buf, end_of_stream)
407+
.map_err(h2_to_io_error)
408+
}
409+
}
410+
411+
unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream<SendBuf<B>> {
412+
&mut *(&mut self.0 as *mut _ as *mut _)
413+
}
414+
}
415+
416+
#[repr(transparent)]
417+
struct Neutered<B> {
418+
_inner: B,
419+
impossible: Impossible,
420+
}
421+
422+
enum Impossible {}
423+
424+
unsafe impl<B> Send for Neutered<B> {}
425+
426+
impl<B> Buf for Neutered<B> {
427+
fn remaining(&self) -> usize {
428+
match self.impossible {}
429+
}
430+
431+
fn chunk(&self) -> &[u8] {
432+
match self.impossible {}
433+
}
434+
435+
fn advance(&mut self, _cnt: usize) {
436+
match self.impossible {}
273437
}
274438
}

‎src/proto/h2/server.rs

+78-10
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,24 @@ use std::marker::Unpin;
33
#[cfg(feature = "runtime")]
44
use std::time::Duration;
55

6+
use bytes::Bytes;
67
use h2::server::{Connection, Handshake, SendResponse};
7-
use h2::Reason;
8+
use h2::{Reason, RecvStream};
9+
use http::{Method, Request};
810
use pin_project::pin_project;
911
use tokio::io::{AsyncRead, AsyncWrite};
1012

11-
use super::{decode_content_length, ping, PipeToSendStream, SendBuf};
13+
use super::{ping, PipeToSendStream, SendBuf};
1214
use crate::body::HttpBody;
1315
use crate::common::exec::ConnStreamExec;
1416
use crate::common::{date, task, Future, Pin, Poll};
1517
use crate::headers;
18+
use crate::proto::h2::ping::Recorder;
19+
use crate::proto::h2::{H2Upgraded, UpgradedSendStream};
1620
use crate::proto::Dispatched;
1721
use crate::service::HttpService;
1822

23+
use crate::upgrade::{OnUpgrade, Pending, Upgraded};
1924
use crate::{Body, Response};
2025

2126
// Our defaults are chosen for the "majority" case, which usually are not
@@ -255,9 +260,9 @@ where
255260

256261
// When the service is ready, accepts an incoming request.
257262
match ready!(self.conn.poll_accept(cx)) {
258-
Some(Ok((req, respond))) => {
263+
Some(Ok((req, mut respond))) => {
259264
trace!("incoming request");
260-
let content_length = decode_content_length(req.headers());
265+
let content_length = headers::content_length_parse_all(req.headers());
261266
let ping = self
262267
.ping
263268
.as_ref()
@@ -267,8 +272,36 @@ where
267272
// Record the headers received
268273
ping.record_non_data();
269274

270-
let req = req.map(|stream| crate::Body::h2(stream, content_length, ping));
271-
let fut = H2Stream::new(service.call(req), respond);
275+
let is_connect = req.method() == Method::CONNECT;
276+
let (mut parts, stream) = req.into_parts();
277+
let (req, connect_parts) = if !is_connect {
278+
(
279+
Request::from_parts(
280+
parts,
281+
crate::Body::h2(stream, content_length.into(), ping),
282+
),
283+
None,
284+
)
285+
} else {
286+
if content_length.map_or(false, |len| len != 0) {
287+
warn!("h2 connect request with non-zero body not supported");
288+
respond.send_reset(h2::Reason::INTERNAL_ERROR);
289+
return Poll::Ready(Ok(()));
290+
}
291+
let (pending, upgrade) = crate::upgrade::pending();
292+
debug_assert!(parts.extensions.get::<OnUpgrade>().is_none());
293+
parts.extensions.insert(upgrade);
294+
(
295+
Request::from_parts(parts, crate::Body::empty()),
296+
Some(ConnectParts {
297+
pending,
298+
ping,
299+
recv_stream: stream,
300+
}),
301+
)
302+
};
303+
304+
let fut = H2Stream::new(service.call(req), connect_parts, respond);
272305
exec.execute_h2stream(fut);
273306
}
274307
Some(Err(e)) => {
@@ -331,18 +364,28 @@ enum H2StreamState<F, B>
331364
where
332365
B: HttpBody,
333366
{
334-
Service(#[pin] F),
367+
Service(#[pin] F, Option<ConnectParts>),
335368
Body(#[pin] PipeToSendStream<B>),
336369
}
337370

371+
struct ConnectParts {
372+
pending: Pending,
373+
ping: Recorder,
374+
recv_stream: RecvStream,
375+
}
376+
338377
impl<F, B> H2Stream<F, B>
339378
where
340379
B: HttpBody,
341380
{
342-
fn new(fut: F, respond: SendResponse<SendBuf<B::Data>>) -> H2Stream<F, B> {
381+
fn new(
382+
fut: F,
383+
connect_parts: Option<ConnectParts>,
384+
respond: SendResponse<SendBuf<B::Data>>,
385+
) -> H2Stream<F, B> {
343386
H2Stream {
344387
reply: respond,
345-
state: H2StreamState::Service(fut),
388+
state: H2StreamState::Service(fut, connect_parts),
346389
}
347390
}
348391
}
@@ -364,14 +407,15 @@ impl<F, B, E> H2Stream<F, B>
364407
where
365408
F: Future<Output = Result<Response<B>, E>>,
366409
B: HttpBody,
410+
B::Data: 'static,
367411
B::Error: Into<Box<dyn StdError + Send + Sync>>,
368412
E: Into<Box<dyn StdError + Send + Sync>>,
369413
{
370414
fn poll2(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
371415
let mut me = self.project();
372416
loop {
373417
let next = match me.state.as_mut().project() {
374-
H2StreamStateProj::Service(h) => {
418+
H2StreamStateProj::Service(h, connect_parts) => {
375419
let res = match h.poll(cx) {
376420
Poll::Ready(Ok(r)) => r,
377421
Poll::Pending => {
@@ -402,6 +446,29 @@ where
402446
.entry(::http::header::DATE)
403447
.or_insert_with(date::update_and_header_value);
404448

449+
if let Some(connect_parts) = connect_parts.take() {
450+
if res.status().is_success() {
451+
if headers::content_length_parse_all(res.headers())
452+
.map_or(false, |len| len != 0)
453+
{
454+
warn!("h2 successful response to CONNECT request with body not supported");
455+
me.reply.send_reset(h2::Reason::INTERNAL_ERROR);
456+
return Poll::Ready(Err(crate::Error::new_user_header()));
457+
}
458+
let send_stream = reply!(me, res, false);
459+
connect_parts.pending.fulfill(Upgraded::new(
460+
H2Upgraded {
461+
ping: connect_parts.ping,
462+
recv_stream: connect_parts.recv_stream,
463+
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
464+
buf: Bytes::new(),
465+
},
466+
Bytes::new(),
467+
));
468+
return Poll::Ready(Ok(()));
469+
}
470+
}
471+
405472
// automatically set Content-Length from body...
406473
if let Some(len) = body.size_hint().exact() {
407474
headers::set_content_length_if_missing(res.headers_mut(), len);
@@ -428,6 +495,7 @@ impl<F, B, E> Future for H2Stream<F, B>
428495
where
429496
F: Future<Output = Result<Response<B>, E>>,
430497
B: HttpBody,
498+
B::Data: 'static,
431499
B::Error: Into<Box<dyn StdError + Send + Sync>>,
432500
E: Into<Box<dyn StdError + Send + Sync>>,
433501
{

‎src/upgrade.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
6262
msg.on_upgrade()
6363
}
6464

65-
#[cfg(feature = "http1")]
65+
#[cfg(any(feature = "http1", feature = "http2"))]
6666
pub(super) struct Pending {
6767
tx: oneshot::Sender<crate::Result<Upgraded>>,
6868
}
6969

70-
#[cfg(feature = "http1")]
70+
#[cfg(any(feature = "http1", feature = "http2"))]
7171
pub(super) fn pending() -> (Pending, OnUpgrade) {
7272
let (tx, rx) = oneshot::channel();
7373
(Pending { tx }, OnUpgrade { rx: Some(rx) })
@@ -76,7 +76,7 @@ pub(super) fn pending() -> (Pending, OnUpgrade) {
7676
// ===== impl Upgraded =====
7777

7878
impl Upgraded {
79-
#[cfg(any(feature = "http1", test))]
79+
#[cfg(any(feature = "http1", feature = "http2", test))]
8080
pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
8181
where
8282
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
@@ -187,13 +187,14 @@ impl fmt::Debug for OnUpgrade {
187187

188188
// ===== impl Pending =====
189189

190-
#[cfg(feature = "http1")]
190+
#[cfg(any(feature = "http1", feature = "http2"))]
191191
impl Pending {
192192
pub(super) fn fulfill(self, upgraded: Upgraded) {
193193
trace!("pending upgrade fulfill");
194194
let _ = self.tx.send(Ok(upgraded));
195195
}
196196

197+
#[cfg(feature = "http1")]
197198
/// Don't fulfill the pending Upgrade, but instead signal that
198199
/// upgrades are handled manually.
199200
pub(super) fn manual(self) {

‎tests/client.rs

+122-1
Original file line numberDiff line numberDiff line change
@@ -2261,14 +2261,16 @@ mod conn {
22612261
use std::thread;
22622262
use std::time::Duration;
22632263

2264+
use bytes::Buf;
22642265
use futures_channel::oneshot;
22652266
use futures_util::future::{self, poll_fn, FutureExt, TryFutureExt};
22662267
use futures_util::StreamExt;
2268+
use hyper::upgrade::OnUpgrade;
22672269
use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, ReadBuf};
22682270
use tokio::net::{TcpListener as TkTcpListener, TcpStream};
22692271

22702272
use hyper::client::conn;
2271-
use hyper::{self, Body, Method, Request};
2273+
use hyper::{self, Body, Method, Request, Response, StatusCode};
22722274

22732275
use super::{concat, s, support, tcp_connect, FutureHyperExt};
22742276

@@ -2984,6 +2986,125 @@ mod conn {
29842986
.expect("client should be open");
29852987
}
29862988

2989+
#[tokio::test]
2990+
async fn h2_connect() {
2991+
let _ = pretty_env_logger::try_init();
2992+
2993+
let listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
2994+
.await
2995+
.unwrap();
2996+
let addr = listener.local_addr().unwrap();
2997+
2998+
// Spawn an HTTP2 server that asks for bread and responds with baguette.
2999+
tokio::spawn(async move {
3000+
let sock = listener.accept().await.unwrap().0;
3001+
let mut h2 = h2::server::handshake(sock).await.unwrap();
3002+
3003+
let (req, mut respond) = h2.accept().await.unwrap().unwrap();
3004+
tokio::spawn(async move {
3005+
poll_fn(|cx| h2.poll_closed(cx)).await.unwrap();
3006+
});
3007+
assert_eq!(req.method(), Method::CONNECT);
3008+
3009+
let mut body = req.into_body();
3010+
3011+
let mut send_stream = respond.send_response(Response::default(), false).unwrap();
3012+
3013+
send_stream.send_data("Bread?".into(), true).unwrap();
3014+
3015+
let bytes = body.data().await.unwrap().unwrap();
3016+
assert_eq!(&bytes[..], b"Baguette!");
3017+
let _ = body.flow_control().release_capacity(bytes.len());
3018+
3019+
assert!(body.data().await.is_none());
3020+
});
3021+
3022+
let io = tcp_connect(&addr).await.expect("tcp connect");
3023+
let (mut client, conn) = conn::Builder::new()
3024+
.http2_only(true)
3025+
.handshake::<_, Body>(io)
3026+
.await
3027+
.expect("http handshake");
3028+
3029+
tokio::spawn(async move {
3030+
conn.await.expect("client conn shouldn't error");
3031+
});
3032+
3033+
let req = Request::connect("localhost")
3034+
.body(hyper::Body::empty())
3035+
.unwrap();
3036+
let res = client.send_request(req).await.expect("send_request");
3037+
assert_eq!(res.status(), StatusCode::OK);
3038+
3039+
let mut upgraded = hyper::upgrade::on(res).await.unwrap();
3040+
3041+
let mut vec = vec![];
3042+
upgraded.read_to_end(&mut vec).await.unwrap();
3043+
assert_eq!(s(&vec), "Bread?");
3044+
3045+
upgraded.write_all(b"Baguette!").await.unwrap();
3046+
3047+
upgraded.shutdown().await.unwrap();
3048+
}
3049+
3050+
#[tokio::test]
3051+
async fn h2_connect_rejected() {
3052+
let _ = pretty_env_logger::try_init();
3053+
3054+
let listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
3055+
.await
3056+
.unwrap();
3057+
let addr = listener.local_addr().unwrap();
3058+
let (done_tx, done_rx) = oneshot::channel();
3059+
3060+
tokio::spawn(async move {
3061+
let sock = listener.accept().await.unwrap().0;
3062+
let mut h2 = h2::server::handshake(sock).await.unwrap();
3063+
3064+
let (req, mut respond) = h2.accept().await.unwrap().unwrap();
3065+
tokio::spawn(async move {
3066+
poll_fn(|cx| h2.poll_closed(cx)).await.unwrap();
3067+
});
3068+
assert_eq!(req.method(), Method::CONNECT);
3069+
3070+
let res = Response::builder().status(400).body(()).unwrap();
3071+
let mut send_stream = respond.send_response(res, false).unwrap();
3072+
send_stream
3073+
.send_data("No bread for you!".into(), true)
3074+
.unwrap();
3075+
done_rx.await.unwrap();
3076+
});
3077+
3078+
let io = tcp_connect(&addr).await.expect("tcp connect");
3079+
let (mut client, conn) = conn::Builder::new()
3080+
.http2_only(true)
3081+
.handshake::<_, Body>(io)
3082+
.await
3083+
.expect("http handshake");
3084+
3085+
tokio::spawn(async move {
3086+
conn.await.expect("client conn shouldn't error");
3087+
});
3088+
3089+
let req = Request::connect("localhost")
3090+
.body(hyper::Body::empty())
3091+
.unwrap();
3092+
let res = client.send_request(req).await.expect("send_request");
3093+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
3094+
assert!(res.extensions().get::<OnUpgrade>().is_none());
3095+
3096+
let mut body = String::new();
3097+
hyper::body::aggregate(res.into_body())
3098+
.await
3099+
.unwrap()
3100+
.reader()
3101+
.read_to_string(&mut body)
3102+
.unwrap();
3103+
assert_eq!(body, "No bread for you!");
3104+
3105+
done_tx.send(()).unwrap();
3106+
}
3107+
29873108
async fn drain_til_eof<T: AsyncRead + Unpin>(mut sock: T) -> io::Result<()> {
29883109
let mut buf = [0u8; 1024];
29893110
loop {

‎tests/server.rs

+336
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@ use std::task::{Context, Poll};
1313
use std::thread;
1414
use std::time::Duration;
1515

16+
use bytes::Bytes;
1617
use futures_channel::oneshot;
1718
use futures_util::future::{self, Either, FutureExt, TryFutureExt};
1819
#[cfg(feature = "stream")]
1920
use futures_util::stream::StreamExt as _;
21+
use h2::client::SendRequest;
22+
use h2::{RecvStream, SendStream};
2023
use http::header::{HeaderName, HeaderValue};
2124
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
2225
use tokio::net::{TcpListener, TcpStream as TkTcpStream};
@@ -1482,6 +1485,339 @@ async fn http_connect_new() {
14821485
assert_eq!(s(&vec), "bar=foo");
14831486
}
14841487

1488+
#[tokio::test]
1489+
async fn h2_connect() {
1490+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
1491+
1492+
let _ = pretty_env_logger::try_init();
1493+
let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
1494+
let addr = listener.local_addr().unwrap();
1495+
let conn = connect_async(addr).await;
1496+
1497+
let (h2, connection) = h2::client::handshake(conn).await.unwrap();
1498+
tokio::spawn(async move {
1499+
connection.await.unwrap();
1500+
});
1501+
let mut h2 = h2.ready().await.unwrap();
1502+
1503+
async fn connect_and_recv_bread(
1504+
h2: &mut SendRequest<Bytes>,
1505+
) -> (RecvStream, SendStream<Bytes>) {
1506+
let request = Request::connect("localhost").body(()).unwrap();
1507+
let (response, send_stream) = h2.send_request(request, false).unwrap();
1508+
let response = response.await.unwrap();
1509+
assert_eq!(response.status(), StatusCode::OK);
1510+
1511+
let mut body = response.into_body();
1512+
let bytes = body.data().await.unwrap().unwrap();
1513+
assert_eq!(&bytes[..], b"Bread?");
1514+
let _ = body.flow_control().release_capacity(bytes.len());
1515+
1516+
(body, send_stream)
1517+
}
1518+
1519+
tokio::spawn(async move {
1520+
let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await;
1521+
1522+
send_stream.send_data("Baguette!".into(), true).unwrap();
1523+
1524+
assert!(recv_stream.data().await.unwrap().unwrap().is_empty());
1525+
});
1526+
1527+
let svc = service_fn(move |req: Request<Body>| {
1528+
let on_upgrade = hyper::upgrade::on(req);
1529+
1530+
tokio::spawn(async move {
1531+
let mut upgraded = on_upgrade.await.expect("on_upgrade");
1532+
upgraded.write_all(b"Bread?").await.unwrap();
1533+
1534+
let mut vec = vec![];
1535+
upgraded.read_to_end(&mut vec).await.unwrap();
1536+
assert_eq!(s(&vec), "Baguette!");
1537+
1538+
upgraded.shutdown().await.unwrap();
1539+
});
1540+
1541+
future::ok::<_, hyper::Error>(
1542+
Response::builder()
1543+
.status(200)
1544+
.body(hyper::Body::empty())
1545+
.unwrap(),
1546+
)
1547+
});
1548+
1549+
let (socket, _) = listener.accept().await.unwrap();
1550+
Http::new()
1551+
.http2_only(true)
1552+
.serve_connection(socket, svc)
1553+
.with_upgrades()
1554+
.await
1555+
.unwrap();
1556+
}
1557+
1558+
#[tokio::test]
1559+
async fn h2_connect_multiplex() {
1560+
use futures_util::stream::FuturesUnordered;
1561+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
1562+
1563+
let _ = pretty_env_logger::try_init();
1564+
let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
1565+
let addr = listener.local_addr().unwrap();
1566+
let conn = connect_async(addr).await;
1567+
1568+
let (h2, connection) = h2::client::handshake(conn).await.unwrap();
1569+
tokio::spawn(async move {
1570+
connection.await.unwrap();
1571+
});
1572+
let mut h2 = h2.ready().await.unwrap();
1573+
1574+
tokio::spawn(async move {
1575+
let mut streams = vec![];
1576+
for i in 0..80 {
1577+
let request = Request::connect(format!("localhost_{}", i % 4))
1578+
.body(())
1579+
.unwrap();
1580+
let (response, send_stream) = h2.send_request(request, false).unwrap();
1581+
streams.push((i, response, send_stream));
1582+
}
1583+
1584+
let futures = streams
1585+
.into_iter()
1586+
.map(|(i, response, mut send_stream)| async move {
1587+
if i % 4 == 0 {
1588+
return;
1589+
}
1590+
1591+
let response = response.await.unwrap();
1592+
assert_eq!(response.status(), StatusCode::OK);
1593+
1594+
if i % 4 == 1 {
1595+
return;
1596+
}
1597+
1598+
let mut body = response.into_body();
1599+
let bytes = body.data().await.unwrap().unwrap();
1600+
assert_eq!(&bytes[..], b"Bread?");
1601+
let _ = body.flow_control().release_capacity(bytes.len());
1602+
1603+
if i % 4 == 2 {
1604+
return;
1605+
}
1606+
1607+
send_stream.send_data("Baguette!".into(), true).unwrap();
1608+
1609+
assert!(body.data().await.unwrap().unwrap().is_empty());
1610+
})
1611+
.collect::<FuturesUnordered<_>>();
1612+
1613+
futures.for_each(future::ready).await;
1614+
});
1615+
1616+
let svc = service_fn(move |req: Request<Body>| {
1617+
let authority = req.uri().authority().unwrap().to_string();
1618+
let on_upgrade = hyper::upgrade::on(req);
1619+
1620+
tokio::spawn(async move {
1621+
let upgrade_res = on_upgrade.await;
1622+
if authority == "localhost_0" {
1623+
assert!(upgrade_res.expect_err("upgrade cancelled").is_canceled());
1624+
return;
1625+
}
1626+
let mut upgraded = upgrade_res.expect("upgrade successful");
1627+
1628+
upgraded.write_all(b"Bread?").await.unwrap();
1629+
1630+
let mut vec = vec![];
1631+
let read_res = upgraded.read_to_end(&mut vec).await;
1632+
1633+
if authority == "localhost_1" || authority == "localhost_2" {
1634+
let err = read_res.expect_err("read failed");
1635+
assert_eq!(err.kind(), io::ErrorKind::Other);
1636+
assert_eq!(
1637+
err.get_ref()
1638+
.unwrap()
1639+
.downcast_ref::<h2::Error>()
1640+
.unwrap()
1641+
.reason(),
1642+
Some(h2::Reason::CANCEL),
1643+
);
1644+
return;
1645+
}
1646+
1647+
read_res.unwrap();
1648+
assert_eq!(s(&vec), "Baguette!");
1649+
1650+
upgraded.shutdown().await.unwrap();
1651+
});
1652+
1653+
future::ok::<_, hyper::Error>(
1654+
Response::builder()
1655+
.status(200)
1656+
.body(hyper::Body::empty())
1657+
.unwrap(),
1658+
)
1659+
});
1660+
1661+
let (socket, _) = listener.accept().await.unwrap();
1662+
Http::new()
1663+
.http2_only(true)
1664+
.serve_connection(socket, svc)
1665+
.with_upgrades()
1666+
.await
1667+
.unwrap();
1668+
}
1669+
1670+
#[tokio::test]
1671+
async fn h2_connect_large_body() {
1672+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
1673+
1674+
let _ = pretty_env_logger::try_init();
1675+
let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
1676+
let addr = listener.local_addr().unwrap();
1677+
let conn = connect_async(addr).await;
1678+
1679+
let (h2, connection) = h2::client::handshake(conn).await.unwrap();
1680+
tokio::spawn(async move {
1681+
connection.await.unwrap();
1682+
});
1683+
let mut h2 = h2.ready().await.unwrap();
1684+
1685+
const NO_BREAD: &str = "All work and no bread makes nox a dull boy.\n";
1686+
1687+
async fn connect_and_recv_bread(
1688+
h2: &mut SendRequest<Bytes>,
1689+
) -> (RecvStream, SendStream<Bytes>) {
1690+
let request = Request::connect("localhost").body(()).unwrap();
1691+
let (response, send_stream) = h2.send_request(request, false).unwrap();
1692+
let response = response.await.unwrap();
1693+
assert_eq!(response.status(), StatusCode::OK);
1694+
1695+
let mut body = response.into_body();
1696+
let bytes = body.data().await.unwrap().unwrap();
1697+
assert_eq!(&bytes[..], b"Bread?");
1698+
let _ = body.flow_control().release_capacity(bytes.len());
1699+
1700+
(body, send_stream)
1701+
}
1702+
1703+
tokio::spawn(async move {
1704+
let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await;
1705+
1706+
let large_body = Bytes::from(NO_BREAD.repeat(9000));
1707+
1708+
send_stream.send_data(large_body.clone(), false).unwrap();
1709+
send_stream.send_data(large_body, true).unwrap();
1710+
1711+
assert!(recv_stream.data().await.unwrap().unwrap().is_empty());
1712+
});
1713+
1714+
let svc = service_fn(move |req: Request<Body>| {
1715+
let on_upgrade = hyper::upgrade::on(req);
1716+
1717+
tokio::spawn(async move {
1718+
let mut upgraded = on_upgrade.await.expect("on_upgrade");
1719+
upgraded.write_all(b"Bread?").await.unwrap();
1720+
1721+
let mut vec = vec![];
1722+
if upgraded.read_to_end(&mut vec).await.is_err() {
1723+
return;
1724+
}
1725+
assert_eq!(vec.len(), NO_BREAD.len() * 9000 * 2);
1726+
1727+
upgraded.shutdown().await.unwrap();
1728+
});
1729+
1730+
future::ok::<_, hyper::Error>(
1731+
Response::builder()
1732+
.status(200)
1733+
.body(hyper::Body::empty())
1734+
.unwrap(),
1735+
)
1736+
});
1737+
1738+
let (socket, _) = listener.accept().await.unwrap();
1739+
Http::new()
1740+
.http2_only(true)
1741+
.serve_connection(socket, svc)
1742+
.with_upgrades()
1743+
.await
1744+
.unwrap();
1745+
}
1746+
1747+
#[tokio::test]
1748+
async fn h2_connect_empty_frames() {
1749+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
1750+
1751+
let _ = pretty_env_logger::try_init();
1752+
let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
1753+
let addr = listener.local_addr().unwrap();
1754+
let conn = connect_async(addr).await;
1755+
1756+
let (h2, connection) = h2::client::handshake(conn).await.unwrap();
1757+
tokio::spawn(async move {
1758+
connection.await.unwrap();
1759+
});
1760+
let mut h2 = h2.ready().await.unwrap();
1761+
1762+
async fn connect_and_recv_bread(
1763+
h2: &mut SendRequest<Bytes>,
1764+
) -> (RecvStream, SendStream<Bytes>) {
1765+
let request = Request::connect("localhost").body(()).unwrap();
1766+
let (response, send_stream) = h2.send_request(request, false).unwrap();
1767+
let response = response.await.unwrap();
1768+
assert_eq!(response.status(), StatusCode::OK);
1769+
1770+
let mut body = response.into_body();
1771+
let bytes = body.data().await.unwrap().unwrap();
1772+
assert_eq!(&bytes[..], b"Bread?");
1773+
let _ = body.flow_control().release_capacity(bytes.len());
1774+
1775+
(body, send_stream)
1776+
}
1777+
1778+
tokio::spawn(async move {
1779+
let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await;
1780+
1781+
send_stream.send_data("".into(), false).unwrap();
1782+
send_stream.send_data("".into(), false).unwrap();
1783+
send_stream.send_data("".into(), false).unwrap();
1784+
send_stream.send_data("Baguette!".into(), false).unwrap();
1785+
send_stream.send_data("".into(), true).unwrap();
1786+
1787+
assert!(recv_stream.data().await.unwrap().unwrap().is_empty());
1788+
});
1789+
1790+
let svc = service_fn(move |req: Request<Body>| {
1791+
let on_upgrade = hyper::upgrade::on(req);
1792+
1793+
tokio::spawn(async move {
1794+
let mut upgraded = on_upgrade.await.expect("on_upgrade");
1795+
upgraded.write_all(b"Bread?").await.unwrap();
1796+
1797+
let mut vec = vec![];
1798+
upgraded.read_to_end(&mut vec).await.unwrap();
1799+
assert_eq!(s(&vec), "Baguette!");
1800+
1801+
upgraded.shutdown().await.unwrap();
1802+
});
1803+
1804+
future::ok::<_, hyper::Error>(
1805+
Response::builder()
1806+
.status(200)
1807+
.body(hyper::Body::empty())
1808+
.unwrap(),
1809+
)
1810+
});
1811+
1812+
let (socket, _) = listener.accept().await.unwrap();
1813+
Http::new()
1814+
.http2_only(true)
1815+
.serve_connection(socket, svc)
1816+
.with_upgrades()
1817+
.await
1818+
.unwrap();
1819+
}
1820+
14851821
#[tokio::test]
14861822
async fn parse_errors_send_4xx_response() {
14871823
let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();

0 commit comments

Comments
 (0)
Please sign in to comment.