Skip to content

Commit

Permalink
use jwt to prevent redirect_uri changes
Browse files Browse the repository at this point in the history
  • Loading branch information
goenning committed May 4, 2024
1 parent 0b6dae9 commit 149101e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 18 deletions.
13 changes: 8 additions & 5 deletions app/handlers/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,17 @@ func OAuthCallback() web.HandlerFunc {

provider := c.Param("provider")
state := c.QueryParam("state")
parts := strings.Split(state, "|")
claims, err := jwt.DecodeOAuthStateClaims(state)
if err != nil {
return c.Forbidden()
}

if parts[0] == "" {
if claims.Redirect == "" {
log.Warnf(c, "Missing redirect URL in OAuth callback state for provider @{Provider}.", dto.Props{"Provider": provider})
return c.NotFound()
}

redirectURL, err := url.ParseRequestURI(parts[0])
redirectURL, err := url.ParseRequestURI(claims.Redirect)
if err != nil {
return c.Failure(err)
}
Expand All @@ -184,7 +187,7 @@ func OAuthCallback() web.HandlerFunc {
if redirectURL.Path == fmt.Sprintf("/oauth/%s/echo", provider) {
var query = redirectURL.Query()
query.Set("code", code)
query.Set("identifier", parts[1])
query.Set("identifier", claims.Identifier)
redirectURL.RawQuery = query.Encode()
return c.Redirect(redirectURL.String())
}
Expand Down Expand Up @@ -221,7 +224,7 @@ func OAuthCallback() web.HandlerFunc {
var query = redirectURL.Query()
query.Set("code", code)
query.Set("redirect", redirectURL.RequestURI())
query.Set("identifier", parts[1])
query.Set("identifier", claims.Identifier)
redirectURL.RawQuery = query.Encode()
redirectURL.Path = fmt.Sprintf("/oauth/%s/token", provider)
return c.Redirect(redirectURL.String())
Expand Down
43 changes: 36 additions & 7 deletions app/handlers/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ func TestSignInByOAuthHandler_InvalidURL(t *testing.T) {
RegisterT(t)
bus.Init(&oauth.Service{})

state, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://avengers.test.fider.io",
Identifier: "MY_SESSION_ID",
})

server := mock.NewServer()
code, response := server.
AddParam("provider", app.FacebookProvider).
Expand All @@ -69,7 +74,7 @@ func TestSignInByOAuthHandler_InvalidURL(t *testing.T) {
Execute(handlers.SignInByOAuth())

Expect(code).Equals(http.StatusTemporaryRedirect)
Expect(response.Header().Get("Location")).Equals("https://www.facebook.com/v3.2/dialog/oauth?client_id=FB_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%2Foauth%2Ffacebook%2Fcallback&response_type=code&scope=public_profile+email&state=http%3A%2F%2Favengers.test.fider.io%7CMY_SESSION_ID")
Expect(response.Header().Get("Location")).Equals("https://www.facebook.com/v3.2/dialog/oauth?client_id=FB_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%2Foauth%2Ffacebook%2Fcallback&response_type=code&scope=public_profile+email&state="+state)
}

func TestSignInByOAuthHandler_AuthenticatedUser(t *testing.T) {
Expand Down Expand Up @@ -102,8 +107,13 @@ func TestSignInByOAuthHandler_AuthenticatedUser_UsingEcho(t *testing.T) {
Use(middlewares.Session()).
Execute(handlers.SignInByOAuth())

state, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://avengers.test.fider.io/oauth/facebook/echo",
Identifier: "MY_SESSION_ID",
})

Expect(code).Equals(http.StatusTemporaryRedirect)
Expect(response.Header().Get("Location")).Equals("https://www.facebook.com/v3.2/dialog/oauth?client_id=FB_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%2Foauth%2Ffacebook%2Fcallback&response_type=code&scope=public_profile+email&state=http%3A%2F%2Favengers.test.fider.io%2Foauth%2Ffacebook%2Fecho%7CMY_SESSION_ID")
Expect(response.Header().Get("Location")).Equals("https://www.facebook.com/v3.2/dialog/oauth?client_id=FB_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%2Foauth%2Ffacebook%2Fcallback&response_type=code&scope=public_profile+email&state="+state)
}

func TestCallbackHandler_InvalidState(t *testing.T) {
Expand All @@ -115,16 +125,20 @@ func TestCallbackHandler_InvalidState(t *testing.T) {
AddParam("provider", app.FacebookProvider).
Execute(handlers.OAuthCallback())

Expect(code).Equals(http.StatusInternalServerError)
Expect(code).Equals(http.StatusForbidden)
}

func TestCallbackHandler_InvalidCode(t *testing.T) {
RegisterT(t)

server := mock.NewServer()
state, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://avengers.test.fider.io",
Identifier: "",
})

code, response := server.
WithURL("http://login.test.fider.io/oauth/callback?state=http://avengers.test.fider.io").
WithURL("http://login.test.fider.io/oauth/callback?state="+state).
AddParam("provider", app.FacebookProvider).
Execute(handlers.OAuthCallback())

Expand All @@ -135,9 +149,14 @@ func TestCallbackHandler_InvalidCode(t *testing.T) {
func TestCallbackHandler_SignIn(t *testing.T) {
RegisterT(t)

state, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://avengers.test.fider.io",
Identifier: "888",
})

server := mock.NewServer()
code, response := server.
WithURL("http://login.test.fider.io/oauth/callback?state=http://avengers.test.fider.io|888&code=123").
WithURL("http://login.test.fider.io/oauth/callback?state="+state+"&code=123").
AddParam("provider", app.FacebookProvider).
Execute(handlers.OAuthCallback())

Expand All @@ -149,8 +168,13 @@ func TestCallbackHandler_SignIn_WithPath(t *testing.T) {
RegisterT(t)
server := mock.NewServer()

state, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://avengers.test.fider.io/some-page",
Identifier: "888",
})

code, response := server.
WithURL("http://login.test.fider.io/oauth/callback?state=http://avengers.test.fider.io/some-page|888&code=123").
WithURL("http://login.test.fider.io/oauth/callback?state="+state+"&code=123").
AddParam("provider", app.FacebookProvider).
Execute(handlers.OAuthCallback())

Expand All @@ -175,9 +199,14 @@ func TestCallbackHandler_SignUp(t *testing.T) {
return app.ErrNotFound
})

state, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://demo.test.fider.io/signup",
Identifier: "",
})

server := mock.NewServer()
code, response := server.
WithURL("http://login.test.fider.io/oauth/callback?state=http://demo.test.fider.io/signup&code=123").
WithURL("http://login.test.fider.io/oauth/callback?state="+state+"&code=123").
AddParam("provider", app.FacebookProvider).
Execute(handlers.OAuthCallback())
Expect(code).Equals(http.StatusTemporaryRedirect)
Expand Down
17 changes: 17 additions & 0 deletions app/pkg/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ type OAuthClaims struct {
Metadata
}

// OAuthStateClaims represents what goes into JWT tokens used for OAuth state parameter
type OAuthStateClaims struct {
Redirect string `json:"oauthstate/redirect"`
Identifier string `json:"oauthstate/identifier"`
Metadata
}

// Encode creates new JWT token with given claims
func Encode(claims jwtgo.Claims) (string, error) {
jwtToken := jwtgo.NewWithClaims(jwtgo.GetSigningMethod("HS256"), claims)
Expand Down Expand Up @@ -73,6 +80,16 @@ func DecodeOAuthClaims(token string) (*OAuthClaims, error) {
return claims, nil
}

// DecodeOAuthStateClaims extract OAuthClaims from given JWT token
func DecodeOAuthStateClaims(token string) (*OAuthStateClaims, error) {
claims := &OAuthStateClaims{}
err := decode(token, claims)
if err != nil {
return nil, errors.Wrap(err, "failed to decode OAuthState claims")
}
return claims, nil
}

func decode(token string, claims jwtgo.Claims) error {
jwtToken, err := jwtgo.ParseWithClaims(token, claims, func(t *jwtgo.Token) (any, error) {
if _, ok := t.Method.(*jwtgo.SigningMethodHMAC); !ok {
Expand Down
13 changes: 12 additions & 1 deletion app/services/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/getfider/fider/app/pkg/env"
"github.com/getfider/fider/app/pkg/errors"
"github.com/getfider/fider/app/pkg/jsonq"
"github.com/getfider/fider/app/pkg/jwt"
"github.com/getfider/fider/app/pkg/validate"
"github.com/getfider/fider/app/pkg/web"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -155,7 +156,17 @@ func getOAuthAuthorizationURL(ctx context.Context, q *query.GetOAuthAuthorizatio
parameters.Add("scope", config.Scope)
parameters.Add("redirect_uri", fmt.Sprintf("%s/oauth/%s/callback", oauthBaseURL, q.Provider))
parameters.Add("response_type", "code")
parameters.Add("state", q.Redirect+"|"+q.Identifier)

state, err := jwt.Encode(jwt.OAuthStateClaims{
Redirect: q.Redirect,
Identifier: q.Identifier,
})

if err != nil {
return err
}

parameters.Add("state", state)

authURL.RawQuery = parameters.Encode()
q.Result = authURL.String()
Expand Down
36 changes: 31 additions & 5 deletions app/services/oauth/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/getfider/fider/app/pkg/bus"
"github.com/getfider/fider/app/pkg/errors"
"github.com/getfider/fider/app/pkg/jwt"
"github.com/getfider/fider/app/pkg/web"

. "github.com/getfider/fider/app/pkg/assert"
Expand Down Expand Up @@ -46,9 +47,14 @@ func TestGetAuthURL_Facebook(t *testing.T) {
Identifier: "456",
}

expectedState, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://example.org",
Identifier: "456",
})

err := bus.Dispatch(ctx, authURL)
Expect(err).IsNil()
Expect(authURL.Result).Equals("https://www.facebook.com/v3.2/dialog/oauth?client_id=FB_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2Ffacebook%2Fcallback&response_type=code&scope=public_profile+email&state=http%3A%2F%2Fexample.org%7C456")
Expect(authURL.Result).Equals("https://www.facebook.com/v3.2/dialog/oauth?client_id=FB_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2Ffacebook%2Fcallback&response_type=code&scope=public_profile+email&state="+expectedState)
}

func TestGetAuthURL_Google(t *testing.T) {
Expand All @@ -63,9 +69,14 @@ func TestGetAuthURL_Google(t *testing.T) {
Identifier: "123",
}

expectedState, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://example.org",
Identifier: "123",
})

err := bus.Dispatch(ctx, authURL)
Expect(err).IsNil()
Expect(authURL.Result).Equals("https://accounts.google.com/o/oauth2/v2/auth?client_id=GO_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2Fgoogle%2Fcallback&response_type=code&scope=profile+email&state=http%3A%2F%2Fexample.org%7C123")
Expect(authURL.Result).Equals("https://accounts.google.com/o/oauth2/v2/auth?client_id=GO_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2Fgoogle%2Fcallback&response_type=code&scope=profile+email&state="+expectedState)
}

func TestGetAuthURL_GitHub(t *testing.T) {
Expand All @@ -80,9 +91,14 @@ func TestGetAuthURL_GitHub(t *testing.T) {
Identifier: "456",
}

expectedState, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://example.org",
Identifier: "456",
})

err := bus.Dispatch(ctx, authURL)
Expect(err).IsNil()
Expect(authURL.Result).Equals("https://github.com/login/oauth/authorize?client_id=GH_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2Fgithub%2Fcallback&response_type=code&scope=user%3Aemail&state=http%3A%2F%2Fexample.org%7C456")
Expect(authURL.Result).Equals("https://github.com/login/oauth/authorize?client_id=GH_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2Fgithub%2Fcallback&response_type=code&scope=user%3Aemail&state="+expectedState)
}

func TestGetAuthURL_Custom(t *testing.T) {
Expand All @@ -108,9 +124,14 @@ func TestGetAuthURL_Custom(t *testing.T) {
Identifier: "456",
}

expectedState, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://example.org",
Identifier: "456",
})

err := bus.Dispatch(ctx, authURL)
Expect(err).IsNil()
Expect(authURL.Result).Equals("https://example.org/oauth/authorize?client_id=CU_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2F_custom%2Fcallback&response_type=code&scope=profile+email&state=http%3A%2F%2Fexample.org%7C456")
Expect(authURL.Result).Equals("https://example.org/oauth/authorize?client_id=CU_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2F_custom%2Fcallback&response_type=code&scope=profile+email&state="+expectedState)
}

func TestGetAuthURL_Twitch(t *testing.T) {
Expand All @@ -136,9 +157,14 @@ func TestGetAuthURL_Twitch(t *testing.T) {
Identifier: "456",
}

expectedState, _ := jwt.Encode(jwt.OAuthStateClaims{
Redirect: "http://example.org",
Identifier: "456",
})

err := bus.Dispatch(ctx, authURL)
Expect(err).IsNil()
Expect(authURL.Result).Equals("https://id.twitch.tv/oauth/authorize?claims=%7B%22userinfo%22%3A%7B%22preferred_username%22%3Anull%2C%22email%22%3Anull%2C%22email_verified%22%3Anull%7D%7D&client_id=CU_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2F_custom%2Fcallback&response_type=code&scope=openid&state=http%3A%2F%2Fexample.org%7C456")
Expect(authURL.Result).Equals("https://id.twitch.tv/oauth/authorize?claims=%7B%22userinfo%22%3A%7B%22preferred_username%22%3Anull%2C%22email%22%3Anull%2C%22email_verified%22%3Anull%7D%7D&client_id=CU_CL_ID&redirect_uri=http%3A%2F%2Flogin.test.fider.io%3A3000%2Foauth%2F_custom%2Fcallback&response_type=code&scope=openid&state="+expectedState)
}

func TestParseProfileResponse_AllFields(t *testing.T) {
Expand Down

0 comments on commit 149101e

Please sign in to comment.