Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
arvindbr8 committed Dec 28, 2022
1 parent 330e01e commit 97fe966
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 49 deletions.
13 changes: 4 additions & 9 deletions internal/transport/http2_client.go
Expand Up @@ -785,9 +785,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
h.streamID = t.nextID
t.nextID += 2

// drain client transport if nextID > MaxStreamID, which is set to
// 75% of math.MaxUint32, which signals gRPC that the connection is closed
// and a new one must be created for subsequent RPCs.
// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID

s.id = h.streamID
Expand Down Expand Up @@ -1797,12 +1796,8 @@ func (t *http2Client) getOutFlowWindow() int64 {
}
}

func (t *http2Client) compareStateForTesting(s transportState) error {
func (t *http2Client) stateForTesting() transportState {
t.mu.Lock()
defer t.mu.Unlock()
if t.state != s {
return fmt.Errorf("clientTransport.state: %v, want: %v", t.state, s)
}

return nil
return t.state
}
16 changes: 8 additions & 8 deletions internal/transport/transport_test.go
Expand Up @@ -536,7 +536,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
}
}

// tests that when streamID > MaxStreamId, the current client transport drains.
// Tests that when streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
defer cancel()
Expand All @@ -545,7 +545,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
Host: "localhost",
Method: "foo.Small",
}
// override MaxStreamID.

originalMaxStreamID := MaxStreamID
MaxStreamID = 3
defer func() {
Expand All @@ -563,10 +563,11 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
t.Fatalf("stream id: %d, want: 1", s.id)
}

if err := ct.compareStateForTesting(reachable); err != nil {
t.Fatal(err)
if got, want := ct.stateForTesting(), reachable; got != want {
t.Fatalf("client transport state %v, want %v", got, want)
}

// The expected stream ID here is 3 since stream IDs are incremented by 2.
s, err = ct.NewStream(ctx, callHdr)
if err != nil {
t.Fatalf("ct.NewStream() = %v", err)
Expand All @@ -575,11 +576,10 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
t.Fatalf("stream id: %d, want: 3", s.id)
}

// verifying that ct.state is draining when next stream ID > MaxStreamId.
if err := ct.compareStateForTesting(draining); err != nil {
t.Fatal(err)
// Verifying that ct.state is draining when next stream ID > MaxStreamId.
if got, want := ct.stateForTesting(), draining; got != want {
t.Fatalf("client transport state %v, want %v", got, want)
}

}

func (s) TestClientSendAndReceive(t *testing.T) {
Expand Down
66 changes: 34 additions & 32 deletions test/end2end_test.go
Expand Up @@ -7032,32 +7032,34 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) {
}()
}

type testAuthInfo struct {
// Implementing AuthInfo with net.Conn to send conn info up to Peer.
type authInfoWithConn struct {
credentials.CommonAuthInfo
connection net.Conn
}

func (ta *testAuthInfo) AuthType() string {
func (ai *authInfoWithConn) AuthType() string {
return ""
}

// ConnWrapper wraps a net.Conn and pushes on a channel when closed.
type ConnWrapper struct {
// connWrapperWithCloseCh wraps a net.Conn and pushes on a channel when closed.
type connWrapperWithCloseCh struct {
net.Conn
CloseCh *testutils.Channel
}

// Close closes the connection and sends a value on the close channel.
func (cw *ConnWrapper) Close() error {
func (cw *connWrapperWithCloseCh) Close() error {
err := cw.Conn.Close()
cw.CloseCh.Replace(nil)
return err
}

// This custom creds are used for storing the connections made by the client.
// The CloseCh in conn can be used to detect when conn is closed.
type transportRestartCheckCreds struct {
mu sync.Mutex
connections []ConnWrapper
credentials.TransportCredentials
connections []connWrapperWithCloseCh
}

func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
Expand All @@ -7066,10 +7068,9 @@ func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn
func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
closeCh := testutils.NewChannel()
conn := &ConnWrapper{Conn: rawConn, CloseCh: closeCh}
conn := &connWrapperWithCloseCh{Conn: rawConn, CloseCh: testutils.NewChannel()}
c.connections = append(c.connections, *conn)
authInfo := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}, connection: conn}
authInfo := &authInfoWithConn{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}, connection: conn}
return conn, authInfo, nil
}
func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo {
Expand All @@ -7090,17 +7091,16 @@ type streamWithPeer struct {
}

func (s *streamWithPeer) sendAndReceive() error {
req := &testpb.StreamingOutputCallRequest{}
if err := s.stream.Send(req); err != nil {
return fmt.Errorf("sending on stream1: %v", err)
if err := s.stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
return fmt.Errorf("sending on %s: %v", s.id, err)
}
if _, err := s.stream.Recv(); err != nil {
return fmt.Errorf("receiving on stream1: %v", err)
return fmt.Errorf("receiving on %s: %v", s.id, err)
}
return nil
}

// tests that the client transport drains and restarts when next stream ID exceeds
// Tests that the client transport drains and restarts when next stream ID exceeds
// MaxStreamID. This test also verifies that subsequent RPCs use a new client
// transport and the old transport is closed.
func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
Expand All @@ -7117,17 +7117,14 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
if _, err := stream.Recv(); err != nil {
return err
}

res := &testpb.StreamingOutputCallResponse{}
if err := stream.Send(res); err != nil {
if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil {
return err
}
}
},
}

creds := &transportRestartCheckCreds{}

if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil {
t.Fatalf("starting stubServer: %v", err)
}
Expand All @@ -7138,8 +7135,8 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {

var streams []streamWithPeer

// setting up 3 streams and calling Send() and Recv() once on each.
for i := 0; i <= 2; i++ {
// Setting up 3 streams and calling Send() and Recv() once on each.
for i := 0; i < 3; i++ {
var p peer.Peer
s, err := ss.Client.FullDuplexCall(ctx, grpc.Peer(&p))
if err != nil {
Expand All @@ -7153,45 +7150,50 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
t.Fatal(err)
}
}
// verify only 2 connections are created so far
// Verifying only 2 connections are created so far.
if len(creds.connections) != 2 {
t.Fatalf("number of connections created: %v, want: 2", len(creds.connections))
}

// verifying that streams on the first conn still works.
// Verifying that streams on the first conn still works.
if err := streams[0].sendAndReceive(); err != nil {
t.Fatal(err)
}
if err := streams[1].sendAndReceive(); err != nil {
t.Fatal(err)
}

// closing all streams to get AuthInfo in Peer
// The peer passed via the call option is set up only after the RPC is complete,
// which is why we wait here until Recv() returns EOF.
for i := 0; i < 3; i++ {
if err := streams[i].stream.CloseSend(); err != nil {
t.Fatalf("CloseSend() on %s: %v", streams[i].id, err)
}
if _, err := streams[i].stream.Recv(); !strings.Contains(err.Error(), "EOF") {
t.Fatalf("receiving on %s got: %v, want: EOF", streams[i].id, err)
} else {
streams[i].conn = streams[i].peer.AuthInfo.(*testAuthInfo).connection
for {
_, err := streams[i].stream.Recv()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("receiving on %s got: %v, want: EOF", streams[i].id, err)
}
}
streams[i].conn = streams[i].peer.AuthInfo.(*authInfoWithConn).connection
}

// verifying that the 3rd RPC was made on a different connection.
// Verifying the first and second RPCs were made on the same connection.
if streams[0].conn != streams[1].conn {
t.Fatal("got streams using different connections; want same.")
}

// Verifying the third and first/second RPCs were made on different connections.
if streams[2].conn == streams[0].conn {
t.Fatal("got streams using same connections; want different.")
}

// verifying if first connection was closed.
// Verifying first connection was closed.
if _, err := creds.connections[0].CloseCh.Receive(ctx); err != nil {
t.Fatal("timeout expired when waiting for first client transport to close")
}

}

func doHTTPHeaderTest(t *testing.T, errCode codes.Code, headerFields ...[]string) error {
Expand Down

0 comments on commit 97fe966

Please sign in to comment.