Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oidc: restrict use of context.Background() #364

Merged
merged 1 commit into from Feb 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion oidc/jwks.go
Expand Up @@ -60,7 +60,7 @@ func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time)
if now == nil {
now = time.Now
}
return &RemoteKeySet{jwksURL: jwksURL, ctx: cloneContext(ctx), now: now}
return &RemoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
}

// RemoteKeySet is a KeySet implementation that validates JSON web tokens against
Expand Down
66 changes: 43 additions & 23 deletions oidc/oidc.go
Expand Up @@ -14,6 +14,7 @@ import (
"mime"
"net/http"
"strings"
"sync"
"time"

"golang.org/x/oauth2"
Expand Down Expand Up @@ -58,15 +59,11 @@ func ClientContext(ctx context.Context, client *http.Client) context.Context {
return context.WithValue(ctx, oauth2.HTTPClient, client)
}

// cloneContext copies a context's bag-of-values into a new context that isn't
// associated with its cancellation. This is used to initialize remote keys sets
// which run in the background and aren't associated with the initial context.
func cloneContext(ctx context.Context) context.Context {
cp := context.Background()
func getClient(ctx context.Context) *http.Client {
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
cp = ClientContext(cp, c)
return c
}
return cp
return nil
}

// InsecureIssuerURLContext allows discovery to work when the issuer_url reported
Expand All @@ -90,7 +87,7 @@ func InsecureIssuerURLContext(ctx context.Context, issuerURL string) context.Con

func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
client := http.DefaultClient
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
if c := getClient(ctx); c != nil {
client = c
}
return client.Do(req.WithContext(ctx))
Expand All @@ -102,12 +99,33 @@ type Provider struct {
authURL string
tokenURL string
userInfoURL string
jwksURL string
algorithms []string

// Raw claims returned by the server.
rawClaims []byte

remoteKeySet KeySet
// Guards all of the following fields.
mu sync.Mutex
// HTTP client specified from the initial NewProvider request. This is used
// when creating the common key set.
client *http.Client
// A key set that uses context.Background() and is shared between all code paths
// that don't have a convinent way of supplying a unique context.
commonRemoteKeySet KeySet
}

func (p *Provider) remoteKeySet() KeySet {
p.mu.Lock()
defer p.mu.Unlock()
if p.commonRemoteKeySet == nil {
ctx := context.Background()
if p.client != nil {
ctx = ClientContext(ctx, p.client)
}
p.commonRemoteKeySet = NewRemoteKeySet(ctx, p.jwksURL)
}
return p.commonRemoteKeySet
}

type providerJSON struct {
Expand Down Expand Up @@ -167,12 +185,13 @@ type ProviderConfig struct {
// through discovery.
func (p *ProviderConfig) NewProvider(ctx context.Context) *Provider {
return &Provider{
issuer: p.IssuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
algorithms: p.Algorithms,
remoteKeySet: NewRemoteKeySet(cloneContext(ctx), p.JWKSURL),
issuer: p.IssuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL,
algorithms: p.Algorithms,
client: getClient(ctx),
}
}

Expand Down Expand Up @@ -221,13 +240,14 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
}
}
return &Provider{
issuer: issuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
algorithms: algs,
rawClaims: body,
remoteKeySet: NewRemoteKeySet(cloneContext(ctx), p.JWKSURL),
issuer: issuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL,
algorithms: algs,
rawClaims: body,
client: getClient(ctx),
}, nil
}

Expand Down Expand Up @@ -317,7 +337,7 @@ func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource)
ct := resp.Header.Get("Content-Type")
mediaType, _, parseErr := mime.ParseMediaType(ct)
if parseErr == nil && mediaType == "application/jwt" {
payload, err := p.remoteKeySet.VerifySignature(ctx, string(body))
payload, err := p.remoteKeySet().VerifySignature(ctx, string(body))
if err != nil {
return nil, fmt.Errorf("oidc: invalid userinfo jwt signature %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions oidc/oidc_test.go
Expand Up @@ -326,15 +326,15 @@ func TestNewProvider(t *testing.T) {
}
}

func TestCloneContext(t *testing.T) {
func TestGetClient(t *testing.T) {
ctx := context.Background()
if _, ok := cloneContext(ctx).Value(oauth2.HTTPClient).(*http.Client); ok {
if c := getClient(ctx); c != nil {
t.Errorf("cloneContext(): expected no *http.Client from empty context")
}

c := &http.Client{}
ctx = ClientContext(ctx, c)
if got, ok := cloneContext(ctx).Value(oauth2.HTTPClient).(*http.Client); !ok || c != got {
if got := getClient(ctx); got == nil || c != got {
t.Errorf("cloneContext(): expected *http.Client from context")
}
}
Expand Down
16 changes: 15 additions & 1 deletion oidc/verify.go
Expand Up @@ -120,16 +120,30 @@ type Config struct {
InsecureSkipSignatureCheck bool
}

// VerifierContext returns an IDTokenVerifier that uses the provider's key set to
// verify JWTs. As opposed to Verifier, the context is used for all requests to
// the upstream JWKs endpoint.
func (p *Provider) VerifierContext(ctx context.Context, config *Config) *IDTokenVerifier {
return p.newVerifier(NewRemoteKeySet(ctx, p.jwksURL), config)
}

// Verifier returns an IDTokenVerifier that uses the provider's key set to verify JWTs.
//
// The returned verifier uses a background context for all requests to the upstream
// JWKs endpoint. To control that context, use VerifierContext instead.
func (p *Provider) Verifier(config *Config) *IDTokenVerifier {
return p.newVerifier(p.remoteKeySet(), config)
}

func (p *Provider) newVerifier(keySet KeySet, config *Config) *IDTokenVerifier {
if len(config.SupportedSigningAlgs) == 0 && len(p.algorithms) > 0 {
// Make a copy so we don't modify the config values.
cp := &Config{}
*cp = *config
cp.SupportedSigningAlgs = p.algorithms
config = cp
}
return NewVerifier(p.issuer, p.remoteKeySet, config)
return NewVerifier(p.issuer, keySet, config)
}

func parseJWT(p string) ([]byte, error) {
Expand Down