diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index c977b88c524d..a8a9e4d7b2aa 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -785,9 +785,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, h.streamID = t.nextID t.nextID += 2 - // drain client transport if nextID > MaxStreamID, which is set to - // 75% of math.MaxUint32, which signals gRPC that the connection is closed - // and a new one must be created for subsequent RPCs. + // Drain client transport if nextID > MaxStreamID which signals gRPC that + // the connection is closed and a new one must be created for subsequent RPCs. transportDrainRequired = t.nextID > MaxStreamID s.id = h.streamID @@ -1797,12 +1796,8 @@ func (t *http2Client) getOutFlowWindow() int64 { } } -func (t *http2Client) compareStateForTesting(s transportState) error { +func (t *http2Client) stateForTesting() transportState { t.mu.Lock() defer t.mu.Unlock() - if t.state != s { - return fmt.Errorf("clientTransport.state: %v, want: %v", t.state, s) - } - - return nil + return t.state } diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index ab20294ef744..70bc34c7512e 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -536,7 +536,7 @@ func (s) TestInflightStreamClosing(t *testing.T) { } } -// tests that when streamID > MaxStreamId, the current client transport drains. +// Tests that when streamID > MaxStreamId, the current client transport drains. func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) defer cancel() @@ -545,7 +545,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { Host: "localhost", Method: "foo.Small", } - // override MaxStreamID. + originalMaxStreamID := MaxStreamID MaxStreamID = 3 defer func() { @@ -563,10 +563,11 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { t.Fatalf("stream id: %d, want: 1", s.id) } - if err := ct.compareStateForTesting(reachable); err != nil { - t.Fatal(err) + if got, want := ct.stateForTesting(), reachable; got != want { + t.Fatalf("client transport state %v, want %v", got, want) } + // The expected stream ID here is 3 since stream IDs are incremented by 2. s, err = ct.NewStream(ctx, callHdr) if err != nil { t.Fatalf("ct.NewStream() = %v", err) @@ -575,11 +576,10 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { t.Fatalf("stream id: %d, want: 3", s.id) } - // verifying that ct.state is draining when next stream ID > MaxStreamId. - if err := ct.compareStateForTesting(draining); err != nil { - t.Fatal(err) + // Verifying that ct.state is draining when next stream ID > MaxStreamId. + if got, want := ct.stateForTesting(), draining; got != want { + t.Fatalf("client transport state %v, want %v", got, want) } - } func (s) TestClientSendAndReceive(t *testing.T) { diff --git a/test/end2end_test.go b/test/end2end_test.go index a2d5f19bcd32..c4232b5f2a2e 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7032,32 +7032,34 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { }() } -type testAuthInfo struct { +// Implementing AuthInfo with net.Conn to send conn info up to Peer. +type authInfoWithConn struct { credentials.CommonAuthInfo connection net.Conn } -func (ta *testAuthInfo) AuthType() string { +func (ai *authInfoWithConn) AuthType() string { return "" } -// ConnWrapper wraps a net.Conn and pushes on a channel when closed. -type ConnWrapper struct { +// connWrapperWithCloseCh wraps a net.Conn and pushes on a channel when closed. +type connWrapperWithCloseCh struct { net.Conn CloseCh *testutils.Channel } // Close closes the connection and sends a value on the close channel. -func (cw *ConnWrapper) Close() error { +func (cw *connWrapperWithCloseCh) Close() error { err := cw.Conn.Close() cw.CloseCh.Replace(nil) return err } +// This custom creds are used for storing the connections made by the client. +// The CloseCh in conn can be used to detect when conn is closed. type transportRestartCheckCreds struct { mu sync.Mutex - connections []ConnWrapper - credentials.TransportCredentials + connections []connWrapperWithCloseCh } func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { @@ -7066,10 +7068,9 @@ func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { c.mu.Lock() defer c.mu.Unlock() - closeCh := testutils.NewChannel() - conn := &ConnWrapper{Conn: rawConn, CloseCh: closeCh} + conn := &connWrapperWithCloseCh{Conn: rawConn, CloseCh: testutils.NewChannel()} c.connections = append(c.connections, *conn) - authInfo := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}, connection: conn} + authInfo := &authInfoWithConn{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}, connection: conn} return conn, authInfo, nil } func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo { @@ -7090,17 +7091,16 @@ type streamWithPeer struct { } func (s *streamWithPeer) sendAndReceive() error { - req := &testpb.StreamingOutputCallRequest{} - if err := s.stream.Send(req); err != nil { - return fmt.Errorf("sending on stream1: %v", err) + if err := s.stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + return fmt.Errorf("sending on %s: %v", s.id, err) } if _, err := s.stream.Recv(); err != nil { - return fmt.Errorf("receiving on stream1: %v", err) + return fmt.Errorf("receiving on %s: %v", s.id, err) } return nil } -// tests that the client transport drains and restarts when next stream ID exceeds +// Tests that the client transport drains and restarts when next stream ID exceeds // MaxStreamID. This test also verifies that subsequent RPCs use a new client // transport and the old transport is closed. func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { @@ -7117,9 +7117,7 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { if _, err := stream.Recv(); err != nil { return err } - - res := &testpb.StreamingOutputCallResponse{} - if err := stream.Send(res); err != nil { + if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil { return err } } @@ -7127,7 +7125,6 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { } creds := &transportRestartCheckCreds{} - if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil { t.Fatalf("starting stubServer: %v", err) } @@ -7138,8 +7135,8 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { var streams []streamWithPeer - // setting up 3 streams and calling Send() and Recv() once on each. - for i := 0; i <= 2; i++ { + // Setting up 3 streams and calling Send() and Recv() once on each. + for i := 0; i < 3; i++ { var p peer.Peer s, err := ss.Client.FullDuplexCall(ctx, grpc.Peer(&p)) if err != nil { @@ -7153,12 +7150,12 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { t.Fatal(err) } } - // verify only 2 connections are created so far + // Verifying only 2 connections are created so far. if len(creds.connections) != 2 { t.Fatalf("number of connections created: %v, want: 2", len(creds.connections)) } - // verifying that streams on the first conn still works. + // Verifying that streams on the first conn still works. if err := streams[0].sendAndReceive(); err != nil { t.Fatal(err) } @@ -7166,32 +7163,37 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { t.Fatal(err) } - // closing all streams to get AuthInfo in Peer + // The peer passed via the call option is set up only after the RPC is complete, + // which is why we wait here until Recv() returns EOF. for i := 0; i < 3; i++ { if err := streams[i].stream.CloseSend(); err != nil { t.Fatalf("CloseSend() on %s: %v", streams[i].id, err) } - if _, err := streams[i].stream.Recv(); !strings.Contains(err.Error(), "EOF") { - t.Fatalf("receiving on %s got: %v, want: EOF", streams[i].id, err) - } else { - streams[i].conn = streams[i].peer.AuthInfo.(*testAuthInfo).connection + for { + _, err := streams[i].stream.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("receiving on %s got: %v, want: EOF", streams[i].id, err) + } } + streams[i].conn = streams[i].peer.AuthInfo.(*authInfoWithConn).connection } - // verifying that the 3rd RPC was made on a different connection. + // Verifying the first and second RPCs were made on the same connection. if streams[0].conn != streams[1].conn { t.Fatal("got streams using different connections; want same.") } - + // Verifying the third and first/second RPCs were made on different connections. if streams[2].conn == streams[0].conn { t.Fatal("got streams using same connections; want different.") } - // verifying if first connection was closed. + // Verifying first connection was closed. if _, err := creds.connections[0].CloseCh.Receive(ctx); err != nil { t.Fatal("timeout expired when waiting for first client transport to close") } - } func doHTTPHeaderTest(t *testing.T, errCode codes.Code, headerFields ...[]string) error {