Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

balancer: support injection of per-call metadata from LB policies #5853

Merged
merged 5 commits into from Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions balancer/balancer.go
Expand Up @@ -279,6 +279,14 @@ type PickResult struct {
// type, Done may not be called. May be nil if the balancer does not wish
// to be notified when the RPC completes.
Done func(DoneInfo)

// Metadata provides a way for LB policies to inject arbitrary per-call
// metadata. Any metadata returned here will be merged with existing
// metadata added by the client application.
//
// LB policies with child policies are responsible for propagating metadata
// injected by their children to the ClientConn, as part of Pick().
Metatada metadata.MD
}

// TransientFailureError returns e. It exists for backward compatibility and
Expand Down
2 changes: 1 addition & 1 deletion clientconn.go
Expand Up @@ -934,7 +934,7 @@ func (cc *ClientConn) healthCheckConfig() *healthCheckConfig {
return cc.sc.healthCheckConfig
}

func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, func(balancer.DoneInfo), error) {
func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) {
return cc.blockingpicker.pick(ctx, failfast, balancer.PickInfo{
Ctx: ctx,
FullMethodName: method,
Expand Down
28 changes: 17 additions & 11 deletions picker_wrapper.go
Expand Up @@ -58,12 +58,18 @@ func (pw *pickerWrapper) updatePicker(p balancer.Picker) {
pw.mu.Unlock()
}

func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) func(balancer.DoneInfo) {
// doneChannelzWrapper performs the following:
// - increments the calls started channelz counter
// - wraps the done function in the passed in result to increment the calls
// failed or calls succeeded channelz counter before invoking the actual
// done function.
func doneChannelzWrapper(acw *acBalancerWrapper, result *balancer.PickResult) {
acw.mu.Lock()
ac := acw.ac
acw.mu.Unlock()
ac.incrCallsStarted()
return func(b balancer.DoneInfo) {
done := result.Done
result.Done = func(b balancer.DoneInfo) {
if b.Err != nil && b.Err != io.EOF {
ac.incrCallsFailed()
} else {
Expand All @@ -82,15 +88,15 @@ func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) f
// - the current picker returns other errors and failfast is false.
// - the subConn returned by the current picker is not READY
// When one of these situations happens, pick blocks until the picker gets updated.
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, func(balancer.DoneInfo), error) {
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, balancer.PickResult, error) {
var ch chan struct{}

var lastPickErr error
for {
pw.mu.Lock()
if pw.done {
pw.mu.Unlock()
return nil, nil, ErrClientConnClosing
return nil, balancer.PickResult{}, ErrClientConnClosing
}

if pw.picker == nil {
Expand All @@ -111,9 +117,9 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
switch ctx.Err() {
case context.DeadlineExceeded:
return nil, nil, status.Error(codes.DeadlineExceeded, errStr)
return nil, balancer.PickResult{}, status.Error(codes.DeadlineExceeded, errStr)
case context.Canceled:
return nil, nil, status.Error(codes.Canceled, errStr)
return nil, balancer.PickResult{}, status.Error(codes.Canceled, errStr)
}
case <-ch:
}
Expand All @@ -125,7 +131,6 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
pw.mu.Unlock()

pickResult, err := p.Pick(info)

if err != nil {
if err == balancer.ErrNoSubConnAvailable {
continue
Expand All @@ -136,15 +141,15 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err)
}
return nil, nil, dropError{error: err}
return nil, balancer.PickResult{}, dropError{error: err}
}
// For all other errors, wait for ready RPCs should block and other
// RPCs should fail with unavailable.
if !failfast {
lastPickErr = err
continue
}
return nil, nil, status.Error(codes.Unavailable, err.Error())
return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error())
}

acw, ok := pickResult.SubConn.(*acBalancerWrapper)
Expand All @@ -154,9 +159,10 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
if t := acw.getAddrConn().getReadyTransport(); t != nil {
if channelz.IsOn() {
return t, doneChannelzWrapper(acw, pickResult.Done), nil
doneChannelzWrapper(acw, &pickResult)
return t, pickResult, nil
}
return t, pickResult.Done, nil
return t, pickResult, nil
}
if pickResult.Done != nil {
// Calling done with nil error, no bytes sent and no bytes received.
Expand Down
36 changes: 27 additions & 9 deletions stream.go
Expand Up @@ -438,7 +438,7 @@ func (a *csAttempt) getTransport() error {
cs := a.cs

var err error
a.t, a.done, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
a.t, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
if err != nil {
if de, ok := err.(dropError); ok {
err = de.error
Expand All @@ -455,6 +455,24 @@ func (a *csAttempt) getTransport() error {
func (a *csAttempt) newStream() error {
cs := a.cs
cs.callHdr.PreviousAttempts = cs.numRetries

// Merge metadata stored in PickResult, if any, with existing call metadata.
// It is safe to overwrite the csAttempt's context here, since all state
// maintained in it are local to the attempt. When the attempt has to be
// retried, a new instance of csAttempt will be created.
if a.pickResult.Metatada != nil {
// We currently do not have a function it the metadata package which
// merges given metadata with existing metadata in a context. Existing
// function `AppendToOutgoingContext()` takes a variadic argument of key
// value pairs.
//
// TODO: Make it possible to retrieve key value pairs from metadata.MD
// in a form passable to AppendToOutgoingContext().
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add to this sentence: ..., or create a version of AppendToOutgoingContext that accepts a metadata.MD. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

md, _ := metadata.FromOutgoingContext(a.ctx)
md = metadata.Join(md, a.pickResult.Metatada)
a.ctx = metadata.NewOutgoingContext(a.ctx, md)
}

s, err := a.t.NewStream(a.ctx, cs.callHdr)
if err != nil {
nse, ok := err.(*transport.NewStreamError)
Expand Down Expand Up @@ -529,12 +547,12 @@ type clientStream struct {
// csAttempt implements a single transport stream attempt within a
// clientStream.
type csAttempt struct {
ctx context.Context
cs *clientStream
t transport.ClientTransport
s *transport.Stream
p *parser
done func(balancer.DoneInfo)
ctx context.Context
cs *clientStream
t transport.ClientTransport
s *transport.Stream
p *parser
pickResult balancer.PickResult

finished bool
dc Decompressor
Expand Down Expand Up @@ -1103,12 +1121,12 @@ func (a *csAttempt) finish(err error) {
tr = a.s.Trailer()
}

if a.done != nil {
if a.pickResult.Done != nil {
br := false
if a.s != nil {
br = a.s.BytesReceived()
}
a.done(balancer.DoneInfo{
a.pickResult.Done(balancer.DoneInfo{
Err: err,
Trailer: tr,
BytesSent: a.s != nil,
Expand Down
136 changes: 136 additions & 0 deletions test/balancer_test.go
Expand Up @@ -866,3 +866,139 @@ func (s) TestAuthorityInBuildOptions(t *testing.T) {
})
}
}

// wrappedPickFirstBalancerBuilder builds a custom balancer which wraps an
// underlying pick_first balancer.
type wrappedPickFirstBalancerBuilder struct {
name string
}

func (*wrappedPickFirstBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
builder := balancer.Get(grpc.PickFirstBalancerName)
wpfb := &wrappedPickFirstBalancer{
ClientConn: cc,
}
pf := builder.Build(wpfb, opts)
wpfb.Balancer = pf
return wpfb
}

func (wbb *wrappedPickFirstBalancerBuilder) Name() string {
return wbb.name
}

// wrappedPickFirstBalancer contains a pick_first balancer and forwards all
// calls from the ClientConn to it. For state updates from the pick_first
// balancer, it creates a custom picker which injects arbitrary metadata on a
// per-call basis.
type wrappedPickFirstBalancer struct {
balancer.Balancer
balancer.ClientConn
}

func (wb *wrappedPickFirstBalancer) UpdateState(state balancer.State) {
state.Picker = &wrappedPicker{p: state.Picker}
wb.ClientConn.UpdateState(state)
}

const (
metadataHeaderInjectedByBalancer = "metadata-header-injected-by-balancer"
metadataHeaderInjectedByApplication = "metadata-header-injected-by-application"
metadataValueInjectedByBalancer = "metadata-value-injected-by-balancer"
metadataValueInjectedByApplication = "metadata-value-injected-by-application"
)

// wrappedPicker wraps the picker returned by the pick_first
type wrappedPicker struct {
p balancer.Picker
}

func (wp *wrappedPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
res, err := wp.p.Pick(info)
if err != nil {
return balancer.PickResult{}, err
}

if res.Metatada == nil {
res.Metatada = metadata.Pairs(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
} else {
res.Metatada.Append(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
}
return res, nil
}

// TestMetadataInPickResult tests the scenario where an LB policy inject
// arbitrary metadata on a per-call basis and verifies that the injected
// metadata makes it all the way to the server RPC handler.
func (s) TestMetadataInPickResult(t *testing.T) {
t.Log("Starting test backend...")
mdChan := make(chan metadata.MD, 1)
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
md, _ := metadata.FromIncomingContext(ctx)
select {
case mdChan <- md:
case <-ctx.Done():
return nil, ctx.Err()
}
return &testpb.Empty{}, nil
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Starting test backend: %v", err)
}
defer ss.Stop()
t.Logf("Started test backend at %q", ss.Address)

name := t.Name() + "wrappedPickFirstBalancer"
t.Logf("Registering test balancer with name %q...", name)
b := &wrappedPickFirstBalancerBuilder{name: t.Name() + "wrappedPickFirstBalancer"}
balancer.Register(b)

t.Log("Creating ClientConn to test backend...")
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address}}})
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, b.Name())),
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("grpc.Dial(): %v", err)
}
defer cc.Close()
tc := testpb.NewTestServiceClient(cc)

t.Log("Making EmptyCall() RPC with custom metadata...")
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
md := metadata.Pairs(metadataHeaderInjectedByApplication, metadataValueInjectedByApplication)
ctx = metadata.NewOutgoingContext(ctx, md)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() RPC: %v", err)
}
t.Log("EmptyCall() RPC succeeded")

t.Log("Waiting for custom metadata to be received at the test backend...")
var gotMD metadata.MD
select {
case gotMD = <-mdChan:
case <-ctx.Done():
t.Fatalf("Timed out waiting for custom metadata to be received at the test backend")
}

t.Log("Verifying custom metadata added by the client application is received at the test backend...")
wantMDVal := []string{metadataValueInjectedByApplication}
gotMDVal := gotMD.Get(metadataHeaderInjectedByApplication)
if !cmp.Equal(gotMDVal, wantMDVal) {
t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
}

t.Log("Verifying custom metadata added by the LB policy is received at the test backend...")
wantMDVal = []string{metadataValueInjectedByBalancer}
gotMDVal = gotMD.Get(metadataHeaderInjectedByBalancer)
if !cmp.Equal(gotMDVal, wantMDVal) {
t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
}
}