Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

balancer: support injection of per-call metadata from LB policies #5853

Merged
merged 5 commits into from Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions balancer/balancer.go
Expand Up @@ -279,6 +279,14 @@ type PickResult struct {
// type, Done may not be called. May be nil if the balancer does not wish
// to be notified when the RPC completes.
Done func(DoneInfo)

// Metadata provides a way for LB policies to inject arbitrary per-call
// metadata. Any metadata returned here will be merged with existing
// metadata added by the client application.
//
// LB policies with child policies are responsible for propagating metadata
// injected by their children to the ClientConn, as part of Pick().
Metatada metadata.MD
}

// TransientFailureError returns e. It exists for backward compatibility and
Expand Down
2 changes: 1 addition & 1 deletion clientconn.go
Expand Up @@ -934,7 +934,7 @@ func (cc *ClientConn) healthCheckConfig() *healthCheckConfig {
return cc.sc.healthCheckConfig
}

func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, func(balancer.DoneInfo), error) {
func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) {
return cc.blockingpicker.pick(ctx, failfast, balancer.PickInfo{
Ctx: ctx,
FullMethodName: method,
Expand Down
28 changes: 17 additions & 11 deletions picker_wrapper.go
Expand Up @@ -58,12 +58,18 @@ func (pw *pickerWrapper) updatePicker(p balancer.Picker) {
pw.mu.Unlock()
}

func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) func(balancer.DoneInfo) {
// doneChannelzWrapper performs the following:
// - increments the calls started channelz counter
// - wraps the done function in the passed in result to increment the calls
// failed or calls succeeded channelz counter before invoking the actual
// done function.
func doneChannelzWrapper(acw *acBalancerWrapper, result *balancer.PickResult) {
acw.mu.Lock()
ac := acw.ac
acw.mu.Unlock()
ac.incrCallsStarted()
return func(b balancer.DoneInfo) {
done := result.Done
result.Done = func(b balancer.DoneInfo) {
if b.Err != nil && b.Err != io.EOF {
ac.incrCallsFailed()
} else {
Expand All @@ -82,15 +88,15 @@ func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) f
// - the current picker returns other errors and failfast is false.
// - the subConn returned by the current picker is not READY
// When one of these situations happens, pick blocks until the picker gets updated.
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, func(balancer.DoneInfo), error) {
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, balancer.PickResult, error) {
var ch chan struct{}

var lastPickErr error
for {
pw.mu.Lock()
if pw.done {
pw.mu.Unlock()
return nil, nil, ErrClientConnClosing
return nil, balancer.PickResult{}, ErrClientConnClosing
}

if pw.picker == nil {
Expand All @@ -111,9 +117,9 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
switch ctx.Err() {
case context.DeadlineExceeded:
return nil, nil, status.Error(codes.DeadlineExceeded, errStr)
return nil, balancer.PickResult{}, status.Error(codes.DeadlineExceeded, errStr)
case context.Canceled:
return nil, nil, status.Error(codes.Canceled, errStr)
return nil, balancer.PickResult{}, status.Error(codes.Canceled, errStr)
}
case <-ch:
}
Expand All @@ -125,7 +131,6 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
pw.mu.Unlock()

pickResult, err := p.Pick(info)

if err != nil {
if err == balancer.ErrNoSubConnAvailable {
continue
Expand All @@ -136,15 +141,15 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err)
}
return nil, nil, dropError{error: err}
return nil, balancer.PickResult{}, dropError{error: err}
}
// For all other errors, wait for ready RPCs should block and other
// RPCs should fail with unavailable.
if !failfast {
lastPickErr = err
continue
}
return nil, nil, status.Error(codes.Unavailable, err.Error())
return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error())
}

acw, ok := pickResult.SubConn.(*acBalancerWrapper)
Expand All @@ -154,9 +159,10 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
if t := acw.getAddrConn().getReadyTransport(); t != nil {
if channelz.IsOn() {
return t, doneChannelzWrapper(acw, pickResult.Done), nil
doneChannelzWrapper(acw, &pickResult)
return t, pickResult, nil
}
return t, pickResult.Done, nil
return t, pickResult, nil
}
if pickResult.Done != nil {
// Calling done with nil error, no bytes sent and no bytes received.
Expand Down
36 changes: 26 additions & 10 deletions stream.go
Expand Up @@ -438,7 +438,7 @@ func (a *csAttempt) getTransport() error {
cs := a.cs

var err error
a.t, a.done, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
a.t, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
if err != nil {
if de, ok := err.(dropError); ok {
err = de.error
Expand All @@ -455,7 +455,23 @@ func (a *csAttempt) getTransport() error {
func (a *csAttempt) newStream() error {
cs := a.cs
cs.callHdr.PreviousAttempts = cs.numRetries
s, err := a.t.NewStream(a.ctx, cs.callHdr)

// Merge metadata stored in PickResult, if any, with existing call metadata.
ctx := a.ctx
if a.pickResult.Metatada != nil {
// We currently do not have a function it the metadata package which
// merges given metadata with existing metadata in a context. Existing
// function `AppendToOutgoingContext()` takes a variadic argument of key
// value pairs.
//
// TODO: Make it possible to retrieve key value pairs from metadata.MD
// in a form passable to AppendToOutgoingContext().
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add to this sentence: ..., or create a version of AppendToOutgoingContext that accepts a metadata.MD. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

md, _ := metadata.FromOutgoingContext(ctx)
md = metadata.Join(md, a.pickResult.Metatada)
ctx = metadata.NewOutgoingContext(a.ctx, md)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not update a.ctx in place instead? I'm a bit worried if we ever use it again later, it will be the wrong one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • clientStream.ctx is the RPC's context (from what I can see, it is used directly and not derived from it)
  • csAttempt.ctx is derived from the above context, and is used in the following places
    • getTransport(): is passed to Pick()
    • sending/receiving messages and when the attempt is closed, the context is passed to stats handler methods

I feel that by not modifying the csAttempt's context and instead deriving a new one (with metadata added by the LB policy) to be passed to NewStream(), we can ensure that we don't have any surprises when the attempt is retried. If we modify the csAttempt's context and the attempt is retried, and we call Pick() with the modified context, the LB policy's metadata might show up multiple times in there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each attempt creates a new csAttempt, so that should not be a concern. Technically the attempt is not what's retried, but the RPC (clientStream).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each attempt creates a new csAttempt, so that should not be a concern.

Oh yes, that's right.

I'm a bit worried if we ever use it again later, it will be the wrong one.

When will we ever use it again? From what I see, the context is used in getTransport() and while invoking stats handler methods. Even if we add few usages for this context, why would it be helpful to pass the extra metadata added by the LB policy?

I'm not opposed to updating the attempt's context in place with metadata from the LB policy. But, I thought it was safer to derive one from it and pass it to NewStream and continue using the existing one for everything else, since that will minimize the effect of this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we ever will use it again. But if ClientStream.Context() doesn't contain the true outgoing metadata, then that could potentially be considered a bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clientStream.Context() returns the context from the underlying transport stream.

if cs.attempt.s != nil {

And we use the modified context (with all outgoing metadata) when creating the transport stream. So, it looks like we are good here too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly it's still feeling wrong to not update the attempt's context when the metadata is updated by the LB policy. Unless there's a strong reason not to do this, then I'd rather do it and be safe vs. sorry. The concerns about it affecting subsequent attempts are a misunderstanding of the design of attempts (everything in an attempt should be local to that one attempt), and not a good reason to not update in place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

s, err := a.t.NewStream(ctx, cs.callHdr)
if err != nil {
nse, ok := err.(*transport.NewStreamError)
if !ok {
Expand Down Expand Up @@ -529,12 +545,12 @@ type clientStream struct {
// csAttempt implements a single transport stream attempt within a
// clientStream.
type csAttempt struct {
ctx context.Context
cs *clientStream
t transport.ClientTransport
s *transport.Stream
p *parser
done func(balancer.DoneInfo)
ctx context.Context
cs *clientStream
t transport.ClientTransport
s *transport.Stream
p *parser
pickResult balancer.PickResult

finished bool
dc Decompressor
Expand Down Expand Up @@ -1103,12 +1119,12 @@ func (a *csAttempt) finish(err error) {
tr = a.s.Trailer()
}

if a.done != nil {
if a.pickResult.Done != nil {
br := false
if a.s != nil {
br = a.s.BytesReceived()
}
a.done(balancer.DoneInfo{
a.pickResult.Done(balancer.DoneInfo{
Err: err,
Trailer: tr,
BytesSent: a.s != nil,
Expand Down
136 changes: 136 additions & 0 deletions test/balancer_test.go
Expand Up @@ -866,3 +866,139 @@ func (s) TestAuthorityInBuildOptions(t *testing.T) {
})
}
}

// wrappedPickFirstBalancerBuilder builds a custom balancer which wraps an
// underlying pick_first balancer.
type wrappedPickFirstBalancerBuilder struct {
name string
}

func (*wrappedPickFirstBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
builder := balancer.Get(grpc.PickFirstBalancerName)
wpfb := &wrappedPickFirstBalancer{
ClientConn: cc,
}
pf := builder.Build(wpfb, opts)
wpfb.Balancer = pf
return wpfb
}

func (wbb *wrappedPickFirstBalancerBuilder) Name() string {
return wbb.name
}

// wrappedPickFirstBalancer contains a pick_first balancer and forwards all
// calls from the ClientConn to it. For state updates from the pick_first
// balancer, it creates a custom picker which injects arbitrary metadata on a
// per-call basis.
type wrappedPickFirstBalancer struct {
balancer.Balancer
balancer.ClientConn
}

func (wb *wrappedPickFirstBalancer) UpdateState(state balancer.State) {
state.Picker = &wrappedPicker{p: state.Picker}
wb.ClientConn.UpdateState(state)
}

const (
metadataHeaderInjectedByBalancer = "metadata-header-injected-by-balancer"
metadataHeaderInjectedByApplication = "metadata-header-injected-by-application"
metadataValueInjectedByBalancer = "metadata-value-injected-by-balancer"
metadataValueInjectedByApplication = "metadata-value-injected-by-application"
)

// wrappedPicker wraps the picker returned by the pick_first
type wrappedPicker struct {
p balancer.Picker
}

func (wp *wrappedPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
res, err := wp.p.Pick(info)
if err != nil {
return balancer.PickResult{}, err
}

if res.Metatada == nil {
res.Metatada = metadata.Pairs(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
} else {
res.Metatada.Append(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
}
return res, nil
}

// TestMetadataInPickResult tests the scenario where an LB policy inject
// arbitrary metadata on a per-call basis and verifies that the injected
// metadata makes it all the way to the server RPC handler.
func (s) TestMetadataInPickResult(t *testing.T) {
t.Log("Starting test backend...")
mdChan := make(chan metadata.MD, 1)
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
md, _ := metadata.FromIncomingContext(ctx)
select {
case mdChan <- md:
case <-ctx.Done():
return nil, ctx.Err()
}
return &testpb.Empty{}, nil
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Starting test backend: %v", err)
}
defer ss.Stop()
t.Logf("Started test backend at %q", ss.Address)

name := t.Name() + "wrappedPickFirstBalancer"
t.Logf("Registering test balancer with name %q...", name)
b := &wrappedPickFirstBalancerBuilder{name: t.Name() + "wrappedPickFirstBalancer"}
balancer.Register(b)

t.Log("Creating ClientConn to test backend...")
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address}}})
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, b.Name())),
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("grpc.Dial(): %v", err)
}
defer cc.Close()
tc := testpb.NewTestServiceClient(cc)

t.Log("Making EmptyCall() RPC with custom metadata...")
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
md := metadata.Pairs(metadataHeaderInjectedByApplication, metadataValueInjectedByApplication)
ctx = metadata.NewOutgoingContext(ctx, md)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() RPC: %v", err)
}
t.Log("EmptyCall() RPC succeeded")

t.Log("Waiting for custom metadata to be received at the test backend...")
var gotMD metadata.MD
select {
case gotMD = <-mdChan:
case <-ctx.Done():
t.Fatalf("Timed out waiting for custom metadata to be received at the test backend")
}

t.Log("Verifying custom metadata added by the client application is received at the test backend...")
wantMDVal := []string{metadataValueInjectedByApplication}
gotMDVal := gotMD.Get(metadataHeaderInjectedByApplication)
if !cmp.Equal(gotMDVal, wantMDVal) {
t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
}

t.Log("Verifying custom metadata added by the LB policy is received at the test backend...")
wantMDVal = []string{metadataValueInjectedByBalancer}
gotMDVal = gotMD.Get(metadataHeaderInjectedByBalancer)
if !cmp.Equal(gotMDVal, wantMDVal) {
t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
}
}