Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: after GracefulStop, ensure connections are closed when final RPC completes #5968

Merged
merged 1 commit into from Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions internal/transport/controlbuf.go
Expand Up @@ -527,6 +527,9 @@ const minBatchSize = 1000
// As an optimization, to increase the batch size for each flush, loopy yields the processor, once
// if the batch size is too low to give stream goroutines a chance to fill it up.
func (l *loopyWriter) run() (err error) {
// Always flush the writer before exiting in case there are pending frames
// to be sent.
defer l.framer.writer.Flush()
for {
it, err := l.cbuf.get(true)
if err != nil {
Expand Down Expand Up @@ -759,7 +762,7 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
return err
}
}
if l.side == clientSide && l.draining && len(l.estdStreams) == 0 {
if l.draining && len(l.estdStreams) == 0 {
return errors.New("finished processing active streams while in draining mode")
}
return nil
Expand Down Expand Up @@ -814,7 +817,6 @@ func (l *loopyWriter) goAwayHandler(g *goAway) error {
}

func (l *loopyWriter) closeConnectionHandler() error {
l.framer.writer.Flush()
// Exit loopyWriter entirely by returning an error here. This will lead to
// the transport closing the connection, and, ultimately, transport
// closure.
Expand Down
51 changes: 51 additions & 0 deletions test/gracefulstop_test.go
Expand Up @@ -26,6 +26,7 @@ import (
"testing"
"time"

"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
Expand Down Expand Up @@ -164,3 +165,53 @@ func (s) TestGracefulStop(t *testing.T) {
cancel()
wg.Wait()
}

func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) {
// This test ensures that a server closes the connections to its clients
// when the final stream has completed after a GOAWAY.

handlerCalled := make(chan struct{})
gracefulStopCalled := make(chan struct{})

ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error {
close(handlerCalled) // Initiate call to GracefulStop.
<-gracefulStopCalled // Wait for GOAWAYs to be received by the client.
return nil
}}

te := newTest(t, tcpClearEnv)
te.startServer(ts)
defer te.tearDown()

te.withServerTester(func(st *serverTester) {
st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false)

<-handlerCalled // Wait for the server to invoke its handler.

// Gracefully stop the server.
gracefulStopDone := make(chan struct{})
go func() {
te.srv.GracefulStop()
close(gracefulStopDone)
}()
st.wantGoAway(http2.ErrCodeNo) // Server sends a GOAWAY due to GracefulStop.
pf := st.wantPing() // Server sends a ping to verify client receipt.
st.writePing(true, pf.Data) // Send ping ack to confirm.
st.wantGoAway(http2.ErrCodeNo) // Wait for subsequent GOAWAY to indicate no new stream processing.

close(gracefulStopCalled) // Unblock server handler.

fr := st.wantAnyFrame() // Wait for trailer.
hdr, ok := fr.(*http2.MetaHeadersFrame)
if !ok {
t.Fatalf("Received unexpected frame of type (%T) from server: %v; want HEADERS", fr, fr)
}
if !hdr.StreamEnded() {
t.Fatalf("Received unexpected HEADERS frame from server: %v; want END_STREAM set", fr)
}

st.wantRSTStream(http2.ErrCodeNo) // Server should send RST_STREAM because client did not half-close.

<-gracefulStopDone // Wait for GracefulStop to return.
})
}
35 changes: 31 additions & 4 deletions test/servertester.go
Expand Up @@ -138,19 +138,46 @@ func (st *serverTester) writeSettingsAck() {
}
}

func (st *serverTester) wantGoAway(errCode http2.ErrCode) *http2.GoAwayFrame {
f, err := st.readFrame()
if err != nil {
st.t.Fatalf("Error while expecting an RST frame: %v", err)
}
gaf, ok := f.(*http2.GoAwayFrame)
if !ok {
st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
}
if gaf.ErrCode != errCode {
st.t.Fatalf("expected GOAWAY error code '%v', got '%v'", errCode.String(), gaf.ErrCode.String())
}
return gaf
}

func (st *serverTester) wantPing() *http2.PingFrame {
f, err := st.readFrame()
if err != nil {
st.t.Fatalf("Error while expecting an RST frame: %v", err)
}
pf, ok := f.(*http2.PingFrame)
if !ok {
st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
}
return pf
}

func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame {
f, err := st.readFrame()
if err != nil {
st.t.Fatalf("Error while expecting an RST frame: %v", err)
}
sf, ok := f.(*http2.RSTStreamFrame)
rf, ok := f.(*http2.RSTStreamFrame)
if !ok {
st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f)
}
if sf.ErrCode != errCode {
st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), sf.ErrCode.String())
if rf.ErrCode != errCode {
st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), rf.ErrCode.String())
}
return sf
return rf
}

func (st *serverTester) wantSettings() *http2.SettingsFrame {
Expand Down
6 changes: 3 additions & 3 deletions test/stream_cleanup_test.go
Expand Up @@ -46,7 +46,7 @@ func (s) TestStreamCleanup(t *testing.T) {
return &testpb.Empty{}, nil
},
}
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
if err := ss.Start(nil, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
Expand Down Expand Up @@ -79,7 +79,7 @@ func (s) TestStreamCleanupAfterSendStatus(t *testing.T) {
})
},
}
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
if err := ss.Start(nil, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
Expand Down Expand Up @@ -132,6 +132,6 @@ func (s) TestStreamCleanupAfterSendStatus(t *testing.T) {
case <-gracefulStopDone:
timer.Stop()
case <-timer.C:
t.Fatalf("s.GracefulStop() didn't finish without 1 second after the last RPC")
t.Fatalf("s.GracefulStop() didn't finish within 1 second after the last RPC")
}
}