Skip to content

Commit 31b4180

Browse files
authoredDec 15, 2023
feat(http1): Add support for sending HTTP/1.1 Chunked Trailer Fields (#3375)
Closes #2719
1 parent 0f2929b commit 31b4180

File tree

8 files changed

+611
-31
lines changed

8 files changed

+611
-31
lines changed
 

‎src/proto/h1/conn.rs

+36-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::time::Duration;
88

99
use crate::rt::{Read, Write};
1010
use bytes::{Buf, Bytes};
11-
use http::header::{HeaderValue, CONNECTION};
11+
use http::header::{HeaderValue, CONNECTION, TE};
1212
use http::{HeaderMap, Method, Version};
1313
use httparse::ParserConfig;
1414

@@ -75,6 +75,7 @@ where
7575
// We assume a modern world where the remote speaks HTTP/1.1.
7676
// If they tell us otherwise, we'll downgrade in `read_head`.
7777
version: Version::HTTP_11,
78+
allow_trailer_fields: false,
7879
},
7980
_marker: PhantomData,
8081
}
@@ -264,6 +265,13 @@ where
264265
self.state.reading = Reading::Body(Decoder::new(msg.decode));
265266
}
266267

268+
self.state.allow_trailer_fields = msg
269+
.head
270+
.headers
271+
.get(TE)
272+
.map(|te_header| te_header == "trailers")
273+
.unwrap_or(false);
274+
267275
Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
268276
}
269277

@@ -640,6 +648,31 @@ where
640648
self.state.writing = state;
641649
}
642650

651+
pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) {
652+
if T::is_server() && self.state.allow_trailer_fields == false {
653+
debug!("trailers not allowed to be sent");
654+
return;
655+
}
656+
debug_assert!(self.can_write_body() && self.can_buffer_body());
657+
658+
match self.state.writing {
659+
Writing::Body(ref encoder) => {
660+
if let Some(enc_buf) =
661+
encoder.encode_trailers(trailers, self.state.title_case_headers)
662+
{
663+
self.io.buffer(enc_buf);
664+
665+
self.state.writing = if encoder.is_last() || encoder.is_close_delimited() {
666+
Writing::Closed
667+
} else {
668+
Writing::KeepAlive
669+
};
670+
}
671+
}
672+
_ => unreachable!("write_trailers invalid state: {:?}", self.state.writing),
673+
}
674+
}
675+
643676
pub(crate) fn write_body_and_end(&mut self, chunk: B) {
644677
debug_assert!(self.can_write_body() && self.can_buffer_body());
645678
// empty chunks should be discarded at Dispatcher level
@@ -842,6 +875,8 @@ struct State {
842875
upgrade: Option<crate::upgrade::Pending>,
843876
/// Either HTTP/1.0 or 1.1 connection
844877
version: Version,
878+
/// Flag to track if trailer fields are allowed to be sent
879+
allow_trailer_fields: bool,
845880
}
846881

847882
#[derive(Debug)]

‎src/proto/h1/dispatch.rs

+24-18
Original file line numberDiff line numberDiff line change
@@ -351,27 +351,33 @@ where
351351
*clear_body = true;
352352
crate::Error::new_user_body(e)
353353
})?;
354-
let chunk = if let Ok(data) = frame.into_data() {
355-
data
356-
} else {
357-
trace!("discarding non-data frame");
358-
continue;
359-
};
360-
let eos = body.is_end_stream();
361-
if eos {
362-
*clear_body = true;
363-
if chunk.remaining() == 0 {
364-
trace!("discarding empty chunk");
365-
self.conn.end_body()?;
354+
355+
if frame.is_data() {
356+
let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
357+
let eos = body.is_end_stream();
358+
if eos {
359+
*clear_body = true;
360+
if chunk.remaining() == 0 {
361+
trace!("discarding empty chunk");
362+
self.conn.end_body()?;
363+
} else {
364+
self.conn.write_body_and_end(chunk);
365+
}
366366
} else {
367-
self.conn.write_body_and_end(chunk);
367+
if chunk.remaining() == 0 {
368+
trace!("discarding empty chunk");
369+
continue;
370+
}
371+
self.conn.write_body(chunk);
368372
}
373+
} else if frame.is_trailers() {
374+
*clear_body = true;
375+
self.conn.write_trailers(
376+
frame.into_trailers().unwrap_or_else(|_| unreachable!()),
377+
);
369378
} else {
370-
if chunk.remaining() == 0 {
371-
trace!("discarding empty chunk");
372-
continue;
373-
}
374-
self.conn.write_body(chunk);
379+
trace!("discarding unknown frame");
380+
continue;
375381
}
376382
} else {
377383
*clear_body = true;

‎src/proto/h1/encode.rs

+268-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
use std::collections::HashMap;
12
use std::fmt;
23
use std::io::IoSlice;
34

45
use bytes::buf::{Chain, Take};
5-
use bytes::Buf;
6+
use bytes::{Buf, Bytes};
7+
use http::{
8+
header::{
9+
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
10+
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
11+
},
12+
HeaderMap, HeaderName, HeaderValue,
13+
};
614

715
use super::io::WriteBuf;
16+
use super::role::{write_headers, write_headers_title_case};
817

918
type StaticBuf = &'static [u8];
1019

@@ -26,7 +35,7 @@ pub(crate) struct NotEof(u64);
2635
#[derive(Debug, PartialEq, Clone)]
2736
enum Kind {
2837
/// An Encoder for when Transfer-Encoding includes `chunked`.
29-
Chunked,
38+
Chunked(Option<Vec<HeaderValue>>),
3039
/// An Encoder for when Content-Length is set.
3140
///
3241
/// Enforces that the body is not longer than the Content-Length header.
@@ -45,6 +54,7 @@ enum BufKind<B> {
4554
Limited(Take<B>),
4655
Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
4756
ChunkedEnd(StaticBuf),
57+
Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
4858
}
4959

5060
impl Encoder {
@@ -55,7 +65,7 @@ impl Encoder {
5565
}
5666
}
5767
pub(crate) fn chunked() -> Encoder {
58-
Encoder::new(Kind::Chunked)
68+
Encoder::new(Kind::Chunked(None))
5969
}
6070

6171
pub(crate) fn length(len: u64) -> Encoder {
@@ -67,6 +77,16 @@ impl Encoder {
6777
Encoder::new(Kind::CloseDelimited)
6878
}
6979

80+
pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
81+
match self.kind {
82+
Kind::Chunked(_) => Encoder {
83+
kind: Kind::Chunked(Some(trailers)),
84+
is_last: self.is_last,
85+
},
86+
_ => self,
87+
}
88+
}
89+
7090
pub(crate) fn is_eof(&self) -> bool {
7191
matches!(self.kind, Kind::Length(0))
7292
}
@@ -89,10 +109,17 @@ impl Encoder {
89109
}
90110
}
91111

112+
pub(crate) fn is_chunked(&self) -> bool {
113+
match self.kind {
114+
Kind::Chunked(_) => true,
115+
_ => false,
116+
}
117+
}
118+
92119
pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
93120
match self.kind {
94121
Kind::Length(0) => Ok(None),
95-
Kind::Chunked => Ok(Some(EncodedBuf {
122+
Kind::Chunked(_) => Ok(Some(EncodedBuf {
96123
kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
97124
})),
98125
#[cfg(feature = "server")]
@@ -109,7 +136,7 @@ impl Encoder {
109136
debug_assert!(len > 0, "encode() called with empty buf");
110137

111138
let kind = match self.kind {
112-
Kind::Chunked => {
139+
Kind::Chunked(_) => {
113140
trace!("encoding chunked {}B", len);
114141
let buf = ChunkSize::new(len)
115142
.chain(msg)
@@ -136,6 +163,53 @@ impl Encoder {
136163
EncodedBuf { kind }
137164
}
138165

166+
pub(crate) fn encode_trailers<B>(
167+
&self,
168+
trailers: HeaderMap,
169+
title_case_headers: bool,
170+
) -> Option<EncodedBuf<B>> {
171+
match &self.kind {
172+
Kind::Chunked(Some(ref allowed_trailer_fields)) => {
173+
let allowed_trailer_field_map = allowed_trailer_field_map(&allowed_trailer_fields);
174+
175+
let mut cur_name = None;
176+
let mut allowed_trailers = HeaderMap::new();
177+
178+
for (opt_name, value) in trailers {
179+
if let Some(n) = opt_name {
180+
cur_name = Some(n);
181+
}
182+
let name = cur_name.as_ref().expect("current header name");
183+
184+
if allowed_trailer_field_map.contains_key(name.as_str())
185+
&& valid_trailer_field(name)
186+
{
187+
allowed_trailers.insert(name, value);
188+
}
189+
}
190+
191+
let mut buf = Vec::new();
192+
if title_case_headers {
193+
write_headers_title_case(&allowed_trailers, &mut buf);
194+
} else {
195+
write_headers(&allowed_trailers, &mut buf);
196+
}
197+
198+
if buf.is_empty() {
199+
return None;
200+
}
201+
202+
Some(EncodedBuf {
203+
kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
204+
})
205+
}
206+
_ => {
207+
debug!("attempted to encode trailers for non-chunked response");
208+
None
209+
}
210+
}
211+
}
212+
139213
pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
140214
where
141215
B: Buf,
@@ -144,7 +218,7 @@ impl Encoder {
144218
debug_assert!(len > 0, "encode() called with empty buf");
145219

146220
match self.kind {
147-
Kind::Chunked => {
221+
Kind::Chunked(_) => {
148222
trace!("encoding chunked {}B", len);
149223
let buf = ChunkSize::new(len)
150224
.chain(msg)
@@ -181,6 +255,40 @@ impl Encoder {
181255
}
182256
}
183257

258+
fn valid_trailer_field(name: &HeaderName) -> bool {
259+
match name {
260+
&AUTHORIZATION => false,
261+
&CACHE_CONTROL => false,
262+
&CONTENT_ENCODING => false,
263+
&CONTENT_LENGTH => false,
264+
&CONTENT_RANGE => false,
265+
&CONTENT_TYPE => false,
266+
&HOST => false,
267+
&MAX_FORWARDS => false,
268+
&SET_COOKIE => false,
269+
&TRAILER => false,
270+
&TRANSFER_ENCODING => false,
271+
&TE => false,
272+
_ => true,
273+
}
274+
}
275+
276+
fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
277+
let mut trailer_map = HashMap::new();
278+
279+
for header_value in allowed_trailer_fields {
280+
if let Ok(header_str) = header_value.to_str() {
281+
let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();
282+
283+
for item in items {
284+
trailer_map.entry(item.to_string()).or_insert(());
285+
}
286+
}
287+
}
288+
289+
trailer_map
290+
}
291+
184292
impl<B> Buf for EncodedBuf<B>
185293
where
186294
B: Buf,
@@ -192,6 +300,7 @@ where
192300
BufKind::Limited(ref b) => b.remaining(),
193301
BufKind::Chunked(ref b) => b.remaining(),
194302
BufKind::ChunkedEnd(ref b) => b.remaining(),
303+
BufKind::Trailers(ref b) => b.remaining(),
195304
}
196305
}
197306

@@ -202,6 +311,7 @@ where
202311
BufKind::Limited(ref b) => b.chunk(),
203312
BufKind::Chunked(ref b) => b.chunk(),
204313
BufKind::ChunkedEnd(ref b) => b.chunk(),
314+
BufKind::Trailers(ref b) => b.chunk(),
205315
}
206316
}
207317

@@ -212,6 +322,7 @@ where
212322
BufKind::Limited(ref mut b) => b.advance(cnt),
213323
BufKind::Chunked(ref mut b) => b.advance(cnt),
214324
BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
325+
BufKind::Trailers(ref mut b) => b.advance(cnt),
215326
}
216327
}
217328

@@ -222,6 +333,7 @@ where
222333
BufKind::Limited(ref b) => b.chunks_vectored(dst),
223334
BufKind::Chunked(ref b) => b.chunks_vectored(dst),
224335
BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
336+
BufKind::Trailers(ref b) => b.chunks_vectored(dst),
225337
}
226338
}
227339
}
@@ -327,7 +439,16 @@ impl std::error::Error for NotEof {}
327439

328440
#[cfg(test)]
329441
mod tests {
442+
use std::iter::FromIterator;
443+
330444
use bytes::BufMut;
445+
use http::{
446+
header::{
447+
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
448+
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
449+
},
450+
HeaderMap, HeaderName, HeaderValue,
451+
};
331452

332453
use super::super::io::Cursor;
333454
use super::Encoder;
@@ -402,4 +523,145 @@ mod tests {
402523
assert!(!encoder.is_eof());
403524
encoder.end::<()>().unwrap();
404525
}
526+
527+
#[test]
528+
fn chunked_with_valid_trailers() {
529+
let encoder = Encoder::chunked();
530+
let trailers = vec![HeaderValue::from_static("chunky-trailer")];
531+
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
532+
533+
let headers = HeaderMap::from_iter(
534+
vec![
535+
(
536+
HeaderName::from_static("chunky-trailer"),
537+
HeaderValue::from_static("header data"),
538+
),
539+
(
540+
HeaderName::from_static("should-not-be-included"),
541+
HeaderValue::from_static("oops"),
542+
),
543+
]
544+
.into_iter(),
545+
);
546+
547+
let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
548+
549+
let mut dst = Vec::new();
550+
dst.put(buf1);
551+
assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
552+
}
553+
554+
#[test]
555+
fn chunked_with_multiple_trailer_headers() {
556+
let encoder = Encoder::chunked();
557+
let trailers = vec![
558+
HeaderValue::from_static("chunky-trailer"),
559+
HeaderValue::from_static("chunky-trailer-2"),
560+
];
561+
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
562+
563+
let headers = HeaderMap::from_iter(
564+
vec![
565+
(
566+
HeaderName::from_static("chunky-trailer"),
567+
HeaderValue::from_static("header data"),
568+
),
569+
(
570+
HeaderName::from_static("chunky-trailer-2"),
571+
HeaderValue::from_static("more header data"),
572+
),
573+
]
574+
.into_iter(),
575+
);
576+
577+
let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
578+
579+
let mut dst = Vec::new();
580+
dst.put(buf1);
581+
assert_eq!(
582+
dst,
583+
b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n"
584+
);
585+
}
586+
587+
#[test]
588+
fn chunked_with_no_trailer_header() {
589+
let encoder = Encoder::chunked();
590+
591+
let headers = HeaderMap::from_iter(
592+
vec![(
593+
HeaderName::from_static("chunky-trailer"),
594+
HeaderValue::from_static("header data"),
595+
)]
596+
.into_iter(),
597+
);
598+
599+
assert!(encoder
600+
.encode_trailers::<&[u8]>(headers.clone(), false)
601+
.is_none());
602+
603+
let trailers = vec![];
604+
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
605+
606+
assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
607+
}
608+
609+
#[test]
610+
fn chunked_with_invalid_trailers() {
611+
let encoder = Encoder::chunked();
612+
613+
let trailers = format!(
614+
"{},{},{},{},{},{},{},{},{},{},{},{}",
615+
AUTHORIZATION,
616+
CACHE_CONTROL,
617+
CONTENT_ENCODING,
618+
CONTENT_LENGTH,
619+
CONTENT_RANGE,
620+
CONTENT_TYPE,
621+
HOST,
622+
MAX_FORWARDS,
623+
SET_COOKIE,
624+
TRAILER,
625+
TRANSFER_ENCODING,
626+
TE,
627+
);
628+
let trailers = vec![HeaderValue::from_str(&trailers).unwrap()];
629+
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
630+
631+
let mut headers = HeaderMap::new();
632+
headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
633+
headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
634+
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
635+
headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
636+
headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
637+
headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
638+
headers.insert(HOST, HeaderValue::from_static("header data"));
639+
headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
640+
headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
641+
headers.insert(TRAILER, HeaderValue::from_static("header data"));
642+
headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
643+
headers.insert(TE, HeaderValue::from_static("header data"));
644+
645+
assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
646+
}
647+
648+
#[test]
649+
fn chunked_with_title_case_headers() {
650+
let encoder = Encoder::chunked();
651+
let trailers = vec![HeaderValue::from_static("chunky-trailer")];
652+
let encoder = encoder.into_chunked_with_trailing_fields(trailers);
653+
654+
let headers = HeaderMap::from_iter(
655+
vec![(
656+
HeaderName::from_static("chunky-trailer"),
657+
HeaderValue::from_static("header data"),
658+
)]
659+
.into_iter(),
660+
);
661+
let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();
662+
663+
let mut dst = Vec::new();
664+
dst.put(buf1);
665+
assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
666+
}
405667
}

‎src/proto/h1/role.rs

+54-4
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ impl Server {
625625
};
626626

627627
let mut encoder = Encoder::length(0);
628+
let mut allowed_trailer_fields: Option<Vec<HeaderValue>> = None;
628629
let mut wrote_date = false;
629630
let mut cur_name = None;
630631
let mut is_name_written = false;
@@ -811,6 +812,38 @@ impl Server {
811812
header::DATE => {
812813
wrote_date = true;
813814
}
815+
header::TRAILER => {
816+
// check that we actually can send a chunked body...
817+
if msg.head.version == Version::HTTP_10
818+
|| !Server::can_chunked(msg.req_method, msg.head.subject)
819+
{
820+
continue;
821+
}
822+
823+
if !is_name_written {
824+
is_name_written = true;
825+
header_name_writer.write_header_name_with_colon(
826+
dst,
827+
"trailer: ",
828+
header::TRAILER,
829+
);
830+
extend(dst, value.as_bytes());
831+
} else {
832+
extend(dst, b", ");
833+
extend(dst, value.as_bytes());
834+
}
835+
836+
match allowed_trailer_fields {
837+
Some(ref mut allowed_trailer_fields) => {
838+
allowed_trailer_fields.push(value);
839+
}
840+
None => {
841+
allowed_trailer_fields = Some(vec![value]);
842+
}
843+
}
844+
845+
continue 'headers;
846+
}
814847
_ => (),
815848
}
816849
//TODO: this should perhaps instead combine them into
@@ -895,6 +928,12 @@ impl Server {
895928
extend(dst, b"\r\n");
896929
}
897930

931+
if encoder.is_chunked() {
932+
if let Some(allowed_trailer_fields) = allowed_trailer_fields {
933+
encoder = encoder.into_chunked_with_trailing_fields(allowed_trailer_fields);
934+
}
935+
}
936+
898937
Ok(encoder.set_last(is_last))
899938
}
900939
}
@@ -1302,6 +1341,19 @@ impl Client {
13021341
}
13031342
};
13041343

1344+
let encoder = encoder.map(|enc| {
1345+
if enc.is_chunked() {
1346+
let allowed_trailer_fields: Vec<HeaderValue> =
1347+
headers.get_all(header::TRAILER).iter().cloned().collect();
1348+
1349+
if !allowed_trailer_fields.is_empty() {
1350+
return enc.into_chunked_with_trailing_fields(allowed_trailer_fields);
1351+
}
1352+
}
1353+
1354+
enc
1355+
});
1356+
13051357
// This is because we need a second mutable borrow to remove
13061358
// content-length header.
13071359
if let Some(encoder) = encoder {
@@ -1464,8 +1516,7 @@ fn title_case(dst: &mut Vec<u8>, name: &[u8]) {
14641516
}
14651517
}
14661518

1467-
#[cfg(feature = "client")]
1468-
fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) {
1519+
pub(crate) fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) {
14691520
for (name, value) in headers {
14701521
title_case(dst, name.as_str().as_bytes());
14711522
extend(dst, b": ");
@@ -1474,8 +1525,7 @@ fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec<u8>) {
14741525
}
14751526
}
14761527

1477-
#[cfg(feature = "client")]
1478-
fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) {
1528+
pub(crate) fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) {
14791529
for (name, value) in headers {
14801530
extend(dst, name.as_str().as_bytes());
14811531
extend(dst, b": ");

‎tests/client.rs

+49-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::convert::Infallible;
55
use std::fmt;
66
use std::future::Future;
77
use std::io::{Read, Write};
8+
use std::iter::FromIterator;
89
use std::net::{SocketAddr, TcpListener};
910
use std::pin::Pin;
1011
use std::thread;
@@ -13,7 +14,7 @@ use std::time::Duration;
1314
use http::uri::PathAndQuery;
1415
use http_body_util::{BodyExt, StreamBody};
1516
use hyper::body::Frame;
16-
use hyper::header::HeaderValue;
17+
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
1718
use hyper::{Method, Request, StatusCode, Uri, Version};
1819

1920
use bytes::Bytes;
@@ -409,6 +410,15 @@ macro_rules! __client_req_prop {
409410
Frame::data,
410411
)));
411412
}};
413+
414+
($req_builder:ident, $body:ident, $addr:ident, body_stream_with_trailers: $body_e:expr) => {{
415+
use support::trailers::StreamBodyWithTrailers;
416+
let (body, trailers) = $body_e;
417+
$body = BodyExt::boxed(StreamBodyWithTrailers::with_trailers(
418+
futures_util::TryStreamExt::map_ok(body, Frame::data),
419+
trailers,
420+
));
421+
}};
412422
}
413423

414424
macro_rules! __client_req_header {
@@ -632,6 +642,44 @@ test! {
632642
body: &b"hello"[..],
633643
}
634644

645+
test! {
646+
name: client_post_req_body_chunked_with_trailer,
647+
648+
server:
649+
expected: "\
650+
POST / HTTP/1.1\r\n\
651+
trailer: chunky-trailer\r\n\
652+
host: {addr}\r\n\
653+
transfer-encoding: chunked\r\n\
654+
\r\n\
655+
5\r\n\
656+
hello\r\n\
657+
0\r\n\
658+
chunky-trailer: header data\r\n\
659+
\r\n\
660+
",
661+
reply: REPLY_OK,
662+
663+
client:
664+
request: {
665+
method: POST,
666+
url: "http://{addr}/",
667+
headers: {
668+
"trailer" => "chunky-trailer",
669+
},
670+
body_stream_with_trailers: (
671+
(futures_util::stream::once(async { Ok::<_, Infallible>(Bytes::from("hello"))})),
672+
HeaderMap::from_iter(vec![(
673+
HeaderName::from_static("chunky-trailer"),
674+
HeaderValue::from_static("header data")
675+
)].into_iter())),
676+
},
677+
response:
678+
status: OK,
679+
headers: {},
680+
body: None,
681+
}
682+
635683
test! {
636684
name: client_get_req_body_sized,
637685

‎tests/server.rs

+102-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use futures_channel::oneshot;
1919
use futures_util::future::{self, Either, FutureExt};
2020
use h2::client::SendRequest;
2121
use h2::{RecvStream, SendStream};
22-
use http::header::{HeaderName, HeaderValue};
22+
use http::header::{HeaderMap, HeaderName, HeaderValue};
2323
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
2424
use hyper::rt::Timer;
2525
use hyper::rt::{Read as AsyncRead, Write as AsyncWrite};
@@ -2595,6 +2595,94 @@ async fn http2_keep_alive_count_server_pings() {
25952595
.expect("timed out waiting for pings");
25962596
}
25972597

2598+
#[test]
2599+
fn http1_trailer_fields() {
2600+
let body = futures_util::stream::once(async move { Ok("hello".into()) });
2601+
let mut headers = HeaderMap::new();
2602+
headers.insert("chunky-trailer", "header data".parse().unwrap());
2603+
// Invalid trailer field that should not be sent
2604+
headers.insert("Host", "www.example.com".parse().unwrap());
2605+
// Not specified in Trailer header, so should not be sent
2606+
headers.insert("foo", "bar".parse().unwrap());
2607+
2608+
let server = serve();
2609+
server
2610+
.reply()
2611+
.header("transfer-encoding", "chunked")
2612+
.header("trailer", "chunky-trailer")
2613+
.body_stream_with_trailers(body, headers);
2614+
let mut req = connect(server.addr());
2615+
req.write_all(
2616+
b"\
2617+
GET / HTTP/1.1\r\n\
2618+
Host: example.domain\r\n\
2619+
Connection: keep-alive\r\n\
2620+
TE: trailers\r\n\
2621+
\r\n\
2622+
",
2623+
)
2624+
.expect("writing");
2625+
2626+
let chunky_trailer_chunk = b"\r\nchunky-trailer: header data\r\n\r\n";
2627+
let res = read_until(&mut req, |buf| buf.ends_with(chunky_trailer_chunk)).expect("reading");
2628+
let sres = s(&res);
2629+
2630+
let expected_head =
2631+
"HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n";
2632+
assert_eq!(&sres[..expected_head.len()], expected_head);
2633+
2634+
// skip the date header
2635+
let date_fragment = "GMT\r\n\r\n";
2636+
let pos = sres.find(date_fragment).expect("find GMT");
2637+
let body = &sres[pos + date_fragment.len()..];
2638+
2639+
let expected_body = "5\r\nhello\r\n0\r\nchunky-trailer: header data\r\n\r\n";
2640+
assert_eq!(body, expected_body);
2641+
}
2642+
2643+
#[test]
2644+
fn http1_trailer_fields_not_allowed() {
2645+
let body = futures_util::stream::once(async move { Ok("hello".into()) });
2646+
let mut headers = HeaderMap::new();
2647+
headers.insert("chunky-trailer", "header data".parse().unwrap());
2648+
2649+
let server = serve();
2650+
server
2651+
.reply()
2652+
.header("transfer-encoding", "chunked")
2653+
.header("trailer", "chunky-trailer")
2654+
.body_stream_with_trailers(body, headers);
2655+
let mut req = connect(server.addr());
2656+
2657+
// TE: trailers is not specified in request headers
2658+
req.write_all(
2659+
b"\
2660+
GET / HTTP/1.1\r\n\
2661+
Host: example.domain\r\n\
2662+
Connection: keep-alive\r\n\
2663+
\r\n\
2664+
",
2665+
)
2666+
.expect("writing");
2667+
2668+
let last_chunk = b"\r\n0\r\n\r\n";
2669+
let res = read_until(&mut req, |buf| buf.ends_with(last_chunk)).expect("reading");
2670+
let sres = s(&res);
2671+
2672+
let expected_head =
2673+
"HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n";
2674+
assert_eq!(&sres[..expected_head.len()], expected_head);
2675+
2676+
// skip the date header
2677+
let date_fragment = "GMT\r\n\r\n";
2678+
let pos = sres.find(date_fragment).expect("find GMT");
2679+
let body = &sres[pos + date_fragment.len()..];
2680+
2681+
// no trailer fields should be sent because TE: trailers was not in request headers
2682+
let expected_body = "5\r\nhello\r\n0\r\n\r\n";
2683+
assert_eq!(body, expected_body);
2684+
}
2685+
25982686
// -------------------------------------------------
25992687
// the Server that is used to run all the tests with
26002688
// -------------------------------------------------
@@ -2700,6 +2788,19 @@ impl<'a> ReplyBuilder<'a> {
27002788
self.tx.lock().unwrap().send(Reply::Body(body)).unwrap();
27012789
}
27022790

2791+
fn body_stream_with_trailers<S>(self, stream: S, trailers: HeaderMap)
2792+
where
2793+
S: futures_util::Stream<Item = Result<Bytes, BoxError>> + Send + Sync + 'static,
2794+
{
2795+
use futures_util::TryStreamExt;
2796+
use hyper::body::Frame;
2797+
use support::trailers::StreamBodyWithTrailers;
2798+
let mut stream_body = StreamBodyWithTrailers::new(stream.map_ok(Frame::data));
2799+
stream_body.set_trailers(trailers);
2800+
let body = BodyExt::boxed(stream_body);
2801+
self.tx.lock().unwrap().send(Reply::Body(body)).unwrap();
2802+
}
2803+
27032804
#[allow(dead_code)]
27042805
fn error<E: Into<BoxError>>(self, err: E) {
27052806
self.tx

‎tests/support/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ mod tokiort;
2424
#[allow(unused)]
2525
pub use tokiort::{TokioExecutor, TokioIo, TokioTimer};
2626

27+
pub mod trailers;
28+
2729
#[allow(unused_macros)]
2830
macro_rules! t {
2931
(

‎tests/support/trailers.rs

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use bytes::Buf;
2+
use futures_util::stream::Stream;
3+
use http::header::HeaderMap;
4+
use http_body::{Body, Frame};
5+
use pin_project_lite::pin_project;
6+
use std::{
7+
pin::Pin,
8+
task::{Context, Poll},
9+
};
10+
11+
pin_project! {
12+
/// A body created from a [`Stream`].
13+
#[derive(Clone, Debug)]
14+
pub struct StreamBodyWithTrailers<S> {
15+
#[pin]
16+
stream: S,
17+
trailers: Option<HeaderMap>,
18+
}
19+
}
20+
21+
impl<S> StreamBodyWithTrailers<S> {
22+
/// Create a new `StreamBodyWithTrailers`.
23+
pub fn new(stream: S) -> Self {
24+
Self {
25+
stream,
26+
trailers: None,
27+
}
28+
}
29+
30+
pub fn with_trailers(stream: S, trailers: HeaderMap) -> Self {
31+
Self {
32+
stream,
33+
trailers: Some(trailers),
34+
}
35+
}
36+
37+
pub fn set_trailers(&mut self, trailers: HeaderMap) {
38+
self.trailers = Some(trailers);
39+
}
40+
}
41+
42+
impl<S, D, E> Body for StreamBodyWithTrailers<S>
43+
where
44+
S: Stream<Item = Result<Frame<D>, E>>,
45+
D: Buf,
46+
{
47+
type Data = D;
48+
type Error = E;
49+
50+
fn poll_frame(
51+
self: Pin<&mut Self>,
52+
cx: &mut Context<'_>,
53+
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
54+
let project = self.project();
55+
match project.stream.poll_next(cx) {
56+
Poll::Ready(Some(result)) => Poll::Ready(Some(result)),
57+
Poll::Ready(None) => match project.trailers.take() {
58+
Some(trailers) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
59+
None => Poll::Ready(None),
60+
},
61+
Poll::Pending => Poll::Pending,
62+
}
63+
}
64+
}
65+
66+
impl<S: Stream> Stream for StreamBodyWithTrailers<S> {
67+
type Item = S::Item;
68+
69+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
70+
self.project().stream.poll_next(cx)
71+
}
72+
73+
fn size_hint(&self) -> (usize, Option<usize>) {
74+
self.stream.size_hint()
75+
}
76+
}

0 commit comments

Comments
 (0)
Please sign in to comment.