From b4b86b904ee818dc480b8b7384bd92a751a5c0ee Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 13 Oct 2023 02:17:45 -0700 Subject: [PATCH] dial.go: Use timeout on HTTPClient properly Closes #341 --- conn_test.go | 31 +++++++++++++++++++++++++++++++ dial.go | 9 +++++---- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/conn_test.go b/conn_test.go index d80acce2..59661b73 100644 --- a/conn_test.go +++ b/conn_test.go @@ -264,6 +264,37 @@ func TestConn(t *testing.T) { err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) + + t.Run("HTTPClient.Timeout", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ + HTTPClient: &http.Client{Timeout: time.Second*5}, + }, nil) + + tt.goEchoLoop(c2) + + c1.SetReadLimit(1 << 30) + + exp := xrand.String(xrand.Int(131072)) + + werr := xsync.Go(func() error { + return wsjson.Write(tt.ctx, c1, exp) + }) + + var act interface{} + err := wsjson.Read(tt.ctx, c1, &act) + assert.Success(t, err) + assert.Equal(t, "read msg", exp, act) + + select { + case err := <-werr: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + err = c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + }) } func TestWasm(t *testing.T) { diff --git a/dial.go b/dial.go index 510b94b1..0f2735da 100644 --- a/dial.go +++ b/dial.go @@ -59,12 +59,13 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context } if o.HTTPClient == nil { o.HTTPClient = http.DefaultClient - } else if opts.HTTPClient.Timeout > 0 { - ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) + } + if o.HTTPClient.Timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout) - newClient := *opts.HTTPClient + newClient := *o.HTTPClient newClient.Timeout = 0 - opts.HTTPClient = &newClient + o.HTTPClient = &newClient } if o.HTTPHeader == nil { o.HTTPHeader = http.Header{}