diff --git a/conn_test.go b/conn_test.go index 5f78cad5..c814ca28 100644 --- a/conn_test.go +++ b/conn_test.go @@ -535,6 +535,7 @@ func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) { }) var act interface{} + c.SetReadLimit(1 << 30) err := wsjson.Read(ctx, c, &act) assert.Success(tb, err) assert.Equal(tb, "read msg", exp, act) diff --git a/dial_test.go b/dial_test.go index 7a84436d..63cb4be6 100644 --- a/dial_test.go +++ b/dial_test.go @@ -361,21 +361,29 @@ func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.Header().Set("PROXIED", "true") w.WriteHeader(resp.StatusCode) - errc1 := xsync.Go(func() error { - _, err := io.Copy(w, resp.Body) - return err - }) - var errc2 <-chan error - if bodyw, ok := resp.Body.(io.Writer); ok { - errc2 = xsync.Go(func() error { - _, err := io.Copy(bodyw, r.Body) + if resprw, ok := resp.Body.(io.ReadWriter); ok { + c, brw, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + brw.Flush() + + errc1 := xsync.Go(func() error { + _, err := io.Copy(c, resprw) return err }) - } - select { - case <-errc1: - case <-errc2: - case <-r.Context().Done(): + errc2 := xsync.Go(func() error { + _, err := io.Copy(resprw, c) + return err + }) + select { + case <-errc1: + case <-errc2: + case <-r.Context().Done(): + } + } else { + io.Copy(w, resp.Body) } }