Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: expose API to set send compressor #5744

Merged
merged 13 commits into from Jan 31, 2023
3 changes: 1 addition & 2 deletions internal/transport/http2_server.go
Expand Up @@ -406,11 +406,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
isGRPC = true

case "grpc-accept-encoding":
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
if hf.Value == "" {
continue
}

mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
compressors := hf.Value
if s.clientAdvertisedCompressors != "" {
compressors = s.clientAdvertisedCompressors + "," + compressors
Expand Down
4 changes: 2 additions & 2 deletions server.go
Expand Up @@ -1954,7 +1954,7 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}

// SetSendCompressor sets a compressor for outbound messages.
// SetSendCompressor sets a compressor for outbound messages from the server.
// It must not be called after any event that causes headers to be sent
// (see ServerStream.SetHeader for the complete list). Provided compressor is
// used when below conditions are met:
Expand All @@ -1975,7 +1975,7 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
//
easwars marked this conversation as resolved.
Show resolved Hide resolved
// # Experimental
jronak marked this conversation as resolved.
Show resolved Hide resolved
//
// Notice: This function_ is EXPERIMENTAL and may be changed or removed in a
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
easwars marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
72 changes: 57 additions & 15 deletions test/end2end_test.go
Expand Up @@ -5081,34 +5081,64 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) {
}
}

// wrapCompressor is a wrapper of encoding.Compressor which maintains count of
// Compressor method invokes.
type wrapCompressor struct {
encoding.Compressor
compressInvokes int32
}

func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
atomic.AddInt32(&wc.compressInvokes, 1)
return wc.Compressor.Compress(w)
}

func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
oldC := encoding.GetCompressor("gzip")
c := &wrapCompressor{Compressor: oldC}
encoding.RegisterCompressor(c)
t.Cleanup(func() {
encoding.RegisterCompressor(oldC)
})
return c
}

func (s) TestSetSendCompressorSuccess(t *testing.T) {
for _, tt := range []struct {
desc string
resCompressor string
wantErr error
name string
desc string
dialOpts []grpc.DialOption
resCompressor string
wantCompressInvokes int32
}{
{
desc: "gzip_response_compressor",
resCompressor: "gzip",
name: "identity_request_and_gzip_response",
desc: "request is uncompressed and response is gzip compressed",
resCompressor: "gzip",
wantCompressInvokes: 1,
},
{
desc: "identity_response_compressor",
resCompressor: "identity",
name: "gzip_request_and_identity_response",
desc: "request is gzip compressed and response is uncompressed with identity",
resCompressor: "identity",
dialOpts: []grpc.DialOption{grpc.WithCompressor(grpc.NewGZIPCompressor())},
jronak marked this conversation as resolved.
Show resolved Hide resolved
wantCompressInvokes: 0,
},
} {
t.Run(tt.desc, func(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Run("unary", func(t *testing.T) {
testUnarySetSendCompressorSuccess(t, tt.resCompressor)
testUnarySetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
})

t.Run("stream", func(t *testing.T) {
testStreamSetSendCompressorSuccess(t, tt.resCompressor)
testStreamSetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
})
})
}
}

func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string) {
func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
wc := setupGzipWrapCompressor(t)
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
Expand All @@ -5117,7 +5147,7 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string) {
return &testpb.Empty{}, nil
},
}
if err := ss.Start(nil); err != nil {
if err := ss.Start(nil, dialOpts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
Expand All @@ -5128,9 +5158,15 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string) {
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
}

compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
if compressInvokes != wantCompressInvokes {
t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
}
}

func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string) {
func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
wc := setupGzipWrapCompressor(t)
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
if _, err := stream.Recv(); err != nil {
Expand All @@ -5144,8 +5180,8 @@ func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string) {
return stream.Send(&testpb.StreamingOutputCallResponse{})
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v, want: nil", err)
if err := ss.Start(nil, dialOpts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

Expand All @@ -5164,6 +5200,11 @@ func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string) {
if _, err := s.Recv(); err != nil {
t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
}

compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
if compressInvokes != wantCompressInvokes {
t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
}
}

func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) {
Expand Down Expand Up @@ -5262,6 +5303,7 @@ func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) {
err := grpc.SetSendCompressor(ctx, "gzip")
if err == nil {
t.Error("Wanted set send compressor error")
jronak marked this conversation as resolved.
Show resolved Hide resolved
return &testpb.Empty{}, nil
}
return nil, err
},
Expand Down