Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials/alts: defer ALTS stream creation until handshake time #6077

Merged
50 changes: 33 additions & 17 deletions credentials/alts/internal/handshaker/handshaker.go
Expand Up @@ -138,14 +138,16 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
// and server options (server options struct does not exist now. When
// caller can provide endpoints, it should be created.

// altsHandshaker is used to complete a ALTS handshaking between client and
// altsHandshaker is used to complete an ALTS handshake between client and
// server. This handshaker talks to the ALTS handshaker service in the metadata
// server.
type altsHandshaker struct {
// RPC stream used to access the ALTS Handshaker service.
stream altsgrpc.HandshakerService_DoHandshakeClient
// the connection to the peer.
conn net.Conn
// a virtual connection to the ALTS handshaker service.
clientConn *grpc.ClientConn
// client handshake options.
clientOpts *ClientHandshakerOptions
// server handshake options.
Expand All @@ -154,39 +156,33 @@ type altsHandshaker struct {
side core.Side
}

// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// NewClientHandshaker creates a core.Handshaker that performs a client-side
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
// service in the metadata server.
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
stream: nil,
conn: c,
clientConn: conn,
clientOpts: opts,
side: core.ClientSide,
}, nil
}

// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// NewServerHandshaker creates a core.Handshaker that performs a server-side
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
// service in the metadata server.
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
stream: nil,
conn: c,
clientConn: conn,
serverOpts: opts,
side: core.ServerSide,
}, nil
}

// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire() {
Expand All @@ -198,6 +194,16 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
}

// TODO(matthewstevenson88): Change unit tests to use public APIs so
// that h.stream can unconditionally be set based on h.clientConn.
if h.stream == nil {
matthewstevenson88 marked this conversation as resolved.
Show resolved Hide resolved
easwars marked this conversation as resolved.
Show resolved Hide resolved
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
}
h.stream = stream
}

// Create target identities from service account list.
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
for _, account := range h.clientOpts.TargetServiceAccounts {
Expand Down Expand Up @@ -229,7 +235,7 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
return conn, authInfo, nil
}

// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire() {
Expand All @@ -241,6 +247,16 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
}

// TODO(matthewstevenson88): Change unit tests to use public APIs so
// that h.stream can unconditionally be set based on h.clientConn.
if h.stream == nil {
matthewstevenson88 marked this conversation as resolved.
Show resolved Hide resolved
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
}
h.stream = stream
}

p := make([]byte, frameLimit)
n, err := h.conn.Read(p)
if err != nil {
Expand Down
64 changes: 64 additions & 0 deletions credentials/alts/internal/handshaker/handshaker_test.go
Expand Up @@ -25,6 +25,8 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
grpc "google.golang.org/grpc"
core "google.golang.org/grpc/credentials/alts/internal"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
Expand Down Expand Up @@ -283,3 +285,65 @@ func (s) TestPeerNotResponding(t *testing.T) {
t.Errorf("ClientHandshake() = %v, want %v", got, want)
}
}

func (s) TestNewClientHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ClientHandshakerOptions{}
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
if err != nil {
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
}
expectedHs := &altsHandshaker{
stream: nil,
conn: conn,
clientConn: clientConn,
clientOpts: opts,
serverOpts: nil,
side: core.ClientSide,
}
cmpOpts := []cmp.Option{
cmp.AllowUnexported(altsHandshaker{}),
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
}
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
}
if hs.(*altsHandshaker).stream != nil {
t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream")
}
if hs.(*altsHandshaker).clientConn != clientConn {
t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn")
}
}

func (s) TestNewServerHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ServerHandshakerOptions{}
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
if err != nil {
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
}
expectedHs := &altsHandshaker{
stream: nil,
conn: conn,
clientConn: clientConn,
clientOpts: nil,
serverOpts: opts,
side: core.ServerSide,
}
cmpOpts := []cmp.Option{
cmp.AllowUnexported(altsHandshaker{}),
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
}
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
}
if hs.(*altsHandshaker).stream != nil {
t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream")
}
if hs.(*altsHandshaker).clientConn != clientConn {
t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn")
}
}