New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transport: drain client transport when streamID approaches maxStreamID #5889
Changes from all commits
32f93bf
ae31b25
4fe3f5e
1609a2a
0f55aca
cd1bd83
db42893
fac885e
b9b8465
560b3dd
41d1e77
84d7472
8d92634
6b93023
5df4d0f
9d3bada
e02ee57
5b1aad6
d9851b9
6749314
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,3 +47,8 @@ const ( | |
defaultClientMaxHeaderListSize = uint32(16 << 20) | ||
defaultServerMaxHeaderListSize = uint32(16 << 20) | ||
) | ||
|
||
// MaxStreamID is the upper bound for the stream ID before the current | ||
// transport gracefully closes and new transport is created for subsequent RPCs. | ||
// This is set to 75% of math.MaxUint32. It's exported so that tests can override it. | ||
var MaxStreamID = uint32(3_221_225_472) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not sure this will actually prevent this issue from happening, as the overflow would occur before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @dfawley There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The spec:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @patrick-ogrady! Great catch. Im going to fix that |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -742,15 +742,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, | |
endStream: false, | ||
initStream: func(id uint32) error { | ||
t.mu.Lock() | ||
if state := t.state; state != reachable { | ||
// TODO: handle transport closure in loopy instead and remove this | ||
// initStream is never called when transport is draining. | ||
if t.state == closing { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
t.mu.Unlock() | ||
// Do a quick cleanup. | ||
err := error(errStreamDrain) | ||
if state == closing { | ||
err = ErrConnClosing | ||
} | ||
cleanup(err) | ||
return err | ||
cleanup(ErrConnClosing) | ||
return ErrConnClosing | ||
} | ||
if channelz.IsOn() { | ||
atomic.AddInt64(&t.czData.streamsStarted, 1) | ||
|
@@ -768,6 +765,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, | |
} | ||
firstTry := true | ||
var ch chan struct{} | ||
transportDrainRequired := false | ||
checkForStreamQuota := func(it interface{}) bool { | ||
if t.streamQuota <= 0 { // Can go negative if server decreases it. | ||
if firstTry { | ||
|
@@ -783,6 +781,11 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, | |
h := it.(*headerFrame) | ||
h.streamID = t.nextID | ||
t.nextID += 2 | ||
|
||
// Drain client transport if nextID > MaxStreamID which signals gRPC that | ||
// the connection is closed and a new one must be created for subsequent RPCs. | ||
transportDrainRequired = t.nextID > MaxStreamID | ||
|
||
s.id = h.streamID | ||
s.fc = &inFlow{limit: uint32(t.initialWindowSize)} | ||
t.mu.Lock() | ||
|
@@ -862,6 +865,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, | |
sh.HandleRPC(s.ctx, outHeader) | ||
} | ||
} | ||
if transportDrainRequired { | ||
if logger.V(logLevel) { | ||
logger.Infof("transport: t.nextID > MaxStreamID. Draining") | ||
} | ||
t.GracefulClose() | ||
} | ||
return s, nil | ||
} | ||
|
||
|
@@ -1783,3 +1792,9 @@ func (t *http2Client) getOutFlowWindow() int64 { | |
return -2 | ||
} | ||
} | ||
|
||
func (t *http2Client) stateForTesting() transportState { | ||
t.mu.Lock() | ||
defer t.mu.Unlock() | ||
return t.state | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
/* | ||
* | ||
* Copyright 2023 gRPC authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
*/ | ||
package test | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"net" | ||
"sync" | ||
"testing" | ||
|
||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/credentials" | ||
"google.golang.org/grpc/internal/grpcsync" | ||
"google.golang.org/grpc/internal/stubserver" | ||
"google.golang.org/grpc/internal/transport" | ||
"google.golang.org/grpc/status" | ||
testpb "google.golang.org/grpc/test/grpc_testing" | ||
) | ||
|
||
// connWrapperWithCloseCh wraps a net.Conn and fires an event when closed. | ||
type connWrapperWithCloseCh struct { | ||
net.Conn | ||
close *grpcsync.Event | ||
} | ||
|
||
// Close closes the connection and sends a value on the close channel. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
func (cw *connWrapperWithCloseCh) Close() error { | ||
cw.close.Fire() | ||
return cw.Conn.Close() | ||
} | ||
|
||
// These custom creds are used for storing the connections made by the client. | ||
// The closeCh in conn can be used to detect when conn is closed. | ||
type transportRestartCheckCreds struct { | ||
mu sync.Mutex | ||
connections []*connWrapperWithCloseCh | ||
} | ||
|
||
func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { | ||
return rawConn, nil, nil | ||
} | ||
func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
conn := &connWrapperWithCloseCh{Conn: rawConn, close: grpcsync.NewEvent()} | ||
c.connections = append(c.connections, conn) | ||
return conn, nil, nil | ||
} | ||
func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo { | ||
return credentials.ProtocolInfo{} | ||
} | ||
func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials { | ||
return c | ||
} | ||
func (c *transportRestartCheckCreds) OverrideServerName(s string) error { | ||
return nil | ||
} | ||
|
||
// Tests that the client transport drains and restarts when next stream ID exceeds | ||
// MaxStreamID. This test also verifies that subsequent RPCs use a new client | ||
// transport and the old transport is closed. | ||
func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { | ||
// Set the transport's MaxStreamID to 4 to cause connection to drain after 2 RPCs. | ||
originalMaxStreamID := transport.MaxStreamID | ||
transport.MaxStreamID = 4 | ||
defer func() { | ||
transport.MaxStreamID = originalMaxStreamID | ||
}() | ||
|
||
ss := &stubserver.StubServer{ | ||
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { | ||
if _, err := stream.Recv(); err != nil { | ||
return status.Errorf(codes.Internal, "unexpected error receiving: %v", err) | ||
} | ||
if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil { | ||
return status.Errorf(codes.Internal, "unexpected error sending: %v", err) | ||
} | ||
if recv, err := stream.Recv(); err != io.EOF { | ||
return status.Errorf(codes.Internal, "Recv = %v, %v; want _, io.EOF", recv, err) | ||
} | ||
return nil | ||
}, | ||
} | ||
|
||
creds := &transportRestartCheckCreds{} | ||
if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil { | ||
t.Fatalf("Starting stubServer: %v", err) | ||
} | ||
defer ss.Stop() | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||
defer cancel() | ||
|
||
var streams []testpb.TestService_FullDuplexCallClient | ||
|
||
const numStreams = 3 | ||
// expected number of conns when each stream is created i.e., 3rd stream is created | ||
// on a new connection. | ||
expectedNumConns := [numStreams]int{1, 1, 2} | ||
|
||
// Set up 3 streams. | ||
for i := 0; i < numStreams; i++ { | ||
s, err := ss.Client.FullDuplexCall(ctx) | ||
if err != nil { | ||
t.Fatalf("Creating FullDuplex stream: %v", err) | ||
} | ||
streams = append(streams, s) | ||
// Verify expected num of conns after each stream is created. | ||
if len(creds.connections) != expectedNumConns[i] { | ||
t.Fatalf("Got number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i]) | ||
} | ||
} | ||
|
||
// Verify all streams still work. | ||
for i, stream := range streams { | ||
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { | ||
t.Fatalf("Sending on stream %d: %v", i, err) | ||
} | ||
if _, err := stream.Recv(); err != nil { | ||
t.Fatalf("Receiving on stream %d: %v", i, err) | ||
} | ||
} | ||
|
||
for i, stream := range streams { | ||
if err := stream.CloseSend(); err != nil { | ||
t.Fatalf("CloseSend() on stream %d: %v", i, err) | ||
} | ||
} | ||
|
||
// Verifying first connection was closed. | ||
select { | ||
case <-creds.connections[0].close.Done(): | ||
case <-ctx.Done(): | ||
t.Fatal("Timeout expired when waiting for first client transport to close") | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// TODO: provide a better error with the reason we are in draining.
e.g.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done