Skip to content

Commit

Permalink
transport: drain client transport when streamID approaches maxStreamID (
Browse files Browse the repository at this point in the history
#5889)

Fixes #5600
  • Loading branch information
arvindbr8 committed Jan 11, 2023
1 parent 42b7b63 commit 6de8f50
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 15 deletions.
16 changes: 9 additions & 7 deletions internal/transport/controlbuf.go
Expand Up @@ -650,16 +650,18 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error {
itl: &itemList{},
wq: h.wq,
}
str.itl.enqueue(h)
return l.originateStream(str)
return l.originateStream(str, h)
}

func (l *loopyWriter) originateStream(str *outStream) error {
hdr := str.itl.dequeue().(*headerFrame)
func (l *loopyWriter) originateStream(str *outStream, hdr *headerFrame) error {
// l.draining is set when handling GoAway. In which case, we want to avoid
// creating new streams.
if l.draining {
// TODO: provide a better error with the reason we are in draining.
hdr.onOrphaned(errStreamDrain)
return nil
}
if err := hdr.initStream(str.id); err != nil {
if err == errStreamDrain { // errStreamDrain need not close transport
return nil
}
return err
}
if err := l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions internal/transport/defaults.go
Expand Up @@ -47,3 +47,8 @@ const (
defaultClientMaxHeaderListSize = uint32(16 << 20)
defaultServerMaxHeaderListSize = uint32(16 << 20)
)

// MaxStreamID is the upper bound for the stream ID before the current
// transport gracefully closes and new transport is created for subsequent RPCs.
// This is set to 75% of math.MaxUint32. It's exported so that tests can override it.
var MaxStreamID = uint32(3_221_225_472)
31 changes: 23 additions & 8 deletions internal/transport/http2_client.go
Expand Up @@ -742,15 +742,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
endStream: false,
initStream: func(id uint32) error {
t.mu.Lock()
if state := t.state; state != reachable {
// TODO: handle transport closure in loopy instead and remove this
// initStream is never called when transport is draining.
if t.state == closing {
t.mu.Unlock()
// Do a quick cleanup.
err := error(errStreamDrain)
if state == closing {
err = ErrConnClosing
}
cleanup(err)
return err
cleanup(ErrConnClosing)
return ErrConnClosing
}
if channelz.IsOn() {
atomic.AddInt64(&t.czData.streamsStarted, 1)
Expand All @@ -768,6 +765,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
}
firstTry := true
var ch chan struct{}
transportDrainRequired := false
checkForStreamQuota := func(it interface{}) bool {
if t.streamQuota <= 0 { // Can go negative if server decreases it.
if firstTry {
Expand All @@ -783,6 +781,11 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
h := it.(*headerFrame)
h.streamID = t.nextID
t.nextID += 2

// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID

s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.mu.Lock()
Expand Down Expand Up @@ -862,6 +865,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
sh.HandleRPC(s.ctx, outHeader)
}
}
if transportDrainRequired {
if logger.V(logLevel) {
logger.Infof("transport: t.nextID > MaxStreamID. Draining")
}
t.GracefulClose()
}
return s, nil
}

Expand Down Expand Up @@ -1783,3 +1792,9 @@ func (t *http2Client) getOutFlowWindow() int64 {
return -2
}
}

func (t *http2Client) stateForTesting() transportState {
t.mu.Lock()
defer t.mu.Unlock()
return t.state
}
46 changes: 46 additions & 0 deletions internal/transport/transport_test.go
Expand Up @@ -536,6 +536,52 @@ func (s) TestInflightStreamClosing(t *testing.T) {
}
}

// Tests that when streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
defer cancel()
defer server.stop()
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}

originalMaxStreamID := MaxStreamID
MaxStreamID = 3
defer func() {
MaxStreamID = originalMaxStreamID
}()

ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()

s, err := ct.NewStream(ctx, callHdr)
if err != nil {
t.Fatalf("ct.NewStream() = %v", err)
}
if s.id != 1 {
t.Fatalf("Stream id: %d, want: 1", s.id)
}

if got, want := ct.stateForTesting(), reachable; got != want {
t.Fatalf("Client transport state %v, want %v", got, want)
}

// The expected stream ID here is 3 since stream IDs are incremented by 2.
s, err = ct.NewStream(ctx, callHdr)
if err != nil {
t.Fatalf("ct.NewStream() = %v", err)
}
if s.id != 3 {
t.Fatalf("Stream id: %d, want: 3", s.id)
}

// Verifying that ct.state is draining when next stream ID > MaxStreamId.
if got, want := ct.stateForTesting(), draining; got != want {
t.Fatalf("Client transport state %v, want %v", got, want)
}
}

func (s) TestClientSendAndReceive(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
defer cancel()
Expand Down
153 changes: 153 additions & 0 deletions test/transport_test.go
@@ -0,0 +1,153 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package test

import (
"context"
"io"
"net"
"sync"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)

// connWrapperWithCloseCh wraps a net.Conn and fires an event when closed.
type connWrapperWithCloseCh struct {
net.Conn
close *grpcsync.Event
}

// Close closes the connection and sends a value on the close channel.
func (cw *connWrapperWithCloseCh) Close() error {
cw.close.Fire()
return cw.Conn.Close()
}

// These custom creds are used for storing the connections made by the client.
// The closeCh in conn can be used to detect when conn is closed.
type transportRestartCheckCreds struct {
mu sync.Mutex
connections []*connWrapperWithCloseCh
}

func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
conn := &connWrapperWithCloseCh{Conn: rawConn, close: grpcsync.NewEvent()}
c.connections = append(c.connections, conn)
return conn, nil, nil
}
func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials {
return c
}
func (c *transportRestartCheckCreds) OverrideServerName(s string) error {
return nil
}

// Tests that the client transport drains and restarts when next stream ID exceeds
// MaxStreamID. This test also verifies that subsequent RPCs use a new client
// transport and the old transport is closed.
func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
// Set the transport's MaxStreamID to 4 to cause connection to drain after 2 RPCs.
originalMaxStreamID := transport.MaxStreamID
transport.MaxStreamID = 4
defer func() {
transport.MaxStreamID = originalMaxStreamID
}()

ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
if _, err := stream.Recv(); err != nil {
return status.Errorf(codes.Internal, "unexpected error receiving: %v", err)
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil {
return status.Errorf(codes.Internal, "unexpected error sending: %v", err)
}
if recv, err := stream.Recv(); err != io.EOF {
return status.Errorf(codes.Internal, "Recv = %v, %v; want _, io.EOF", recv, err)
}
return nil
},
}

creds := &transportRestartCheckCreds{}
if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil {
t.Fatalf("Starting stubServer: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

var streams []testpb.TestService_FullDuplexCallClient

const numStreams = 3
// expected number of conns when each stream is created i.e., 3rd stream is created
// on a new connection.
expectedNumConns := [numStreams]int{1, 1, 2}

// Set up 3 streams.
for i := 0; i < numStreams; i++ {
s, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Creating FullDuplex stream: %v", err)
}
streams = append(streams, s)
// Verify expected num of conns after each stream is created.
if len(creds.connections) != expectedNumConns[i] {
t.Fatalf("Got number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i])
}
}

// Verify all streams still work.
for i, stream := range streams {
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Fatalf("Sending on stream %d: %v", i, err)
}
if _, err := stream.Recv(); err != nil {
t.Fatalf("Receiving on stream %d: %v", i, err)
}
}

for i, stream := range streams {
if err := stream.CloseSend(); err != nil {
t.Fatalf("CloseSend() on stream %d: %v", i, err)
}
}

// Verifying first connection was closed.
select {
case <-creds.connections[0].close.Done():
case <-ctx.Done():
t.Fatal("Timeout expired when waiting for first client transport to close")
}
}

0 comments on commit 6de8f50

Please sign in to comment.