Skip to content

Commit

Permalink
dial_test: Add TestDialViaProxy
Browse files Browse the repository at this point in the history
For #395

Somehow currently reproduces #391...

Debugging still.
  • Loading branch information
nhooyr committed Oct 19, 2023
1 parent e314da6 commit 249edb2
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 15 deletions.
26 changes: 26 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,29 @@ func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOp
err = wstest.EchoLoop(r.Context(), c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
}

func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) {
exp := xrand.String(xrand.Int(131072))

werr := xsync.Go(func() error {
return wsjson.Write(ctx, c, exp)
})

var act interface{}
err := wsjson.Read(ctx, c, &act)
assert.Success(tb, err)
assert.Equal(tb, "read msg", exp, act)

select {
case err := <-werr:
assert.Success(tb, err)
case <-ctx.Done():
tb.Fatal(ctx.Err())
}
}

func assertClose(tb testing.TB, c *websocket.Conn) {
tb.Helper()
err := c.Close(websocket.StatusNormalClosure, "")
assert.Success(tb, err)
}
110 changes: 95 additions & 15 deletions dial_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//go:build !js
// +build !js

package websocket
package websocket_test

import (
"bytes"
Expand All @@ -10,12 +10,15 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/util"
"nhooyr.io/websocket/internal/xsync"
)

func TestBadDials(t *testing.T) {
Expand All @@ -27,7 +30,7 @@ func TestBadDials(t *testing.T) {
testCases := []struct {
name string
url string
opts *DialOptions
opts *websocket.DialOptions
rand util.ReaderFunc
nilCtx bool
}{
Expand Down Expand Up @@ -72,7 +75,7 @@ func TestBadDials(t *testing.T) {
tc.rand = rand.Reader.Read
}

_, _, err := dial(ctx, tc.url, tc.opts, tc.rand)
_, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
assert.Error(t, err)
})
}
Expand All @@ -84,7 +87,7 @@ func TestBadDials(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(strings.NewReader("hi")),
Expand All @@ -104,7 +107,7 @@ func TestBadDials(t *testing.T) {
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))

return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Expand All @@ -113,7 +116,7 @@ func TestBadDials(t *testing.T) {
}, nil
}

_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
})
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
Expand Down Expand Up @@ -152,7 +155,7 @@ func Test_verifyHostOverride(t *testing.T) {
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))

return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Expand All @@ -161,7 +164,7 @@ func Test_verifyHostOverride(t *testing.T) {
}, nil
}

_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
Host: tc.host,
})
Expand Down Expand Up @@ -272,18 +275,18 @@ func Test_verifyServerHandshake(t *testing.T) {
resp := w.Result()

r := httptest.NewRequest("GET", "/", nil)
key, err := secWebSocketKey(rand.Reader)
key, err := websocket.SecWebSocketKey(rand.Reader)
assert.Success(t, err)
r.Header.Set("Sec-WebSocket-Key", key)

if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
}

opts := &DialOptions{
opts := &websocket.DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
_, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
if tc.success {
assert.Success(t, err)
} else {
Expand Down Expand Up @@ -311,7 +314,7 @@ func TestDialRedirect(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
resp := &http.Response{
Header: http.Header{},
Expand All @@ -321,11 +324,88 @@ func TestDialRedirect(t *testing.T) {
resp.StatusCode = http.StatusFound
return resp, nil
}
resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "meow")
resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "meow")
resp.StatusCode = http.StatusSwitchingProtocols
return resp, nil
}),
})
assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
}

type forwardProxy struct {
hc *http.Client
}

func newForwardProxy() *forwardProxy {
return &forwardProxy{
hc: &http.Client{},
}
}

func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()

r = r.WithContext(ctx)
r.RequestURI = ""
resp, err := fc.hc.Do(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer resp.Body.Close()

for k, v := range resp.Header {
w.Header()[k] = v
}
w.Header().Set("PROXIED", "true")
w.WriteHeader(resp.StatusCode)
errc1 := xsync.Go(func() error {
_, err := io.Copy(w, resp.Body)
return err
})
var errc2 <-chan error
if bodyw, ok := resp.Body.(io.Writer); ok {
errc2 = xsync.Go(func() error {
_, err := io.Copy(bodyw, r.Body)
return err
})
}
select {
case <-errc1:
case <-errc2:
case <-r.Context().Done():
}
}

func TestDialViaProxy(t *testing.T) {
t.Parallel()

ps := httptest.NewServer(newForwardProxy())
defer ps.Close()

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, nil)
assert.Success(t, err)
}))
defer s.Close()

psu, err := url.Parse(ps.URL)
assert.Success(t, err)
proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
proxyTransport.Proxy = http.ProxyURL(psu)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: proxyTransport,
},
})
assert.Success(t, err)
assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))

assertEcho(t, ctx, c)
assertClose(t, c)
}
7 changes: 7 additions & 0 deletions export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,10 @@ func (c *Conn) RecordBytesRead() *int {
}

var ErrClosed = errClosed

var ExportedDial = dial
var SecWebSocketAccept = secWebSocketAccept
var SecWebSocketKey = secWebSocketKey
var VerifyServerResponse = verifyServerResponse

var CompressionModeOpts = CompressionMode.opts

0 comments on commit 249edb2

Please sign in to comment.