Skip to content

Commit

Permalink
internal: Refactor cert logic to support OAuth2 token exchange over m…
Browse files Browse the repository at this point in the history
…TLS (#1886)

With Context Aware Access enabled, users must use the endpoint "https://oauth2.mtls.googleapis.com/token" for token exchange. This PR refactors the cert logic currently used by the transport layer to be reused by the internal credentials layer in order to inject an mTLS-enabled HTTPClient (via the "context" mechanism) for use by the OAuth2 transport (along with the mTLS OAuth2 endpoint if so).
  • Loading branch information
andyrzhao committed Mar 9, 2023
1 parent 8d4d70d commit 225fa6b
Show file tree
Hide file tree
Showing 22 changed files with 69 additions and 19 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

go run ../internal/ecp/test_signer.go testdata/rsa2048bit.pem
go run ../ecp/test_signer.go testdata/rsa2048bit.pem
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

go run ../internal/ecp/test_signer.go testdata/invalid.pem
go run ../ecp/test_signer.go testdata/invalid.pem
File renamed without changes.
55 changes: 54 additions & 1 deletion internal/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ package internal

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"time"

"golang.org/x/oauth2"
"google.golang.org/api/internal/impersonate"
Expand Down Expand Up @@ -80,8 +84,25 @@ const (
// - Otherwise, executes standard OAuth 2.0 flow
// More details: google.aip.dev/auth/4111
func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*google.Credentials, error) {
var params google.CredentialsParams
params.Scopes = ds.GetScopes()

// Determine configurations for the OAuth2 transport, which is separate from the API transport.
// The OAuth2 transport and endpoint will be configured for mTLS if applicable.
clientCertSource, oauth2Endpoint, err := GetClientCertificateSourceAndEndpoint(oauth2DialSettings(ds))
if err != nil {
return nil, err
}
params.TokenURL = oauth2Endpoint
if clientCertSource != nil {
tlsConfig := &tls.Config{
GetClientCertificate: clientCertSource,
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, customHTTPClient(tlsConfig))
}

// By default, a standard OAuth 2.0 token source is created
cred, err := google.CredentialsFromJSON(ctx, data, ds.GetScopes()...)
cred, err := google.CredentialsFromJSONWithParams(ctx, data, params)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -157,3 +178,35 @@ func impersonateCredentials(ctx context.Context, creds *google.Credentials, ds *
ProjectID: creds.ProjectID,
}, nil
}

// oauth2DialSettings returns the settings to be used by the OAuth2 transport, which is separate from the API transport.
func oauth2DialSettings(ds *DialSettings) *DialSettings {
var ods DialSettings
ods.DefaultEndpoint = google.Endpoint.TokenURL
ods.DefaultMTLSEndpoint = google.MTLSTokenURL
ods.ClientCertSource = ds.ClientCertSource
return &ods
}

// customHTTPClient constructs an HTTPClient using the provided tlsConfig, to support mTLS.
func customHTTPClient(tlsConfig *tls.Config) *http.Client {
trans := baseTransport()
trans.TLSClientConfig = tlsConfig
return &http.Client{Transport: trans}
}

func baseTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
13 changes: 7 additions & 6 deletions transport/internal/dca/dca.go → internal/dca.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
//
// This package is not intended for use by end developers. Use the
// google.golang.org/api/option package to configure API clients.
package dca

// Package internal supports the options and transport packages.
package internal

import (
"net/url"
"os"
"strings"

"google.golang.org/api/internal"
"google.golang.org/api/transport/cert"
"google.golang.org/api/internal/cert"
)

const (
Expand All @@ -43,7 +44,7 @@ const (
// GetClientCertificateSourceAndEndpoint is a convenience function that invokes
// getClientCertificateSource and getEndpoint sequentially and returns the client
// cert source and endpoint as a tuple.
func GetClientCertificateSourceAndEndpoint(settings *internal.DialSettings) (cert.Source, string, error) {
func GetClientCertificateSourceAndEndpoint(settings *DialSettings) (cert.Source, string, error) {
clientCertSource, err := getClientCertificateSource(settings)
if err != nil {
return nil, "", err
Expand All @@ -65,7 +66,7 @@ func GetClientCertificateSourceAndEndpoint(settings *internal.DialSettings) (cer
// Important Note: For now, the environment variable GOOGLE_API_USE_CLIENT_CERTIFICATE
// must be set to "true" to allow certificate to be used (including user provided
// certificates). For details, see AIP-4114.
func getClientCertificateSource(settings *internal.DialSettings) (cert.Source, error) {
func getClientCertificateSource(settings *DialSettings) (cert.Source, error) {
if !isClientCertificateEnabled() {
return nil, nil
} else if settings.ClientCertSource != nil {
Expand Down Expand Up @@ -94,7 +95,7 @@ func isClientCertificateEnabled() bool {
// URL (ex. https://...), then the user-provided address will be merged into
// the default endpoint. For example, WithEndpoint("myhost:8000") and
// WithDefaultEndpoint("https://foo.com/bar/baz") will return "https://myhost:8080/bar/baz"
func getEndpoint(settings *internal.DialSettings, clientCertSource cert.Source) (string, error) {
func getEndpoint(settings *DialSettings, clientCertSource cert.Source) (string, error) {
if settings.Endpoint == "" {
mtlsMode := getMTLSMode()
if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
Expand Down
8 changes: 3 additions & 5 deletions transport/internal/dca/dca_test.go → internal/dca_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dca
package internal

import (
"testing"

"crypto/tls"

"google.golang.org/api/internal"
)

func TestGetEndpoint(t *testing.T) {
Expand Down Expand Up @@ -51,7 +49,7 @@ func TestGetEndpoint(t *testing.T) {
}

for _, tc := range testCases {
got, err := getEndpoint(&internal.DialSettings{
got, err := getEndpoint(&DialSettings{
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
}, nil)
Expand Down Expand Up @@ -106,7 +104,7 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {
}

for _, tc := range testCases {
got, err := getEndpoint(&internal.DialSettings{
got, err := getEndpoint(&DialSettings{
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
Expand Down
File renamed without changes.
3 changes: 1 addition & 2 deletions transport/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"golang.org/x/oauth2"
"google.golang.org/api/internal"
"google.golang.org/api/option"
"google.golang.org/api/transport/internal/dca"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
grpcgoogle "google.golang.org/grpc/credentials/google"
Expand Down Expand Up @@ -123,7 +122,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
if o.GRPCConn != nil {
return o.GRPCConn, nil
}
clientCertSource, endpoint, err := dca.GetClientCertificateSourceAndEndpoint(o)
clientCertSource, endpoint, err := internal.GetClientCertificateSourceAndEndpoint(o)
if err != nil {
return nil, err
}
Expand Down
5 changes: 2 additions & 3 deletions transport/http/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ import (
"golang.org/x/oauth2"
"google.golang.org/api/googleapi/transport"
"google.golang.org/api/internal"
"google.golang.org/api/internal/cert"
"google.golang.org/api/option"
"google.golang.org/api/transport/cert"
"google.golang.org/api/transport/http/internal/propagation"
"google.golang.org/api/transport/internal/dca"
)

// NewClient returns an HTTP client for use communicating with a Google cloud
Expand All @@ -34,7 +33,7 @@ func NewClient(ctx context.Context, opts ...option.ClientOption) (*http.Client,
if err != nil {
return nil, "", err
}
clientCertSource, endpoint, err := dca.GetClientCertificateSourceAndEndpoint(settings)
clientCertSource, endpoint, err := internal.GetClientCertificateSourceAndEndpoint(settings)
if err != nil {
return nil, "", err
}
Expand Down

0 comments on commit 225fa6b

Please sign in to comment.