Skip to content

Commit

Permalink
balancer: support injection of per-call metadata from LB policies (gr…
Browse files Browse the repository at this point in the history
  • Loading branch information
easwars committed Dec 21, 2022
1 parent 95e55f9 commit e9886ef
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 21 deletions.
8 changes: 8 additions & 0 deletions balancer/balancer.go
Expand Up @@ -279,6 +279,14 @@ type PickResult struct {
// type, Done may not be called. May be nil if the balancer does not wish
// to be notified when the RPC completes.
Done func(DoneInfo)

// Metadata provides a way for LB policies to inject arbitrary per-call
// metadata. Any metadata returned here will be merged with existing
// metadata added by the client application.
//
// LB policies with child policies are responsible for propagating metadata
// injected by their children to the ClientConn, as part of Pick().
Metatada metadata.MD
}

// TransientFailureError returns e. It exists for backward compatibility and
Expand Down
2 changes: 1 addition & 1 deletion clientconn.go
Expand Up @@ -934,7 +934,7 @@ func (cc *ClientConn) healthCheckConfig() *healthCheckConfig {
return cc.sc.healthCheckConfig
}

func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, func(balancer.DoneInfo), error) {
func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) {
return cc.blockingpicker.pick(ctx, failfast, balancer.PickInfo{
Ctx: ctx,
FullMethodName: method,
Expand Down
28 changes: 17 additions & 11 deletions picker_wrapper.go
Expand Up @@ -58,12 +58,18 @@ func (pw *pickerWrapper) updatePicker(p balancer.Picker) {
pw.mu.Unlock()
}

func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) func(balancer.DoneInfo) {
// doneChannelzWrapper performs the following:
// - increments the calls started channelz counter
// - wraps the done function in the passed in result to increment the calls
// failed or calls succeeded channelz counter before invoking the actual
// done function.
func doneChannelzWrapper(acw *acBalancerWrapper, result *balancer.PickResult) {
acw.mu.Lock()
ac := acw.ac
acw.mu.Unlock()
ac.incrCallsStarted()
return func(b balancer.DoneInfo) {
done := result.Done
result.Done = func(b balancer.DoneInfo) {
if b.Err != nil && b.Err != io.EOF {
ac.incrCallsFailed()
} else {
Expand All @@ -82,15 +88,15 @@ func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) f
// - the current picker returns other errors and failfast is false.
// - the subConn returned by the current picker is not READY
// When one of these situations happens, pick blocks until the picker gets updated.
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, func(balancer.DoneInfo), error) {
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, balancer.PickResult, error) {
var ch chan struct{}

var lastPickErr error
for {
pw.mu.Lock()
if pw.done {
pw.mu.Unlock()
return nil, nil, ErrClientConnClosing
return nil, balancer.PickResult{}, ErrClientConnClosing
}

if pw.picker == nil {
Expand All @@ -111,9 +117,9 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
switch ctx.Err() {
case context.DeadlineExceeded:
return nil, nil, status.Error(codes.DeadlineExceeded, errStr)
return nil, balancer.PickResult{}, status.Error(codes.DeadlineExceeded, errStr)
case context.Canceled:
return nil, nil, status.Error(codes.Canceled, errStr)
return nil, balancer.PickResult{}, status.Error(codes.Canceled, errStr)
}
case <-ch:
}
Expand All @@ -125,7 +131,6 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
pw.mu.Unlock()

pickResult, err := p.Pick(info)

if err != nil {
if err == balancer.ErrNoSubConnAvailable {
continue
Expand All @@ -136,15 +141,15 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err)
}
return nil, nil, dropError{error: err}
return nil, balancer.PickResult{}, dropError{error: err}
}
// For all other errors, wait for ready RPCs should block and other
// RPCs should fail with unavailable.
if !failfast {
lastPickErr = err
continue
}
return nil, nil, status.Error(codes.Unavailable, err.Error())
return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error())
}

acw, ok := pickResult.SubConn.(*acBalancerWrapper)
Expand All @@ -154,9 +159,10 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
if t := acw.getAddrConn().getReadyTransport(); t != nil {
if channelz.IsOn() {
return t, doneChannelzWrapper(acw, pickResult.Done), nil
doneChannelzWrapper(acw, &pickResult)
return t, pickResult, nil
}
return t, pickResult.Done, nil
return t, pickResult, nil
}
if pickResult.Done != nil {
// Calling done with nil error, no bytes sent and no bytes received.
Expand Down
37 changes: 28 additions & 9 deletions stream.go
Expand Up @@ -438,7 +438,7 @@ func (a *csAttempt) getTransport() error {
cs := a.cs

var err error
a.t, a.done, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
a.t, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
if err != nil {
if de, ok := err.(dropError); ok {
err = de.error
Expand All @@ -455,6 +455,25 @@ func (a *csAttempt) getTransport() error {
func (a *csAttempt) newStream() error {
cs := a.cs
cs.callHdr.PreviousAttempts = cs.numRetries

// Merge metadata stored in PickResult, if any, with existing call metadata.
// It is safe to overwrite the csAttempt's context here, since all state
// maintained in it are local to the attempt. When the attempt has to be
// retried, a new instance of csAttempt will be created.
if a.pickResult.Metatada != nil {
// We currently do not have a function it the metadata package which
// merges given metadata with existing metadata in a context. Existing
// function `AppendToOutgoingContext()` takes a variadic argument of key
// value pairs.
//
// TODO: Make it possible to retrieve key value pairs from metadata.MD
// in a form passable to AppendToOutgoingContext(), or create a version
// of AppendToOutgoingContext() that accepts a metadata.MD.
md, _ := metadata.FromOutgoingContext(a.ctx)
md = metadata.Join(md, a.pickResult.Metatada)
a.ctx = metadata.NewOutgoingContext(a.ctx, md)
}

s, err := a.t.NewStream(a.ctx, cs.callHdr)
if err != nil {
nse, ok := err.(*transport.NewStreamError)
Expand Down Expand Up @@ -529,12 +548,12 @@ type clientStream struct {
// csAttempt implements a single transport stream attempt within a
// clientStream.
type csAttempt struct {
ctx context.Context
cs *clientStream
t transport.ClientTransport
s *transport.Stream
p *parser
done func(balancer.DoneInfo)
ctx context.Context
cs *clientStream
t transport.ClientTransport
s *transport.Stream
p *parser
pickResult balancer.PickResult

finished bool
dc Decompressor
Expand Down Expand Up @@ -1103,12 +1122,12 @@ func (a *csAttempt) finish(err error) {
tr = a.s.Trailer()
}

if a.done != nil {
if a.pickResult.Done != nil {
br := false
if a.s != nil {
br = a.s.BytesReceived()
}
a.done(balancer.DoneInfo{
a.pickResult.Done(balancer.DoneInfo{
Err: err,
Trailer: tr,
BytesSent: a.s != nil,
Expand Down
136 changes: 136 additions & 0 deletions test/balancer_test.go
Expand Up @@ -866,3 +866,139 @@ func (s) TestAuthorityInBuildOptions(t *testing.T) {
})
}
}

// wrappedPickFirstBalancerBuilder builds a custom balancer which wraps an
// underlying pick_first balancer.
type wrappedPickFirstBalancerBuilder struct {
name string
}

func (*wrappedPickFirstBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
builder := balancer.Get(grpc.PickFirstBalancerName)
wpfb := &wrappedPickFirstBalancer{
ClientConn: cc,
}
pf := builder.Build(wpfb, opts)
wpfb.Balancer = pf
return wpfb
}

func (wbb *wrappedPickFirstBalancerBuilder) Name() string {
return wbb.name
}

// wrappedPickFirstBalancer contains a pick_first balancer and forwards all
// calls from the ClientConn to it. For state updates from the pick_first
// balancer, it creates a custom picker which injects arbitrary metadata on a
// per-call basis.
type wrappedPickFirstBalancer struct {
balancer.Balancer
balancer.ClientConn
}

func (wb *wrappedPickFirstBalancer) UpdateState(state balancer.State) {
state.Picker = &wrappedPicker{p: state.Picker}
wb.ClientConn.UpdateState(state)
}

const (
metadataHeaderInjectedByBalancer = "metadata-header-injected-by-balancer"
metadataHeaderInjectedByApplication = "metadata-header-injected-by-application"
metadataValueInjectedByBalancer = "metadata-value-injected-by-balancer"
metadataValueInjectedByApplication = "metadata-value-injected-by-application"
)

// wrappedPicker wraps the picker returned by the pick_first
type wrappedPicker struct {
p balancer.Picker
}

func (wp *wrappedPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
res, err := wp.p.Pick(info)
if err != nil {
return balancer.PickResult{}, err
}

if res.Metatada == nil {
res.Metatada = metadata.Pairs(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
} else {
res.Metatada.Append(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
}
return res, nil
}

// TestMetadataInPickResult tests the scenario where an LB policy inject
// arbitrary metadata on a per-call basis and verifies that the injected
// metadata makes it all the way to the server RPC handler.
func (s) TestMetadataInPickResult(t *testing.T) {
t.Log("Starting test backend...")
mdChan := make(chan metadata.MD, 1)
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
md, _ := metadata.FromIncomingContext(ctx)
select {
case mdChan <- md:
case <-ctx.Done():
return nil, ctx.Err()
}
return &testpb.Empty{}, nil
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Starting test backend: %v", err)
}
defer ss.Stop()
t.Logf("Started test backend at %q", ss.Address)

name := t.Name() + "wrappedPickFirstBalancer"
t.Logf("Registering test balancer with name %q...", name)
b := &wrappedPickFirstBalancerBuilder{name: t.Name() + "wrappedPickFirstBalancer"}
balancer.Register(b)

t.Log("Creating ClientConn to test backend...")
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address}}})
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, b.Name())),
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("grpc.Dial(): %v", err)
}
defer cc.Close()
tc := testpb.NewTestServiceClient(cc)

t.Log("Making EmptyCall() RPC with custom metadata...")
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
md := metadata.Pairs(metadataHeaderInjectedByApplication, metadataValueInjectedByApplication)
ctx = metadata.NewOutgoingContext(ctx, md)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() RPC: %v", err)
}
t.Log("EmptyCall() RPC succeeded")

t.Log("Waiting for custom metadata to be received at the test backend...")
var gotMD metadata.MD
select {
case gotMD = <-mdChan:
case <-ctx.Done():
t.Fatalf("Timed out waiting for custom metadata to be received at the test backend")
}

t.Log("Verifying custom metadata added by the client application is received at the test backend...")
wantMDVal := []string{metadataValueInjectedByApplication}
gotMDVal := gotMD.Get(metadataHeaderInjectedByApplication)
if !cmp.Equal(gotMDVal, wantMDVal) {
t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
}

t.Log("Verifying custom metadata added by the LB policy is received at the test backend...")
wantMDVal = []string{metadataValueInjectedByBalancer}
gotMDVal = gotMD.Get(metadataHeaderInjectedByBalancer)
if !cmp.Equal(gotMDVal, wantMDVal) {
t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
}
}

0 comments on commit e9886ef

Please sign in to comment.