From 4b1c7cdf712ba7d558caba1124d5d6f45956169c Mon Sep 17 00:00:00 2001 From: Arvind Bright Date: Tue, 27 Dec 2022 20:19:53 -0600 Subject: [PATCH] use Peer.AuthInfo to verify connections --- internal/transport/defaults.go | 6 +- internal/transport/http2_client.go | 26 ++-- internal/transport/transport_test.go | 30 +++-- test/end2end_test.go | 180 ++++++++++++++++++++------- 4 files changed, 174 insertions(+), 68 deletions(-) diff --git a/internal/transport/defaults.go b/internal/transport/defaults.go index ddfa4fee5b23..7579624a2532 100644 --- a/internal/transport/defaults.go +++ b/internal/transport/defaults.go @@ -48,7 +48,7 @@ const ( defaultServerMaxHeaderListSize = uint32(16 << 20) ) -// MaxStreamIDForTesting is the upper bound for the stream ID before the current -// transport gracefully closes and a new transport is created for subsequent RPCs. +// MaxStreamID is the upper bound for the stream ID before the current +// transport gracefully closes and new transport is created for subsequent RPCs. // This is set to 75% of math.MaxUint32. It's exported so that tests can override it. -var MaxStreamIDForTesting = uint32(float32(math.MaxUint32) * 0.75) +var MaxStreamID = uint32(float32(math.MaxUint32) * 0.75) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 901b8a9eba5b..c977b88c524d 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -768,6 +768,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, } firstTry := true var ch chan struct{} + transportDrainRequired := false checkForStreamQuota := func(it interface{}) bool { if t.streamQuota <= 0 { // Can go negative if server decreases it. if firstTry { @@ -783,6 +784,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, h := it.(*headerFrame) h.streamID = t.nextID t.nextID += 2 + + // drain client transport if nextID > MaxStreamID, which is set to + // 75% of math.MaxUint32, which signals gRPC that the connection is closed + // and a new one must be created for subsequent RPCs. + transportDrainRequired = t.nextID > MaxStreamID + s.id = h.streamID s.fc = &inFlow{limit: uint32(t.initialWindowSize)} t.mu.Lock() @@ -815,7 +822,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, } return true } - transportDrainRequired := false for { success, err := t.controlBuf.executeAndPut(func(it interface{}) bool { return checkForHeaderListSize(it) && checkForStreamQuota(it) @@ -825,12 +831,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, return nil, &NewStreamError{Err: err, AllowTransparentRetry: true} } if success { - // drain client transport if nextID > MaxStreamID, which is set to - // 75% of math.MaxUint32, which then signals gRPC to restart transport - // for subsequent RPCs. - if t.nextID > MaxStreamIDForTesting { - transportDrainRequired = true - } break } if hdrListSizeErr != nil { @@ -871,7 +871,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, } if transportDrainRequired { if logger.V(logLevel) { - logger.Infof("t.nextID > MaxStreamID. transport: draining") + logger.Infof("transport: t.nextID > MaxStreamID. Draining") } t.GracefulClose() } @@ -1796,3 +1796,13 @@ func (t *http2Client) getOutFlowWindow() int64 { return -2 } } + +func (t *http2Client) compareStateForTesting(s transportState) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.state != s { + return fmt.Errorf("clientTransport.state: %v, want: %v", t.state, s) + } + + return nil +} diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 8e656e91d2c2..ab20294ef744 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -536,9 +536,8 @@ func (s) TestInflightStreamClosing(t *testing.T) { } } -// TestClientTransportDrainsAfterStreamIdExhausted tests that when -// streamID > MaxStreamId, the current client transport drains. -func (s) TestClientTransportDrainsAfterStreamIdExhausted(t *testing.T) { +// tests that when streamID > MaxStreamId, the current client transport drains. +func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) defer cancel() defer server.stop() @@ -546,11 +545,11 @@ func (s) TestClientTransportDrainsAfterStreamIdExhausted(t *testing.T) { Host: "localhost", Method: "foo.Small", } - // override MaxStreamIDForTesting. - originalMaxStreamID := MaxStreamIDForTesting - MaxStreamIDForTesting = 1 + // override MaxStreamID. + originalMaxStreamID := MaxStreamID + MaxStreamID = 3 defer func() { - MaxStreamIDForTesting = originalMaxStreamID + MaxStreamID = originalMaxStreamID }() ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) @@ -564,10 +563,23 @@ func (s) TestClientTransportDrainsAfterStreamIdExhausted(t *testing.T) { t.Fatalf("stream id: %d, want: 1", s.id) } + if err := ct.compareStateForTesting(reachable); err != nil { + t.Fatal(err) + } + + s, err = ct.NewStream(ctx, callHdr) + if err != nil { + t.Fatalf("ct.NewStream() = %v", err) + } + if s.id != 3 { + t.Fatalf("stream id: %d, want: 3", s.id) + } + // verifying that ct.state is draining when next stream ID > MaxStreamId. - if ct.state != draining { - t.Fatalf("ct.state: %v, want: %v", ct.state, draining) + if err := ct.compareStateForTesting(draining); err != nil { + t.Fatal(err) } + } func (s) TestClientSendAndReceive(t *testing.T) { diff --git a/test/end2end_test.go b/test/end2end_test.go index ffa57cce84b4..68aa32c97e6d 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -1085,7 +1085,7 @@ func testFailFast(t *testing.T, e env) { if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } - // Stop the server and tear down all the existing connections. + // Stop the server and tear down all the existing connection. te.srv.Stop() // Loop until the server teardown is propagated to the client. for { @@ -7032,79 +7032,163 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { }() } -// TestClientTransportRestartsAfterStreamIdExhausted tests that the client transport -// drains and restarts when next stream ID exceeds MaxStreamID. This test also -// verifies that subsequent RPCs use a new client transport and the old transport -// is closed. -func (s) TestClientTransportRestartsAfterStreamIdExhausted(t *testing.T) { - // overriding MaxStreamIDForTesting. - originalMaxStreamID := transport.MaxStreamIDForTesting - transport.MaxStreamIDForTesting = 5 +type testAuthInfo struct { + credentials.CommonAuthInfo + connection net.Conn +} + +func (ta *testAuthInfo) AuthType() string { + return "" +} + +// ConnWrapper wraps a net.Conn and pushes on a channel when closed. +type ConnWrapper struct { + net.Conn + CloseCh *testutils.Channel +} + +// Close closes the connection and sends a value on the close channel. +func (cw *ConnWrapper) Close() error { + err := cw.Conn.Close() + cw.CloseCh.Replace(nil) + return err +} + +type transportRestartCheckCreds struct { + mu sync.Mutex + connections []ConnWrapper + credentials.TransportCredentials +} + +func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.mu.Lock() + defer c.mu.Unlock() + closeCh := testutils.NewChannel() + conn := &ConnWrapper{Conn: rawConn, CloseCh: closeCh} + c.connections = append(c.connections, *conn) + authInfo := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}, connection: conn} + return conn, authInfo, nil +} +func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials { + return &transportRestartCheckCreds{} +} +func (c *transportRestartCheckCreds) OverrideServerName(s string) error { + return nil +} + +type streamWithPeer struct { + stream testpb.TestService_FullDuplexCallClient + peer *peer.Peer + id string + conn net.Conn +} + +func (s *streamWithPeer) sendAndReceive() error { + req := &testpb.StreamingOutputCallRequest{} + if err := s.stream.Send(req); err != nil { + return fmt.Errorf("sending on stream1: %v", err) + } + if _, err := s.stream.Recv(); err != nil { + return fmt.Errorf("receiving on stream1: %v", err) + } + return nil +} + +// tests that the client transport drains and restarts when next stream ID exceeds +// MaxStreamID. This test also verifies that subsequent RPCs use a new client +// transport and the old transport is closed. +func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { + // Set the transport's MaxStreamID to 5 to cause connection to drain after 2 RPCs. + originalMaxStreamID := transport.MaxStreamID + transport.MaxStreamID = 5 defer func() { - transport.MaxStreamIDForTesting = originalMaxStreamID + transport.MaxStreamID = originalMaxStreamID }() - // setting up StubServer. - s := grpc.NewServer() ss := &stubserver.StubServer{ FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { - for i := 0; i < 10; i++ { - m := &grpc.PreparedMsg{} - stream.SendMsg(m) + for { + if _, err := stream.Recv(); err != nil { + return err + } + + res := &testpb.StreamingOutputCallResponse{} + if err := stream.Send(res); err != nil { + return err + } } - return nil }, } - testpb.RegisterTestServiceServer(s, ss) - // setting up gRPC server with ListenerWrapper. - lisWrap := testutils.NewListenerWrapper(t, nil) - go s.Serve(lisWrap) - defer s.Stop() + creds := &transportRestartCheckCreds{} - cc, err := grpc.Dial(lisWrap.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - t.Fatalf("dial(%q): %v", lisWrap.Addr().String(), err) + if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil { + t.Fatalf("starting stubServer: %v", err) } - defer cc.Close() + defer ss.Stop() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - client := testpb.NewTestServiceClient(cc) + var streams []streamWithPeer - createStreamAndRecv := func() { - stream, err := client.FullDuplexCall(ctx) + // setting up 3 streams and calling Send() and Recv() once on each. + for i := 0; i <= 2; i++ { + var p peer.Peer + s, err := ss.Client.FullDuplexCall(ctx, grpc.Peer(&p)) if err != nil { t.Fatalf("creating FullDuplex stream: %v", err) } - if _, err = stream.Recv(); err != nil && err != io.EOF { - t.Fatalf("stream.Recv() = _, %v want: nil or EOF", err) + streamID := fmt.Sprintf("stream %v", i) + sp := streamWithPeer{stream: s, peer: &p, id: streamID} + streams = append(streams, sp) + + if err := sp.sendAndReceive(); err != nil { + t.Fatal(err) } } + // verify only 2 connections are created so far + if len(creds.connections) != 2 { + t.Fatalf("number of connections created: %v, want: 2", len(creds.connections)) + } - // creating FullDuplexCall stream #1. - createStreamAndRecv() - // creating FullDuplexCall stream #2. - createStreamAndRecv() + // verifying that streams on the first conn still works. + if err := streams[0].sendAndReceive(); err != nil { + t.Fatal(err) + } + if err := streams[1].sendAndReceive(); err != nil { + t.Fatal(err) + } - // verifying creation of new conn channel. - val, err := lisWrap.NewConnCh.Receive(ctx) - if err != nil { - t.Fatal("timeout expired when waiting to create new conn channel") + // closing all streams to get AuthInfo in Peer + for i := 0; i < 3; i++ { + if err := streams[i].stream.CloseSend(); err != nil { + t.Fatalf("CloseSend() on %s: %v", streams[i].id, err) + } + if _, err := streams[i].stream.Recv(); !strings.Contains(err.Error(), "EOF") { + t.Fatalf("receiving on %s got: %v, want: EOF", streams[i].id, err) + } else { + streams[i].conn = streams[i].peer.AuthInfo.(*testAuthInfo).connection + } } - conn1 := val.(*testutils.ConnWrapper) - // this stream should be created in a new conn channel. - createStreamAndRecv() + // verifying that the 3rd RPC was made on a different connection. + if streams[0].conn != streams[1].conn { + t.Fatal("got streams using different connections; want same.") + } - // verifying a new conn channel is created. - if _, err = lisWrap.NewConnCh.Receive(ctx); err != nil { - t.Fatal("timeout expired when waiting to create new conn channel") + if streams[2].conn == streams[0].conn { + t.Fatal("got streams using same connections; want different.") } - // verifying the connection to the old one is drained and closed. - if _, err := conn1.CloseCh.Receive(ctx); err != nil { + // verifying if first connection was closed. + if _, err := creds.connections[0].CloseCh.Receive(ctx); err != nil { t.Fatal("timeout expired when waiting for first client transport to close") } @@ -7536,7 +7620,7 @@ func (s) TestAuthorityHeader(t *testing.T) { } // wrapCloseListener tracks Accepts/Closes and maintains a counter of the -// number of open connections. +// number of open connection. type wrapCloseListener struct { net.Listener connsOpen int32