Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

internal: Refactor cert logic to support OAuth2 token exchange over mTLS #1886

Merged
merged 15 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

go run ../internal/ecp/test_signer.go testdata/rsa2048bit.pem
go run ../ecp/test_signer.go testdata/rsa2048bit.pem
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

go run ../internal/ecp/test_signer.go testdata/invalid.pem
go run ../ecp/test_signer.go testdata/invalid.pem
55 changes: 54 additions & 1 deletion internal/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ package internal

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"time"

"golang.org/x/oauth2"
"google.golang.org/api/internal/impersonate"
Expand Down Expand Up @@ -80,8 +84,25 @@ const (
// - Otherwise, executes standard OAuth 2.0 flow
// More details: google.aip.dev/auth/4111
func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*google.Credentials, error) {
var params google.CredentialsParams
params.Scopes = ds.GetScopes()

// Determine configurations for the OAuth2 transport, which is separate from the API transport.
// The OAuth2 transport and endpoint will be configured for mTLS if applicable.
clientCertSource, oauth2Endpoint, err := GetClientCertificateSourceAndEndpoint(oauth2DialSettings(ds))
if err != nil {
return nil, err
}
params.TokenURL = oauth2Endpoint
if clientCertSource != nil {
tlsConfig := &tls.Config{
GetClientCertificate: clientCertSource,
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, customHTTPClient(tlsConfig))
}

// By default, a standard OAuth 2.0 token source is created
cred, err := google.CredentialsFromJSON(ctx, data, ds.GetScopes()...)
cred, err := google.CredentialsFromJSONWithParams(ctx, data, params)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -157,3 +178,35 @@ func impersonateCredentials(ctx context.Context, creds *google.Credentials, ds *
ProjectID: creds.ProjectID,
}, nil
}

// oauth2DialSettings returns the settings to be used by the OAuth2 transport, which is separate from the API transport.
func oauth2DialSettings(ds *DialSettings) *DialSettings {
var ods DialSettings
ods.DefaultEndpoint = google.Endpoint.TokenURL
ods.DefaultMTLSEndpoint = google.MTLSTokenURL
ods.ClientCertSource = ds.ClientCertSource
return &ods
}

// customHTTPClient constructs an HTTPClient using the provided tlsConfig, to support mTLS.
func customHTTPClient(tlsConfig *tls.Config) *http.Client {
trans := baseTransport()
trans.TLSClientConfig = tlsConfig
return &http.Client{Transport: trans}
}

func baseTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
13 changes: 7 additions & 6 deletions transport/internal/dca/dca.go → internal/dca.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
//
// This package is not intended for use by end developers. Use the
// google.golang.org/api/option package to configure API clients.
package dca

// Package internal supports the options and transport packages.
package internal

import (
"net/url"
"os"
"strings"

"google.golang.org/api/internal"
"google.golang.org/api/transport/cert"
"google.golang.org/api/internal/cert"
)

const (
Expand All @@ -43,7 +44,7 @@ const (
// GetClientCertificateSourceAndEndpoint is a convenience function that invokes
// getClientCertificateSource and getEndpoint sequentially and returns the client
// cert source and endpoint as a tuple.
func GetClientCertificateSourceAndEndpoint(settings *internal.DialSettings) (cert.Source, string, error) {
func GetClientCertificateSourceAndEndpoint(settings *DialSettings) (cert.Source, string, error) {
clientCertSource, err := getClientCertificateSource(settings)
if err != nil {
return nil, "", err
Expand All @@ -65,7 +66,7 @@ func GetClientCertificateSourceAndEndpoint(settings *internal.DialSettings) (cer
// Important Note: For now, the environment variable GOOGLE_API_USE_CLIENT_CERTIFICATE
// must be set to "true" to allow certificate to be used (including user provided
// certificates). For details, see AIP-4114.
func getClientCertificateSource(settings *internal.DialSettings) (cert.Source, error) {
func getClientCertificateSource(settings *DialSettings) (cert.Source, error) {
if !isClientCertificateEnabled() {
return nil, nil
} else if settings.ClientCertSource != nil {
Expand Down Expand Up @@ -94,7 +95,7 @@ func isClientCertificateEnabled() bool {
// URL (ex. https://...), then the user-provided address will be merged into
// the default endpoint. For example, WithEndpoint("myhost:8000") and
// WithDefaultEndpoint("https://foo.com/bar/baz") will return "https://myhost:8080/bar/baz"
func getEndpoint(settings *internal.DialSettings, clientCertSource cert.Source) (string, error) {
func getEndpoint(settings *DialSettings, clientCertSource cert.Source) (string, error) {
if settings.Endpoint == "" {
mtlsMode := getMTLSMode()
if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
Expand Down
8 changes: 3 additions & 5 deletions transport/internal/dca/dca_test.go → internal/dca_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dca
package internal

import (
"testing"

"crypto/tls"

"google.golang.org/api/internal"
)

func TestGetEndpoint(t *testing.T) {
Expand Down Expand Up @@ -51,7 +49,7 @@ func TestGetEndpoint(t *testing.T) {
}

for _, tc := range testCases {
got, err := getEndpoint(&internal.DialSettings{
got, err := getEndpoint(&DialSettings{
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
}, nil)
Expand Down Expand Up @@ -106,7 +104,7 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {
}

for _, tc := range testCases {
got, err := getEndpoint(&internal.DialSettings{
got, err := getEndpoint(&DialSettings{
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
Expand Down
File renamed without changes.
3 changes: 1 addition & 2 deletions transport/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"golang.org/x/oauth2"
"google.golang.org/api/internal"
"google.golang.org/api/option"
"google.golang.org/api/transport/internal/dca"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
grpcgoogle "google.golang.org/grpc/credentials/google"
Expand Down Expand Up @@ -123,7 +122,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
if o.GRPCConn != nil {
return o.GRPCConn, nil
}
clientCertSource, endpoint, err := dca.GetClientCertificateSourceAndEndpoint(o)
clientCertSource, endpoint, err := internal.GetClientCertificateSourceAndEndpoint(o)
if err != nil {
return nil, err
}
Expand Down
5 changes: 2 additions & 3 deletions transport/http/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ import (
"golang.org/x/oauth2"
"google.golang.org/api/googleapi/transport"
"google.golang.org/api/internal"
"google.golang.org/api/internal/cert"
"google.golang.org/api/option"
"google.golang.org/api/transport/cert"
"google.golang.org/api/transport/http/internal/propagation"
"google.golang.org/api/transport/internal/dca"
)

// NewClient returns an HTTP client for use communicating with a Google cloud
Expand All @@ -34,7 +33,7 @@ func NewClient(ctx context.Context, opts ...option.ClientOption) (*http.Client,
if err != nil {
return nil, "", err
}
clientCertSource, endpoint, err := dca.GetClientCertificateSourceAndEndpoint(settings)
clientCertSource, endpoint, err := internal.GetClientCertificateSourceAndEndpoint(settings)
if err != nil {
return nil, "", err
}
Expand Down