Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: expose API to set send compressor #5744

Merged
merged 13 commits into from Jan 31, 2023
39 changes: 39 additions & 0 deletions internal/grpcutil/compressor.go
Expand Up @@ -19,6 +19,7 @@
package grpcutil

import (
"fmt"
"strings"

"google.golang.org/grpc/internal/envconfig"
Expand All @@ -45,3 +46,41 @@ func RegisteredCompressors() string {
}
return strings.Join(RegisteredCompressorNames, ",")
}

// ValidateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors.
func ValidateSendCompressor(name, clientAdvertisedCompressors string) error {
if name == "identity" {
return nil
}

if !IsCompressorNameRegistered(name) {
return fmt.Errorf("compressor not registered: %s", name)
}

if !compressorExists(name, clientAdvertisedCompressors) {
return fmt.Errorf("client does not support compressor: %s", name)
}

return nil
}

// compressorExists returns true when the given name exists in the comma
// separated compressor list.
func compressorExists(name, compressors string) bool {
var (
i = 0
length = len(compressors)
)
for j := 0; j <= length; j++ {
if j < length && compressors[j] != ',' {
jronak marked this conversation as resolved.
Show resolved Hide resolved
continue
}

if compressors[i:j] == name {
return true
}
i = j + 1
}
return false
}
42 changes: 42 additions & 0 deletions internal/grpcutil/compressor_test.go
Expand Up @@ -19,6 +19,7 @@
package grpcutil

import (
"fmt"
"testing"

"google.golang.org/grpc/internal/envconfig"
Expand All @@ -44,3 +45,44 @@ func TestRegisteredCompressors(t *testing.T) {
}
}
}

func TestValidateSendCompressors(t *testing.T) {
defer func(c []string) { RegisteredCompressorNames = c }(RegisteredCompressorNames)
RegisteredCompressorNames = []string{"gzip", "snappy"}
tests := []struct {
desc string
name string
advertisedCompressors string
wantErr error
}{
{
desc: "success_when_identity_compressor",
name: "identity",
advertisedCompressors: "gzip,snappy",
},
{
desc: "success_when_compressor_exists",
name: "snappy",
advertisedCompressors: "testcomp,gzip,snappy",
},
{
desc: "failure_when_compressor_not_registered",
name: "testcomp",
advertisedCompressors: "testcomp,gzip,snappy",
wantErr: fmt.Errorf("compressor not registered: testcomp"),
},
{
desc: "failure_when_compressor_not_advertised",
name: "gzip",
advertisedCompressors: "testcomp,snappy",
wantErr: fmt.Errorf("client does not support compressor: gzip"),
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
if err := ValidateSendCompressor(tt.name, tt.advertisedCompressors); fmt.Sprint(err) != fmt.Sprint(tt.wantErr) {
t.Fatalf("Unexpected validation got:%v, want:%v", err, tt.wantErr)
}
})
}
}
9 changes: 9 additions & 0 deletions internal/transport/handler_server_test.go
Expand Up @@ -270,6 +270,10 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
if err != nil {
t.Error(err)
}
err = s.SetSendCompress("gzip")
easwars marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Error(err)
}

md := metadata.Pairs("custom-header", "Another custom header value")
err = s.SendHeader(md)
Expand All @@ -286,6 +290,10 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
if err == nil {
t.Error("expected second SendHeader call to fail")
}
err = s.SetSendCompress("snappy")
if err == nil {
t.Error("expected second SetSendCompress call to fail")
}

st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
Expand All @@ -299,6 +307,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Custom-Header": {"Custom header value", "Another custom header value"},
"Grpc-Encoding": {"gzip"},
}
wantTrailer := http.Header{
"Grpc-Status": {"0"},
Expand Down
5 changes: 5 additions & 0 deletions internal/transport/http2_server.go
Expand Up @@ -27,6 +27,7 @@ import (
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -456,6 +457,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
return false
}

if encodings := mdata["grpc-accept-encoding"]; len(encodings) != 0 {
s.clientAdvertisedCompressors = strings.Join(encodings, ",")
}

if !isGRPC || headerError {
t.controlBuf.put(&cleanupStream{
streamID: streamID,
Expand Down
21 changes: 20 additions & 1 deletion internal/transport/transport.go
Expand Up @@ -253,6 +253,9 @@ type Stream struct {
fc *inFlow
wq *writeQuota

// Holds compressor names passed in grpc-accept-encoding metadata from the
// client. This is empty for the client side Stream.
clientAdvertisedCompressors string
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
Expand Down Expand Up @@ -341,8 +344,24 @@ func (s *Stream) RecvCompress() string {
}

// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
func (s *Stream) SetSendCompress(str string) error {
jronak marked this conversation as resolved.
Show resolved Hide resolved
if s.isHeaderSent() || s.getState() == streamDone {
return status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done")
}

s.sendCompress = str
return nil
}

// SendCompress returns the send compressor name.
func (s *Stream) SendCompress() string {
return s.sendCompress
}

// ClientAdvertisedCompressors returns the advertised compressor names by the
// client.
jronak marked this conversation as resolved.
Show resolved Hide resolved
func (s *Stream) ClientAdvertisedCompressors() string {
return s.clientAdvertisedCompressors
}

// Done returns a channel which is closed when it receives the final status
Expand Down
44 changes: 40 additions & 4 deletions server.go
Expand Up @@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -1267,6 +1268,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
var comp, decomp encoding.Compressor
var cp Compressor
var dc Decompressor
var sendCompressorName string

// If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp.
Expand All @@ -1287,12 +1289,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
cp = s.opts.cp
stream.SetSendCompress(cp.Type())
_ = stream.SetSendCompress(cp.Type())
jronak marked this conversation as resolved.
Show resolved Hide resolved
sendCompressorName = cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
comp = encoding.GetCompressor(rc)
if comp != nil {
stream.SetSendCompress(rc)
_ = stream.SetSendCompress(rc)
sendCompressorName = comp.Name()
}
}

Expand Down Expand Up @@ -1379,6 +1383,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
opts := &transport.Options{Last: true}

if stream.SendCompress() != sendCompressorName {
jronak marked this conversation as resolved.
Show resolved Hide resolved
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
Expand Down Expand Up @@ -1606,12 +1613,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
ss.cp = s.opts.cp
stream.SetSendCompress(s.opts.cp.Type())
_ = stream.SetSendCompress(s.opts.cp.Type())
ss.sendCompressorName = s.opts.cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
ss.comp = encoding.GetCompressor(rc)
if ss.comp != nil {
stream.SetSendCompress(rc)
_ = stream.SetSendCompress(rc)
ss.sendCompressorName = rc
}
}

Expand Down Expand Up @@ -1944,6 +1953,33 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}

// SetSendCompressor sets the compressor that will be used when sending
jronak marked this conversation as resolved.
Show resolved Hide resolved
// RPC payload back to the client. It may be called at most once, and must not
// be called after any event that causes headers to be sent (see SetHeader for
// a complete list). Provided compressor is used when below conditions are met:
//
// - compressor is registered via encoding.RegisterCompressor
// - compressor name exists in the client advertised compressor names sent in
// grpc-accept-encoding metadata.
//
// The context provided must be the context passed to the server's handler.
//
// The error returned is compatible with the status package. However, the
// status code will often not match the RPC status as seen by the client
// application, and therefore, should not be relied upon for this purpose.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not return a status error here, as it does not represent the RPC's status (as nothing on the server side does, but even though we do return misleading status errors in some places, it should not be justification to do so more often).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
easwars marked this conversation as resolved.
Show resolved Hide resolved
if !ok || stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it common to log the context with %v? If there's truly going to be useful information in there then LGTM but I'm not sure how it looks TBH and we should remove it if it's too cluttered or just prints a pointer or something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does look very cluttered with all the fields being printed inline. Something like:

context.Background.WithValue(type string, val test-val)

Also, I don't think this can highlight any meaningful information. Dropped context from the error message

}

if err := grpcutil.ValidateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
easwars marked this conversation as resolved.
Show resolved Hide resolved
return status.Errorf(codes.Internal, "grpc: failed to set send compressor %v", err)
easwars marked this conversation as resolved.
Show resolved Hide resolved
jronak marked this conversation as resolved.
Show resolved Hide resolved
}

return stream.SetSendCompress(name)
}

// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// When called more than once, all the provided metadata will be merged.
//
Expand Down
7 changes: 7 additions & 0 deletions stream.go
Expand Up @@ -1481,6 +1481,8 @@ type serverStream struct {
comp encoding.Compressor
decomp encoding.Compressor

sendCompressorName string
easwars marked this conversation as resolved.
Show resolved Hide resolved

maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
Expand Down Expand Up @@ -1573,6 +1575,11 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
}()

if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
ss.comp = encoding.GetCompressor(sendCompressorsName)
ss.sendCompressorName = sendCompressorsName
}

// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil {
Expand Down