Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: drain client transport when streamID approaches maxStreamID #5889

Merged
merged 20 commits into from Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 10 additions & 7 deletions internal/transport/controlbuf.go
Expand Up @@ -650,16 +650,19 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error {
itl: &itemList{},
wq: h.wq,
}
str.itl.enqueue(h)
return l.originateStream(str)
return l.originateStream(str, h)
}

func (l *loopyWriter) originateStream(str *outStream) error {
hdr := str.itl.dequeue().(*headerFrame)
// originateStreamWithHeaderFrame calls the initStream function on the headerFrame and
// called writeHeader. If write succeeds the streamID is added to l.estdStreams
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

called->calls?

And end the last sentence with a period, please.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I nixed the comment itself. Seems like i really just wrote exactly what the below 10 lines do

func (l *loopyWriter) originateStream(str *outStream, hdr *headerFrame) error {
// l.draining is set for an incomingGoAway. In which case, we want to avoid further
// writes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

Also... we want to avoid creating new streams instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if l.draining {
hdr.onOrphaned(errStreamDrain)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// TODO: provide a better error with the reason we are in draining. e.g.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return nil
}
if err := hdr.initStream(str.id); err != nil {
if err == errStreamDrain { // errStreamDrain need not close transport
return nil
}
return err
}
if err := l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions internal/transport/defaults.go
Expand Up @@ -47,3 +47,8 @@ const (
defaultClientMaxHeaderListSize = uint32(16 << 20)
defaultServerMaxHeaderListSize = uint32(16 << 20)
)

// MaxStreamID is the upper bound for the stream ID before the current
// transport gracefully closes and new transport is created for subsequent RPCs.
// This is set to 75% of math.MaxUint32. It's exported so that tests can override it.
var MaxStreamID = uint32(3_221_225_472)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamID is a 31-bit unsigned integer. The max value a streamID could be before overflow is 2^31-1 = 2,147,483,647.

I'm not sure this will actually prevent this issue from happening, as the overflow would occur before 3_221_225_472. You can read more about this in RFC-7540: https://www.rfc-editor.org/rfc/rfc7540#section-5.1.1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spec:

   Streams are identified with an unsigned 31-bit integer.  Streams
   initiated by a client MUST use odd-numbered stream identifiers; those
   initiated by the server MUST use even-numbered stream identifiers.  A
   stream identifier of zero (0x0) is used for connection control
   messages; the stream identifier of zero cannot be used to establish a
   new stream.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @patrick-ogrady! Great catch. Im going to fix that

31 changes: 23 additions & 8 deletions internal/transport/http2_client.go
Expand Up @@ -742,15 +742,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
endStream: false,
initStream: func(id uint32) error {
t.mu.Lock()
if state := t.state; state != reachable {
// we want initStream to cleanup and return an error when transport is closing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: capitalize "We".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nixing this comment as well. Seems like it doesnt add value when i read it now

// initStream is never called when transport is draining.
if t.state == closing {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// TODO: handle transport closure in loopy instead and remove this. e.g.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

t.mu.Unlock()
// Do a quick cleanup.
err := error(errStreamDrain)
if state == closing {
err = ErrConnClosing
}
cleanup(err)
return err
cleanup(ErrConnClosing)
return ErrConnClosing
}
if channelz.IsOn() {
atomic.AddInt64(&t.czData.streamsStarted, 1)
Expand All @@ -768,6 +765,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
}
firstTry := true
var ch chan struct{}
transportDrainRequired := false
checkForStreamQuota := func(it interface{}) bool {
if t.streamQuota <= 0 { // Can go negative if server decreases it.
if firstTry {
Expand All @@ -783,6 +781,11 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
h := it.(*headerFrame)
h.streamID = t.nextID
t.nextID += 2

// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID

s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.mu.Lock()
Expand Down Expand Up @@ -862,6 +865,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
sh.HandleRPC(s.ctx, outHeader)
}
}
if transportDrainRequired {
if logger.V(logLevel) {
logger.Infof("transport: t.nextID > MaxStreamID. Draining")
}
t.GracefulClose()
}
return s, nil
}

Expand Down Expand Up @@ -1783,3 +1792,9 @@ func (t *http2Client) getOutFlowWindow() int64 {
return -2
}
}

func (t *http2Client) stateForTesting() transportState {
t.mu.Lock()
defer t.mu.Unlock()
return t.state
}
46 changes: 46 additions & 0 deletions internal/transport/transport_test.go
Expand Up @@ -536,6 +536,52 @@ func (s) TestInflightStreamClosing(t *testing.T) {
}
}

// Tests that when streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
defer cancel()
defer server.stop()
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}

originalMaxStreamID := MaxStreamID
MaxStreamID = 3
defer func() {
MaxStreamID = originalMaxStreamID
}()

ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()

s, err := ct.NewStream(ctx, callHdr)
if err != nil {
t.Fatalf("ct.NewStream() = %v", err)
}
if s.id != 1 {
t.Fatalf("Stream id: %d, want: 1", s.id)
}

if got, want := ct.stateForTesting(), reachable; got != want {
t.Fatalf("Client transport state %v, want %v", got, want)
}

// The expected stream ID here is 3 since stream IDs are incremented by 2.
s, err = ct.NewStream(ctx, callHdr)
easwars marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatalf("ct.NewStream() = %v", err)
}
if s.id != 3 {
t.Fatalf("Stream id: %d, want: 3", s.id)
}

// Verifying that ct.state is draining when next stream ID > MaxStreamId.
if got, want := ct.stateForTesting(), draining; got != want {
t.Fatalf("Client transport state %v, want %v", got, want)
}
}

func (s) TestClientSendAndReceive(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
defer cancel()
Expand Down
154 changes: 154 additions & 0 deletions test/transport_test.go
@@ -0,0 +1,154 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package test

import (
"context"
"io"
"net"
"sync"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)

// connWrapperWithCloseCh wraps a net.Conn and pushes on a channel when closed.
type connWrapperWithCloseCh struct {
net.Conn
close *grpcsync.Event
}

// Close closes the connection and sends a value on the close channel.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...and fires the close event. now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

func (cw *connWrapperWithCloseCh) Close() error {
err := cw.Conn.Close()
cw.close.Fire()
return err
}

// These custom creds are used for storing the connections made by the client.
// The closeCh in conn can be used to detect when conn is closed.
type transportRestartCheckCreds struct {
mu sync.Mutex
connections []*connWrapperWithCloseCh
}

func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
conn := &connWrapperWithCloseCh{Conn: rawConn, close: grpcsync.NewEvent()}
c.connections = append(c.connections, conn)
return conn, nil, nil
}
func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials {
return c
}
func (c *transportRestartCheckCreds) OverrideServerName(s string) error {
return nil
}

// Tests that the client transport drains and restarts when next stream ID exceeds
// MaxStreamID. This test also verifies that subsequent RPCs use a new client
// transport and the old transport is closed.
func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
// Set the transport's MaxStreamID to 4 to cause connection to drain after 2 RPCs.
originalMaxStreamID := transport.MaxStreamID
transport.MaxStreamID = 4
defer func() {
transport.MaxStreamID = originalMaxStreamID
}()

ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
if _, err := stream.Recv(); err != nil {
return status.Errorf(codes.Internal, "unexpected error receiving: %v", err)
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil {
return status.Errorf(codes.Internal, "unexpected error sending: %v", err)
}
if recv, err := stream.Recv(); err != io.EOF {
return status.Errorf(codes.Internal, "Recv = %v, %v; want _, io.EOF", recv, err)
}
return nil
},
}

creds := &transportRestartCheckCreds{}
if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil {
t.Fatalf("Starting stubServer: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

var streams []testpb.TestService_FullDuplexCallClient

const numStreams = 3
// expected number of conns when each stream is created i.e., 3rd stream is created
// on a new connection.
expectedNumConns := [numStreams]int{1, 1, 2}

// Set up 3 streams.
for i := 0; i < numStreams; i++ {
s, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Creating FullDuplex stream: %v", err)
}
streams = append(streams, s)
// Verify expected num of conns after each stream is created.
if len(creds.connections) != expectedNumConns[i] {
t.Fatalf("Got number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i])
}
}

// Verify all streams still work.
for i, stream := range streams {
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Fatalf("Sending on stream %d: %v", i, err)
}
if _, err := stream.Recv(); err != nil {
t.Fatalf("Receiving on stream %d: %v", i, err)
}
}

for i, stream := range streams {
if err := stream.CloseSend(); err != nil {
t.Fatalf("CloseSend() on stream %d: %v", i, err)
}
}

// Verifying first connection was closed.
select {
case <-creds.connections[0].close.Done():
case <-ctx.Done():
t.Fatal("Timeout expired when waiting for first client transport to close")
}
}