Skip to content

Commit

Permalink
oidc: restrict use of context.Background()
Browse files Browse the repository at this point in the history
  • Loading branch information
ericchiang committed Feb 4, 2023
1 parent a8ceb9a commit 2936eb3
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 28 deletions.
2 changes: 1 addition & 1 deletion oidc/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time)
if now == nil {
now = time.Now
}
return &RemoteKeySet{jwksURL: jwksURL, ctx: cloneContext(ctx), now: now}
return &RemoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
}

// RemoteKeySet is a KeySet implementation that validates JSON web tokens against
Expand Down
66 changes: 43 additions & 23 deletions oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"mime"
"net/http"
"strings"
"sync"
"time"

"golang.org/x/oauth2"
Expand Down Expand Up @@ -58,15 +59,11 @@ func ClientContext(ctx context.Context, client *http.Client) context.Context {
return context.WithValue(ctx, oauth2.HTTPClient, client)
}

// cloneContext copies a context's bag-of-values into a new context that isn't
// associated with its cancellation. This is used to initialize remote keys sets
// which run in the background and aren't associated with the initial context.
func cloneContext(ctx context.Context) context.Context {
cp := context.Background()
func getClient(ctx context.Context) *http.Client {
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
cp = ClientContext(cp, c)
return c
}
return cp
return nil
}

// InsecureIssuerURLContext allows discovery to work when the issuer_url reported
Expand All @@ -90,7 +87,7 @@ func InsecureIssuerURLContext(ctx context.Context, issuerURL string) context.Con

func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
client := http.DefaultClient
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
if c := getClient(ctx); c != nil {
client = c
}
return client.Do(req.WithContext(ctx))
Expand All @@ -102,12 +99,33 @@ type Provider struct {
authURL string
tokenURL string
userInfoURL string
jwksURL string
algorithms []string

// Raw claims returned by the server.
rawClaims []byte

remoteKeySet KeySet
// Guards all of the following fields.
mu sync.Mutex
// HTTP client specified from the initial NewProvider request. This is used
// when creating the common key set.
client *http.Client
// A key set that uses context.Background() and is shared between all code paths
// that don't have a convinent way of supplying a unique context.
commonRemoteKeySet KeySet
}

func (p *Provider) remoteKeySet() KeySet {
p.mu.Lock()
defer p.mu.Unlock()
if p.commonRemoteKeySet == nil {
ctx := context.Background()
if p.client != nil {
ctx = ClientContext(ctx, p.client)
}
p.commonRemoteKeySet = NewRemoteKeySet(ctx, p.jwksURL)
}
return p.commonRemoteKeySet
}

type providerJSON struct {
Expand Down Expand Up @@ -167,12 +185,13 @@ type ProviderConfig struct {
// through discovery.
func (p *ProviderConfig) NewProvider(ctx context.Context) *Provider {
return &Provider{
issuer: p.IssuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
algorithms: p.Algorithms,
remoteKeySet: NewRemoteKeySet(cloneContext(ctx), p.JWKSURL),
issuer: p.IssuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL,
algorithms: p.Algorithms,
client: getClient(ctx),
}
}

Expand Down Expand Up @@ -221,13 +240,14 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
}
}
return &Provider{
issuer: issuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
algorithms: algs,
rawClaims: body,
remoteKeySet: NewRemoteKeySet(cloneContext(ctx), p.JWKSURL),
issuer: issuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL,
algorithms: algs,
rawClaims: body,
client: getClient(ctx),
}, nil
}

Expand Down Expand Up @@ -317,7 +337,7 @@ func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource)
ct := resp.Header.Get("Content-Type")
mediaType, _, parseErr := mime.ParseMediaType(ct)
if parseErr == nil && mediaType == "application/jwt" {
payload, err := p.remoteKeySet.VerifySignature(ctx, string(body))
payload, err := p.remoteKeySet().VerifySignature(ctx, string(body))
if err != nil {
return nil, fmt.Errorf("oidc: invalid userinfo jwt signature %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,15 @@ func TestNewProvider(t *testing.T) {
}
}

func TestCloneContext(t *testing.T) {
func TestGetClient(t *testing.T) {
ctx := context.Background()
if _, ok := cloneContext(ctx).Value(oauth2.HTTPClient).(*http.Client); ok {
if c := getClient(ctx); c != nil {
t.Errorf("cloneContext(): expected no *http.Client from empty context")
}

c := &http.Client{}
ctx = ClientContext(ctx, c)
if got, ok := cloneContext(ctx).Value(oauth2.HTTPClient).(*http.Client); !ok || c != got {
if got := getClient(ctx); got == nil || c != got {
t.Errorf("cloneContext(): expected *http.Client from context")
}
}
Expand Down
16 changes: 15 additions & 1 deletion oidc/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,30 @@ type Config struct {
InsecureSkipSignatureCheck bool
}

// VerifierContext returns an IDTokenVerifier that uses the provider's key set to
// verify JWTs. As opposed to Verifier, the context is used for all requests to
// the upstream JWKs endpoint.
func (p *Provider) VerifierContext(ctx context.Context, config *Config) *IDTokenVerifier {
return p.newVerifier(NewRemoteKeySet(ctx, p.jwksURL), config)
}

// Verifier returns an IDTokenVerifier that uses the provider's key set to verify JWTs.
//
// The returned verifier uses a background context for all requests to the upstream
// JWKs endpoint. To control that context, use VerifierContext instead.
func (p *Provider) Verifier(config *Config) *IDTokenVerifier {
return p.newVerifier(p.remoteKeySet(), config)
}

func (p *Provider) newVerifier(keySet KeySet, config *Config) *IDTokenVerifier {
if len(config.SupportedSigningAlgs) == 0 && len(p.algorithms) > 0 {
// Make a copy so we don't modify the config values.
cp := &Config{}
*cp = *config
cp.SupportedSigningAlgs = p.algorithms
config = cp
}
return NewVerifier(p.issuer, p.remoteKeySet, config)
return NewVerifier(p.issuer, keySet, config)
}

func parseJWT(p string) ([]byte, error) {
Expand Down

0 comments on commit 2936eb3

Please sign in to comment.