Skip to content

Commit

Permalink
server: expose API to set send compressor
Browse files Browse the repository at this point in the history
  • Loading branch information
jronak committed Oct 26, 2022
1 parent 3c09650 commit b19a201
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 5 deletions.
39 changes: 39 additions & 0 deletions internal/grpcutil/compressor.go
Expand Up @@ -19,6 +19,7 @@
package grpcutil

import (
"fmt"
"strings"

"google.golang.org/grpc/internal/envconfig"
Expand All @@ -45,3 +46,41 @@ func RegisteredCompressors() string {
}
return strings.Join(RegisteredCompressorNames, ",")
}

// ValidateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors.
func ValidateSendCompressor(name, clientAdvertisedCompressors string) error {
if name == "identity" {
return nil
}

if !IsCompressorNameRegistered(name) {
return fmt.Errorf("compressor not registered: %s", name)
}

if !compressorExists(name, clientAdvertisedCompressors) {
return fmt.Errorf("client does not support compressor: %s", name)
}

return nil
}

// compressorExists returns true when the given name exists in the comma
// separated compressor list.
func compressorExists(name, compressors string) bool {
var (
i = 0
length = len(compressors)
)
for j := 0; j <= length; j++ {
if j < length && compressors[j] != ',' {
continue
}

if compressors[i:j] == name {
return true
}
i = j + 1
}
return false
}
42 changes: 42 additions & 0 deletions internal/grpcutil/compressor_test.go
Expand Up @@ -19,6 +19,7 @@
package grpcutil

import (
"fmt"
"testing"

"google.golang.org/grpc/internal/envconfig"
Expand All @@ -44,3 +45,44 @@ func TestRegisteredCompressors(t *testing.T) {
}
}
}

func TestValidateSendCompressors(t *testing.T) {
defer func(c []string) { RegisteredCompressorNames = c }(RegisteredCompressorNames)
RegisteredCompressorNames = []string{"gzip", "snappy"}
tests := []struct {
desc string
name string
advertisedCompressors string
wantErr error
}{
{
desc: "success_when_identity_compressor",
name: "identity",
advertisedCompressors: "gzip,snappy",
},
{
desc: "success_when_compressor_exists",
name: "snappy",
advertisedCompressors: "testcomp,gzip,snappy",
},
{
desc: "failure_when_compressor_not_registered",
name: "testcomp",
advertisedCompressors: "testcomp,gzip,snappy",
wantErr: fmt.Errorf("compressor not registered: testcomp"),
},
{
desc: "failure_when_compressor_not_advertised",
name: "gzip",
advertisedCompressors: "testcomp,snappy",
wantErr: fmt.Errorf("client does not support compressor: gzip"),
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
if err := ValidateSendCompressor(tt.name, tt.advertisedCompressors); fmt.Sprint(err) != fmt.Sprint(tt.wantErr) {
t.Fatalf("Unexpected validation got:%v, want:%v", err, tt.wantErr)
}
})
}
}
9 changes: 9 additions & 0 deletions internal/transport/handler_server_test.go
Expand Up @@ -270,6 +270,10 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
if err != nil {
t.Error(err)
}
err = s.SetSendCompress("gzip")
if err != nil {
t.Error(err)
}

md := metadata.Pairs("custom-header", "Another custom header value")
err = s.SendHeader(md)
Expand All @@ -286,6 +290,10 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
if err == nil {
t.Error("expected second SendHeader call to fail")
}
err = s.SetSendCompress("snappy")
if err == nil {
t.Error("expected second SetSendCompress call to fail")
}

st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
Expand All @@ -299,6 +307,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Custom-Header": {"Custom header value", "Another custom header value"},
"Grpc-Encoding": {"gzip"},
}
wantTrailer := http.Header{
"Grpc-Status": {"0"},
Expand Down
5 changes: 5 additions & 0 deletions internal/transport/http2_server.go
Expand Up @@ -27,6 +27,7 @@ import (
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -456,6 +457,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
return false
}

if encodings := mdata["grpc-accept-encoding"]; len(encodings) != 0 {
s.clientAdvertisedCompressors = strings.Join(encodings, ",")
}

if !isGRPC || headerError {
t.controlBuf.put(&cleanupStream{
streamID: streamID,
Expand Down
21 changes: 20 additions & 1 deletion internal/transport/transport.go
Expand Up @@ -253,6 +253,9 @@ type Stream struct {
fc *inFlow
wq *writeQuota

// Holds compressor names passed in grpc-accept-encoding metadata from the
// client. This is empty for the client side Stream.
clientAdvertisedCompressors string
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
Expand Down Expand Up @@ -341,8 +344,24 @@ func (s *Stream) RecvCompress() string {
}

// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
func (s *Stream) SetSendCompress(str string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done")
}

s.sendCompress = str
return nil
}

// SendCompress returns the send compressor name.
func (s *Stream) SendCompress() string {
return s.sendCompress
}

// ClientAdvertisedCompressors returns the advertised compressor names by the
// client.
func (s *Stream) ClientAdvertisedCompressors() string {
return s.clientAdvertisedCompressors
}

// Done returns a channel which is closed when it receives the final status
Expand Down
44 changes: 40 additions & 4 deletions server.go
Expand Up @@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -1267,6 +1268,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
var comp, decomp encoding.Compressor
var cp Compressor
var dc Decompressor
var sendCompressorName string

// If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp.
Expand All @@ -1287,12 +1289,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
cp = s.opts.cp
stream.SetSendCompress(cp.Type())
_ = stream.SetSendCompress(cp.Type())
sendCompressorName = cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
comp = encoding.GetCompressor(rc)
if comp != nil {
stream.SetSendCompress(rc)
_ = stream.SetSendCompress(rc)
sendCompressorName = comp.Name()
}
}

Expand Down Expand Up @@ -1379,6 +1383,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
opts := &transport.Options{Last: true}

if stream.SendCompress() != sendCompressorName {
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
Expand Down Expand Up @@ -1606,12 +1613,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
ss.cp = s.opts.cp
stream.SetSendCompress(s.opts.cp.Type())
_ = stream.SetSendCompress(s.opts.cp.Type())
ss.sendCompressorName = s.opts.cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
ss.comp = encoding.GetCompressor(rc)
if ss.comp != nil {
stream.SetSendCompress(rc)
_ = stream.SetSendCompress(rc)
ss.sendCompressorName = rc
}
}

Expand Down Expand Up @@ -1944,6 +1953,33 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}

// SetSendCompressor sets the compressor that will be used when sending
// RPC payload back to the client. It may be called at most once, and must not
// be called after any event that causes headers to be sent (see SetHeader for
// a complete list). Provided compressor is used when below conditions are met:
//
// - compressor is registered via encoding.RegisterCompressor
// - compressor name exists in the client advertised compressor names sent in
// grpc-accept-encoding metadata.
//
// The context provided must be the context passed to the server's handler.
//
// The error returned is compatible with the status package. However, the
// status code will often not match the RPC status as seen by the client
// application, and therefore, should not be relied upon for this purpose.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
}

if err := grpcutil.ValidateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor %v", err)
}

return stream.SetSendCompress(name)
}

// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// When called more than once, all the provided metadata will be merged.
//
Expand Down
7 changes: 7 additions & 0 deletions stream.go
Expand Up @@ -1481,6 +1481,8 @@ type serverStream struct {
comp encoding.Compressor
decomp encoding.Compressor

sendCompressorName string

maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
Expand Down Expand Up @@ -1573,6 +1575,11 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
}()

if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
ss.comp = encoding.GetCompressor(sendCompressorsName)
ss.sendCompressorName = sendCompressorsName
}

// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil {
Expand Down

0 comments on commit b19a201

Please sign in to comment.