Skip to content

Commit 3d1497a

Browse files
committedFeb 7, 2025·
fix: fix early cancel when RequestTimeout is provided for streaming requests (#3904)
1 parent 846fa4f commit 3d1497a

File tree

1 file changed

+52
-7
lines changed

1 file changed

+52
-7
lines changed
 

‎internal/requestconfig/requestconfig.go

+52-7
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,41 @@ func parseRetryAfterHeader(resp *http.Response) (time.Duration, bool) {
295295
return 0, false
296296
}
297297

298+
// isBeforeContextDeadline reports whether the non-zero Time t is
299+
// before ctx's deadline. If ctx does not have a deadline, it
300+
// always reports true (the deadline is considered infinite).
301+
func isBeforeContextDeadline(t time.Time, ctx context.Context) bool {
302+
d, ok := ctx.Deadline()
303+
if !ok {
304+
return true
305+
}
306+
return t.Before(d)
307+
}
308+
309+
// bodyWithTimeout is an io.ReadCloser which can observe a context's cancel func
310+
// to handle timeouts etc. It wraps an existing io.ReadCloser.
311+
type bodyWithTimeout struct {
312+
stop func() // stops the time.Timer waiting to cancel the request
313+
rc io.ReadCloser
314+
}
315+
316+
func (b *bodyWithTimeout) Read(p []byte) (n int, err error) {
317+
n, err = b.rc.Read(p)
318+
if err == nil {
319+
return n, nil
320+
}
321+
if err == io.EOF {
322+
return n, err
323+
}
324+
return n, err
325+
}
326+
327+
func (b *bodyWithTimeout) Close() error {
328+
err := b.rc.Close()
329+
b.stop()
330+
return err
331+
}
332+
298333
func retryDelay(res *http.Response, retryCount int) time.Duration {
299334
// If the API asks us to wait a certain amount of time (and it's a reasonable amount),
300335
// just do what it says.
@@ -356,12 +391,17 @@ func (cfg *RequestConfig) Execute() (err error) {
356391
shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0"
357392

358393
var res *http.Response
394+
var cancel context.CancelFunc
359395
for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 {
360396
ctx := cfg.Request.Context()
361-
if cfg.RequestTimeout != time.Duration(0) {
362-
var cancel context.CancelFunc
397+
if cfg.RequestTimeout != time.Duration(0) && isBeforeContextDeadline(time.Now().Add(cfg.RequestTimeout), ctx) {
363398
ctx, cancel = context.WithTimeout(ctx, cfg.RequestTimeout)
364-
defer cancel()
399+
defer func() {
400+
// The cancel function is nil if it was handed off to be handled in a different scope.
401+
if cancel != nil {
402+
cancel()
403+
}
404+
}()
365405
}
366406

367407
req := cfg.Request.Clone(ctx)
@@ -429,10 +469,15 @@ func (cfg *RequestConfig) Execute() (err error) {
429469
return &aerr
430470
}
431471

432-
if cfg.ResponseBodyInto == nil {
433-
return nil
434-
}
435-
if _, ok := cfg.ResponseBodyInto.(**http.Response); ok {
472+
_, intoCustomResponseBody := cfg.ResponseBodyInto.(**http.Response)
473+
if cfg.ResponseBodyInto == nil || intoCustomResponseBody {
474+
// We aren't reading the response body in this scope, but whoever is will need the
475+
// cancel func from the context to observe request timeouts.
476+
// Put the cancel function in the response body so it can be handled elsewhere.
477+
if cancel != nil {
478+
res.Body = &bodyWithTimeout{rc: res.Body, stop: cancel}
479+
cancel = nil
480+
}
436481
return nil
437482
}
438483

0 commit comments

Comments
 (0)
Please sign in to comment.