Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: drain client transport when streamID approaches maxStreamID #5889

Merged
merged 20 commits into from Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 10 additions & 7 deletions internal/transport/controlbuf.go
Expand Up @@ -650,16 +650,19 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error {
itl: &itemList{},
wq: h.wq,
}
str.itl.enqueue(h)
return l.originateStream(str)
return l.originateStream(str, h)
}

func (l *loopyWriter) originateStream(str *outStream) error {
hdr := str.itl.dequeue().(*headerFrame)
// originateStreamWithHeaderFrame calls the initStream function on the headerFrame and
// called writeHeader. If write succeeds the streamID is added to l.estdStreams
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

called->calls?

And end the last sentence with a period, please.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I nixed the comment itself. Seems like i really just wrote exactly what the below 10 lines do

func (l *loopyWriter) originateStream(str *outStream, hdr *headerFrame) error {
// l.draining is set for an incomingGoAway. In which case, we want to avoid further
// writes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

Also... we want to avoid creating new streams instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if l.draining {
hdr.onOrphaned(errStreamDrain)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// TODO: provide a better error with the reason we are in draining. e.g.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return nil
}
if err := hdr.initStream(str.id); err != nil {
if err == errStreamDrain { // errStreamDrain need not close transport
return nil
}
return err
}
if err := l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil {
Expand Down
13 changes: 5 additions & 8 deletions internal/transport/http2_client.go
Expand Up @@ -742,15 +742,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
endStream: false,
initStream: func(id uint32) error {
t.mu.Lock()
if state := t.state; state != reachable {
// we want initStream to cleanup and return an error when transport is closing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: capitalize "We".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nixing this comment as well. Seems like it doesnt add value when i read it now

// initStream is never called when transport is draining.
if t.state == closing {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// TODO: handle transport closure in loopy instead and remove this. e.g.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

t.mu.Unlock()
// Do a quick cleanup.
err := error(errStreamDrain)
if state == closing {
err = ErrConnClosing
}
cleanup(err)
return err
cleanup(ErrConnClosing)
return ErrConnClosing
}
if channelz.IsOn() {
atomic.AddInt64(&t.czData.streamsStarted, 1)
Expand Down
96 changes: 26 additions & 70 deletions test/transport_end2end_test.go → test/transport_test.go
@@ -1,6 +1,6 @@
/*
*
* Copyright 2022 gRPC authors.
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -27,47 +27,31 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)

// authInfoWithConn wraps the underlying net.Conn, and makes it available
// to the test as part of the Peer call option.
type authInfoWithConn struct {
credentials.CommonAuthInfo
conn net.Conn
}

func (ai *authInfoWithConn) AuthType() string {
return ""
}

// connWrapperWithCloseCh wraps a net.Conn and pushes on a channel when closed.
type connWrapperWithCloseCh struct {
net.Conn
closeCh chan interface{}
close *grpcsync.Event
}

// Close closes the connection and sends a value on the close channel.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...and fires the close event. now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

func (cw *connWrapperWithCloseCh) Close() error {
err := cw.Conn.Close()
for {
select {
case cw.closeCh <- nil:
return err
case <-cw.closeCh:
}
}
cw.close.Fire()
return err
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit/optional:

cw.close.Fire()
return cw.Conn.Close()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

// These custom creds are used for storing the connections made by the client.
// The closeCh in conn can be used to detect when conn is closed.
type transportRestartCheckCreds struct {
mu sync.Mutex
connections []connWrapperWithCloseCh
connections []*connWrapperWithCloseCh
}

func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
Expand All @@ -76,17 +60,15 @@ func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn
func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
conn := &connWrapperWithCloseCh{Conn: rawConn, closeCh: make(chan interface{}, 1)}
c.connections = append(c.connections, *conn)
commonAuthInfo := credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}
authInfo := &authInfoWithConn{commonAuthInfo, conn}
return conn, authInfo, nil
conn := &connWrapperWithCloseCh{Conn: rawConn, close: grpcsync.NewEvent()}
c.connections = append(c.connections, conn)
return conn, nil, nil
}
func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials {
return &transportRestartCheckCreds{}
return c
}
func (c *transportRestartCheckCreds) OverrideServerName(s string) error {
return nil
Expand All @@ -96,22 +78,20 @@ func (c *transportRestartCheckCreds) OverrideServerName(s string) error {
// MaxStreamID. This test also verifies that subsequent RPCs use a new client
// transport and the old transport is closed.
func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
// Set the transport's MaxStreamID to 5 to cause connection to drain after 2 RPCs.
// Set the transport's MaxStreamID to 4 to cause connection to drain after 2 RPCs.
originalMaxStreamID := transport.MaxStreamID
transport.MaxStreamID = 5
transport.MaxStreamID = 4
defer func() {
transport.MaxStreamID = originalMaxStreamID
}()

ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
for i := 0; i < 2; i++ {
if _, err := stream.Recv(); err != nil {
return status.Errorf(codes.Internal, "unexpected error receiving: %v", err)
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil {
return status.Errorf(codes.Internal, "unexpected error sending: %v", err)
}
if _, err := stream.Recv(); err != nil {
return status.Errorf(codes.Internal, "unexpected error receiving: %v", err)
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil {
return status.Errorf(codes.Internal, "unexpected error sending: %v", err)
}
if recv, err := stream.Recv(); err != io.EOF {
return status.Errorf(codes.Internal, "Recv = %v, %v; want _, io.EOF", recv, err)
Expand All @@ -131,27 +111,21 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {

var streams []testpb.TestService_FullDuplexCallClient

// expectedNumConns when each stream is created.
expectedNumConns := []int{1, 1, 2}
const numStreams = 3
// expected number of conns when each stream is created i.e., 3rd stream is created
// on a new connection.
expectedNumConns := [numStreams]int{1, 1, 2}

// Set up 3 streams and call sendAndReceive() once on each.
for i := 0; i < 3; i++ {
// Set up 3 streams.
for i := 0; i < numStreams; i++ {
s, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Creating FullDuplex stream: %v", err)
}

streams = append(streams, s)
if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Fatalf("Sending on stream %d: %v", i, err)
}
if _, err := s.Recv(); err != nil {
t.Fatalf("Receiving on stream %d: %v", i, err)
}

// Verify expected num of conns.
// Verify expected num of conns after each stream is created.
if len(creds.connections) != expectedNumConns[i] {
t.Fatalf("Number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i])
t.Fatalf("Got number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i])
}
}

Expand All @@ -165,33 +139,15 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
}
}

var connPerStream []net.Conn

// The peer passed via the call option is set up only after the RPC is complete.
// Conn used by the stream is available in authInfo.
for i, stream := range streams {
if err := stream.CloseSend(); err != nil {
t.Fatalf("CloseSend() on stream %d: %v", i, err)
}
p, ok := peer.FromContext(stream.Context())
if !ok {
t.Fatalf("Getting peer from stream context for stream %d", i)
}
connPerStream = append(connPerStream, p.AuthInfo.(*authInfoWithConn).conn)
}

// Verifying the first and second RPCs were made on the same connection.
if connPerStream[0] != connPerStream[1] {
t.Fatal("Got streams using different connections; want same.")
}
// Verifying the third and first/second RPCs were made on different connections.
if connPerStream[2] == connPerStream[0] {
t.Fatal("Got streams using same connections; want different.")
}

// Verifying first connection was closed.
select {
case <-creds.connections[0].closeCh:
case <-creds.connections[0].close.Done():
case <-ctx.Done():
t.Fatal("Timeout expired when waiting for first client transport to close")
}
Expand Down