diff --git a/balancer/balancer.go b/balancer/balancer.go index 392b21fb2d8e..09d61dd1b55b 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -279,6 +279,14 @@ type PickResult struct { // type, Done may not be called. May be nil if the balancer does not wish // to be notified when the RPC completes. Done func(DoneInfo) + + // Metadata provides a way for LB policies to inject arbitrary per-call + // metadata. Any metadata returned here will be merged with existing + // metadata added by the client application. + // + // LB policies with child policies are responsible for propagating metadata + // injected by their children to the ClientConn, as part of Pick(). + Metatada metadata.MD } // TransientFailureError returns e. It exists for backward compatibility and diff --git a/clientconn.go b/clientconn.go index 045668904519..8402c19e3e9f 100644 --- a/clientconn.go +++ b/clientconn.go @@ -934,7 +934,7 @@ func (cc *ClientConn) healthCheckConfig() *healthCheckConfig { return cc.sc.healthCheckConfig } -func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, func(balancer.DoneInfo), error) { +func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) { return cc.blockingpicker.pick(ctx, failfast, balancer.PickInfo{ Ctx: ctx, FullMethodName: method, diff --git a/picker_wrapper.go b/picker_wrapper.go index a5d5516ee060..c525dc070fc6 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -58,12 +58,18 @@ func (pw *pickerWrapper) updatePicker(p balancer.Picker) { pw.mu.Unlock() } -func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) func(balancer.DoneInfo) { +// doneChannelzWrapper performs the following: +// - increments the calls started channelz counter +// - wraps the done function in the passed in result to increment the calls +// failed or calls succeeded channelz counter before invoking the actual +// done function. +func doneChannelzWrapper(acw *acBalancerWrapper, result *balancer.PickResult) { acw.mu.Lock() ac := acw.ac acw.mu.Unlock() ac.incrCallsStarted() - return func(b balancer.DoneInfo) { + done := result.Done + result.Done = func(b balancer.DoneInfo) { if b.Err != nil && b.Err != io.EOF { ac.incrCallsFailed() } else { @@ -82,7 +88,7 @@ func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) f // - the current picker returns other errors and failfast is false. // - the subConn returned by the current picker is not READY // When one of these situations happens, pick blocks until the picker gets updated. -func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, func(balancer.DoneInfo), error) { +func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, balancer.PickResult, error) { var ch chan struct{} var lastPickErr error @@ -90,7 +96,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. pw.mu.Lock() if pw.done { pw.mu.Unlock() - return nil, nil, ErrClientConnClosing + return nil, balancer.PickResult{}, ErrClientConnClosing } if pw.picker == nil { @@ -111,9 +117,9 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. } switch ctx.Err() { case context.DeadlineExceeded: - return nil, nil, status.Error(codes.DeadlineExceeded, errStr) + return nil, balancer.PickResult{}, status.Error(codes.DeadlineExceeded, errStr) case context.Canceled: - return nil, nil, status.Error(codes.Canceled, errStr) + return nil, balancer.PickResult{}, status.Error(codes.Canceled, errStr) } case <-ch: } @@ -125,7 +131,6 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. pw.mu.Unlock() pickResult, err := p.Pick(info) - if err != nil { if err == balancer.ErrNoSubConnAvailable { continue @@ -136,7 +141,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. if istatus.IsRestrictedControlPlaneCode(st) { err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err) } - return nil, nil, dropError{error: err} + return nil, balancer.PickResult{}, dropError{error: err} } // For all other errors, wait for ready RPCs should block and other // RPCs should fail with unavailable. @@ -144,7 +149,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. lastPickErr = err continue } - return nil, nil, status.Error(codes.Unavailable, err.Error()) + return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error()) } acw, ok := pickResult.SubConn.(*acBalancerWrapper) @@ -154,9 +159,10 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. } if t := acw.getAddrConn().getReadyTransport(); t != nil { if channelz.IsOn() { - return t, doneChannelzWrapper(acw, pickResult.Done), nil + doneChannelzWrapper(acw, &pickResult) + return t, pickResult, nil } - return t, pickResult.Done, nil + return t, pickResult, nil } if pickResult.Done != nil { // Calling done with nil error, no bytes sent and no bytes received. diff --git a/stream.go b/stream.go index 0f8e6c0149da..175aee9583ea 100644 --- a/stream.go +++ b/stream.go @@ -438,7 +438,7 @@ func (a *csAttempt) getTransport() error { cs := a.cs var err error - a.t, a.done, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method) + a.t, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method) if err != nil { if de, ok := err.(dropError); ok { err = de.error @@ -455,6 +455,25 @@ func (a *csAttempt) getTransport() error { func (a *csAttempt) newStream() error { cs := a.cs cs.callHdr.PreviousAttempts = cs.numRetries + + // Merge metadata stored in PickResult, if any, with existing call metadata. + // It is safe to overwrite the csAttempt's context here, since all state + // maintained in it are local to the attempt. When the attempt has to be + // retried, a new instance of csAttempt will be created. + if a.pickResult.Metatada != nil { + // We currently do not have a function it the metadata package which + // merges given metadata with existing metadata in a context. Existing + // function `AppendToOutgoingContext()` takes a variadic argument of key + // value pairs. + // + // TODO: Make it possible to retrieve key value pairs from metadata.MD + // in a form passable to AppendToOutgoingContext(), or create a version + // of AppendToOutgoingContext() that accepts a metadata.MD. + md, _ := metadata.FromOutgoingContext(a.ctx) + md = metadata.Join(md, a.pickResult.Metatada) + a.ctx = metadata.NewOutgoingContext(a.ctx, md) + } + s, err := a.t.NewStream(a.ctx, cs.callHdr) if err != nil { nse, ok := err.(*transport.NewStreamError) @@ -529,12 +548,12 @@ type clientStream struct { // csAttempt implements a single transport stream attempt within a // clientStream. type csAttempt struct { - ctx context.Context - cs *clientStream - t transport.ClientTransport - s *transport.Stream - p *parser - done func(balancer.DoneInfo) + ctx context.Context + cs *clientStream + t transport.ClientTransport + s *transport.Stream + p *parser + pickResult balancer.PickResult finished bool dc Decompressor @@ -1103,12 +1122,12 @@ func (a *csAttempt) finish(err error) { tr = a.s.Trailer() } - if a.done != nil { + if a.pickResult.Done != nil { br := false if a.s != nil { br = a.s.BytesReceived() } - a.done(balancer.DoneInfo{ + a.pickResult.Done(balancer.DoneInfo{ Err: err, Trailer: tr, BytesSent: a.s != nil, diff --git a/test/balancer_test.go b/test/balancer_test.go index c919f1e0f7c4..bd782ffa6e4f 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -866,3 +866,139 @@ func (s) TestAuthorityInBuildOptions(t *testing.T) { }) } } + +// wrappedPickFirstBalancerBuilder builds a custom balancer which wraps an +// underlying pick_first balancer. +type wrappedPickFirstBalancerBuilder struct { + name string +} + +func (*wrappedPickFirstBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + builder := balancer.Get(grpc.PickFirstBalancerName) + wpfb := &wrappedPickFirstBalancer{ + ClientConn: cc, + } + pf := builder.Build(wpfb, opts) + wpfb.Balancer = pf + return wpfb +} + +func (wbb *wrappedPickFirstBalancerBuilder) Name() string { + return wbb.name +} + +// wrappedPickFirstBalancer contains a pick_first balancer and forwards all +// calls from the ClientConn to it. For state updates from the pick_first +// balancer, it creates a custom picker which injects arbitrary metadata on a +// per-call basis. +type wrappedPickFirstBalancer struct { + balancer.Balancer + balancer.ClientConn +} + +func (wb *wrappedPickFirstBalancer) UpdateState(state balancer.State) { + state.Picker = &wrappedPicker{p: state.Picker} + wb.ClientConn.UpdateState(state) +} + +const ( + metadataHeaderInjectedByBalancer = "metadata-header-injected-by-balancer" + metadataHeaderInjectedByApplication = "metadata-header-injected-by-application" + metadataValueInjectedByBalancer = "metadata-value-injected-by-balancer" + metadataValueInjectedByApplication = "metadata-value-injected-by-application" +) + +// wrappedPicker wraps the picker returned by the pick_first +type wrappedPicker struct { + p balancer.Picker +} + +func (wp *wrappedPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + res, err := wp.p.Pick(info) + if err != nil { + return balancer.PickResult{}, err + } + + if res.Metatada == nil { + res.Metatada = metadata.Pairs(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer) + } else { + res.Metatada.Append(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer) + } + return res, nil +} + +// TestMetadataInPickResult tests the scenario where an LB policy inject +// arbitrary metadata on a per-call basis and verifies that the injected +// metadata makes it all the way to the server RPC handler. +func (s) TestMetadataInPickResult(t *testing.T) { + t.Log("Starting test backend...") + mdChan := make(chan metadata.MD, 1) + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + md, _ := metadata.FromIncomingContext(ctx) + select { + case mdChan <- md: + case <-ctx.Done(): + return nil, ctx.Err() + } + return &testpb.Empty{}, nil + }, + } + if err := ss.StartServer(); err != nil { + t.Fatalf("Starting test backend: %v", err) + } + defer ss.Stop() + t.Logf("Started test backend at %q", ss.Address) + + name := t.Name() + "wrappedPickFirstBalancer" + t.Logf("Registering test balancer with name %q...", name) + b := &wrappedPickFirstBalancerBuilder{name: t.Name() + "wrappedPickFirstBalancer"} + balancer.Register(b) + + t.Log("Creating ClientConn to test backend...") + r := manual.NewBuilderWithScheme("whatever") + r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address}}}) + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithResolvers(r), + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, b.Name())), + } + cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...) + if err != nil { + t.Fatalf("grpc.Dial(): %v", err) + } + defer cc.Close() + tc := testpb.NewTestServiceClient(cc) + + t.Log("Making EmptyCall() RPC with custom metadata...") + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + md := metadata.Pairs(metadataHeaderInjectedByApplication, metadataValueInjectedByApplication) + ctx = metadata.NewOutgoingContext(ctx, md) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall() RPC: %v", err) + } + t.Log("EmptyCall() RPC succeeded") + + t.Log("Waiting for custom metadata to be received at the test backend...") + var gotMD metadata.MD + select { + case gotMD = <-mdChan: + case <-ctx.Done(): + t.Fatalf("Timed out waiting for custom metadata to be received at the test backend") + } + + t.Log("Verifying custom metadata added by the client application is received at the test backend...") + wantMDVal := []string{metadataValueInjectedByApplication} + gotMDVal := gotMD.Get(metadataHeaderInjectedByApplication) + if !cmp.Equal(gotMDVal, wantMDVal) { + t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal) + } + + t.Log("Verifying custom metadata added by the LB policy is received at the test backend...") + wantMDVal = []string{metadataValueInjectedByBalancer} + gotMDVal = gotMD.Get(metadataHeaderInjectedByBalancer) + if !cmp.Equal(gotMDVal, wantMDVal) { + t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal) + } +}