Skip to content

Commit

Permalink
server: fix a few issues where grpc server uses RST_STREAM for non-HT…
Browse files Browse the repository at this point in the history
…TP/2 errors (#5893)

Fixes #5892
  • Loading branch information
jhump committed Jan 18, 2023
1 parent ace8082 commit 9b9b381
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 107 deletions.
4 changes: 2 additions & 2 deletions internal/transport/handler_server.go
Expand Up @@ -65,7 +65,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
if !validContentType {
msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType)
http.Error(w, msg, http.StatusBadRequest)
http.Error(w, msg, http.StatusUnsupportedMediaType)
return nil, errors.New(msg)
}
if _, ok := w.(http.Flusher); !ok {
Expand All @@ -87,7 +87,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := decodeTimeout(v)
if err != nil {
msg := fmt.Sprintf("malformed time-out: %v", err)
msg := fmt.Sprintf("malformed grpc-timeout: %v", err)
http.Error(w, msg, http.StatusBadRequest)
return nil, status.Error(codes.Internal, msg)
}
Expand Down
40 changes: 29 additions & 11 deletions internal/transport/handler_server_test.go
Expand Up @@ -41,11 +41,12 @@ import (

func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
type testCase struct {
name string
req *http.Request
wantErr string
modrw func(http.ResponseWriter) http.ResponseWriter
check func(*serverHandlerTransport, *testCase) error
name string
req *http.Request
wantErr string
wantErrCode int
modrw func(http.ResponseWriter) http.ResponseWriter
check func(*serverHandlerTransport, *testCase) error
}
tests := []testCase{
{
Expand All @@ -54,7 +55,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
ProtoMajor: 1,
ProtoMinor: 1,
},
wantErr: "gRPC requires HTTP/2",
wantErr: "gRPC requires HTTP/2",
wantErrCode: http.StatusBadRequest,
},
{
name: "bad method",
Expand All @@ -63,7 +65,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
Method: "GET",
Header: http.Header{},
},
wantErr: `invalid gRPC request method "GET"`,
wantErr: `invalid gRPC request method "GET"`,
wantErrCode: http.StatusBadRequest,
},
{
name: "bad content type",
Expand All @@ -74,7 +77,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
"Content-Type": {"application/foo"},
},
},
wantErr: `invalid gRPC request content-type "application/foo"`,
wantErr: `invalid gRPC request content-type "application/foo"`,
wantErrCode: http.StatusUnsupportedMediaType,
},
{
name: "not flusher",
Expand All @@ -93,7 +97,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
}
return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
},
wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
wantErrCode: http.StatusInternalServerError,
},
{
name: "valid",
Expand Down Expand Up @@ -153,7 +158,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
Path: "/service/foo.bar",
},
},
wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`,
wantErr: `rpc error: code = Internal desc = malformed grpc-timeout: transport: timeout unit is not recognized: "tomorrow"`,
wantErrCode: http.StatusBadRequest,
},
{
name: "with metadata",
Expand Down Expand Up @@ -187,7 +193,12 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
}

for _, tt := range tests {
rw := newTestHandlerResponseWriter()
rrec := httptest.NewRecorder()
rw := http.ResponseWriter(testHandlerResponseWriter{
ResponseRecorder: rrec,
closeNotify: make(chan bool, 1),
})

if tt.modrw != nil {
rw = tt.modrw(rw)
}
Expand All @@ -196,6 +207,13 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
continue
}
if tt.wantErrCode == 0 {
tt.wantErrCode = http.StatusOK
}
if rrec.Code != tt.wantErrCode {
t.Errorf("%s: code = %d; want %d", tt.name, rrec.Code, tt.wantErrCode)
continue
}
if gotErr != nil {
continue
}
Expand Down
46 changes: 34 additions & 12 deletions internal/transport/http2_server.go
Expand Up @@ -380,13 +380,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
fc: &inFlow{limit: uint32(t.initialWindowSize)},
}
var (
// If a gRPC Response-Headers has already been received, then it means
// that the peer is speaking gRPC and we are in gRPC mode.
isGRPC = false
mdata = make(map[string][]string)
httpMethod string
// headerError is set if an error is encountered while parsing the headers
headerError bool
// if false, content-type was missing or invalid
isGRPC = false
contentType = ""
mdata = make(map[string][]string)
httpMethod string
// these are set if an error is encountered while parsing the headers
protocolError bool
headerError *status.Status

timeoutSet bool
timeout time.Duration
Expand All @@ -397,6 +398,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
case "content-type":
contentSubtype, validContentType := grpcutil.ContentSubtype(hf.Value)
if !validContentType {
contentType = hf.Value
break
}
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
Expand All @@ -412,22 +414,22 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
timeoutSet = true
var err error
if timeout, err = decodeTimeout(hf.Value); err != nil {
headerError = true
headerError = status.Newf(codes.Internal, "malformed grpc-timeout: %v", err)
}
// "Transports must consider requests containing the Connection header
// as malformed." - A41
case "connection":
if logger.V(logLevel) {
logger.Errorf("transport: http2Server.operateHeaders parsed a :connection header which makes a request malformed as per the HTTP/2 spec")
}
headerError = true
protocolError = true
default:
if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) {
break
}
v, err := decodeMetadataHeader(hf.Name, hf.Value)
if err != nil {
headerError = true
headerError = status.Newf(codes.Internal, "malformed binary metadata %q in header %q: %v", hf.Value, hf.Name, err)
logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err)
break
}
Expand All @@ -446,7 +448,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
logger.Errorf("transport: %v", errMsg)
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: 400,
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
Expand All @@ -455,7 +457,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
return nil
}

if !isGRPC || headerError {
if protocolError {
t.controlBuf.put(&cleanupStream{
streamID: streamID,
rst: true,
Expand All @@ -464,6 +466,26 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
})
return nil
}
if !isGRPC {
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusUnsupportedMediaType,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType),
rst: !frame.StreamEnded(),
})
return nil
}
if headerError != nil {
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: headerError,
rst: !frame.StreamEnded(),
})
return nil
}

// "If :authority is missing, Host must be renamed to :authority." - A41
if len(mdata[":authority"]) == 0 {
Expand Down

0 comments on commit 9b9b381

Please sign in to comment.