diff --git a/wsstomp.go b/wsstomp.go index 3f41d1b..58e5eb8 100644 --- a/wsstomp.go +++ b/wsstomp.go @@ -2,6 +2,8 @@ package wsstomp import ( "context" + "fmt" + "net/http" "time" "nhooyr.io/websocket" @@ -62,6 +64,28 @@ func (w *WebsocketSTOMP) Close() error { // The context parameter will only be used for the connection handshake, // and not for the full lifetime of the connection. func Connect(ctx context.Context, url string, options *websocket.DialOptions) (*WebsocketSTOMP, error) { + if options == nil { + options = &websocket.DialOptions{} + } + if options.HTTPClient == nil { + options.HTTPClient = &http.Client{ + // fix for https://github.com/nhooyr/websocket/issues/333 + CheckRedirect: func(req *http.Request, via []*http.Request) error { + switch req.URL.Scheme { + case "ws": + req.URL.Scheme = "http" + case "wss": + req.URL.Scheme = "https" + case "http", "https": + default: + return fmt.Errorf("unexpected url scheme: %q", req.URL.Scheme) + } + return nil + }, + // sane timeout + Timeout: time.Second * 30, + } + } con, _, err := websocket.Dial(ctx, url, options) return &WebsocketSTOMP{ connection: con,