@@ -47,6 +47,7 @@ enum Kind {
47
47
48
48
#[ derive( Debug , PartialEq , Clone , Copy ) ]
49
49
enum ChunkedState {
50
+ Start ,
50
51
Size ,
51
52
SizeLws ,
52
53
Extension ,
@@ -72,7 +73,7 @@ impl Decoder {
72
73
73
74
pub ( crate ) fn chunked ( ) -> Decoder {
74
75
Decoder {
75
- kind : Kind :: Chunked ( ChunkedState :: Size , 0 ) ,
76
+ kind : Kind :: Chunked ( ChunkedState :: new ( ) , 0 ) ,
76
77
}
77
78
}
78
79
@@ -180,7 +181,22 @@ macro_rules! byte (
180
181
} )
181
182
) ;
182
183
184
+ macro_rules! or_overflow {
185
+ ( $e: expr) => (
186
+ match $e {
187
+ Some ( val) => val,
188
+ None => return Poll :: Ready ( Err ( io:: Error :: new(
189
+ io:: ErrorKind :: InvalidData ,
190
+ "invalid chunk size: overflow" ,
191
+ ) ) ) ,
192
+ }
193
+ )
194
+ }
195
+
183
196
impl ChunkedState {
197
+ fn new ( ) -> ChunkedState {
198
+ ChunkedState :: Start
199
+ }
184
200
fn step < R : MemRead > (
185
201
& self ,
186
202
cx : & mut Context < ' _ > ,
@@ -190,6 +206,7 @@ impl ChunkedState {
190
206
) -> Poll < Result < ChunkedState , io:: Error > > {
191
207
use self :: ChunkedState :: * ;
192
208
match * self {
209
+ Start => ChunkedState :: read_start ( cx, body, size) ,
193
210
Size => ChunkedState :: read_size ( cx, body, size) ,
194
211
SizeLws => ChunkedState :: read_size_lws ( cx, body) ,
195
212
Extension => ChunkedState :: read_extension ( cx, body) ,
@@ -204,25 +221,46 @@ impl ChunkedState {
204
221
End => Poll :: Ready ( Ok ( ChunkedState :: End ) ) ,
205
222
}
206
223
}
207
- fn read_size < R : MemRead > (
224
+
225
+ fn read_start < R : MemRead > (
208
226
cx : & mut Context < ' _ > ,
209
227
rdr : & mut R ,
210
228
size : & mut u64 ,
211
229
) -> Poll < Result < ChunkedState , io:: Error > > {
212
- trace ! ( "Read chunk hex size " ) ;
230
+ trace ! ( "Read chunk start " ) ;
213
231
214
- macro_rules! or_overflow {
215
- ( $e: expr) => (
216
- match $e {
217
- Some ( val) => val,
218
- None => return Poll :: Ready ( Err ( io:: Error :: new(
219
- io:: ErrorKind :: InvalidData ,
220
- "invalid chunk size: overflow" ,
221
- ) ) ) ,
222
- }
223
- )
232
+ let radix = 16 ;
233
+ match byte ! ( rdr, cx) {
234
+ b @ b'0' ..=b'9' => {
235
+ * size = or_overflow ! ( size. checked_mul( radix) ) ;
236
+ * size = or_overflow ! ( size. checked_add( ( b - b'0' ) as u64 ) ) ;
237
+ }
238
+ b @ b'a' ..=b'f' => {
239
+ * size = or_overflow ! ( size. checked_mul( radix) ) ;
240
+ * size = or_overflow ! ( size. checked_add( ( b + 10 - b'a' ) as u64 ) ) ;
241
+ }
242
+ b @ b'A' ..=b'F' => {
243
+ * size = or_overflow ! ( size. checked_mul( radix) ) ;
244
+ * size = or_overflow ! ( size. checked_add( ( b + 10 - b'A' ) as u64 ) ) ;
245
+ }
246
+ _ => {
247
+ return Poll :: Ready ( Err ( io:: Error :: new (
248
+ io:: ErrorKind :: InvalidInput ,
249
+ "Invalid chunk size line: missing size digit" ,
250
+ ) ) ) ;
251
+ }
224
252
}
225
253
254
+ Poll :: Ready ( Ok ( ChunkedState :: Size ) )
255
+ }
256
+
257
+ fn read_size < R : MemRead > (
258
+ cx : & mut Context < ' _ > ,
259
+ rdr : & mut R ,
260
+ size : & mut u64 ,
261
+ ) -> Poll < Result < ChunkedState , io:: Error > > {
262
+ trace ! ( "Read chunk hex size" ) ;
263
+
226
264
let radix = 16 ;
227
265
match byte ! ( rdr, cx) {
228
266
b @ b'0' ..=b'9' => {
@@ -478,7 +516,7 @@ mod tests {
478
516
use std:: io:: ErrorKind :: { InvalidData , InvalidInput , UnexpectedEof } ;
479
517
480
518
async fn read ( s : & str ) -> u64 {
481
- let mut state = ChunkedState :: Size ;
519
+ let mut state = ChunkedState :: new ( ) ;
482
520
let rdr = & mut s. as_bytes ( ) ;
483
521
let mut size = 0 ;
484
522
loop {
@@ -495,7 +533,7 @@ mod tests {
495
533
}
496
534
497
535
async fn read_err ( s : & str , expected_err : io:: ErrorKind ) {
498
- let mut state = ChunkedState :: Size ;
536
+ let mut state = ChunkedState :: new ( ) ;
499
537
let rdr = & mut s. as_bytes ( ) ;
500
538
let mut size = 0 ;
501
539
loop {
@@ -532,6 +570,9 @@ mod tests {
532
570
// Missing LF or CRLF
533
571
read_err ( "F\r F" , InvalidInput ) . await ;
534
572
read_err ( "F" , UnexpectedEof ) . await ;
573
+ // Missing digit
574
+ read_err ( "\r \n \r \n " , InvalidInput ) . await ;
575
+ read_err ( "\r \n " , InvalidInput ) . await ;
535
576
// Invalid hex digit
536
577
read_err ( "X\r \n " , InvalidInput ) . await ;
537
578
read_err ( "1X\r \n " , InvalidInput ) . await ;
0 commit comments