Skip to content

Commit

Permalink
use Peer.AuthInfo to verify connections
Browse files Browse the repository at this point in the history
  • Loading branch information
arvindbr8 committed Dec 28, 2022
1 parent 00f1002 commit 4b1c7cd
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 68 deletions.
6 changes: 3 additions & 3 deletions internal/transport/defaults.go
Expand Up @@ -48,7 +48,7 @@ const (
defaultServerMaxHeaderListSize = uint32(16 << 20)
)

// MaxStreamIDForTesting is the upper bound for the stream ID before the current
// transport gracefully closes and a new transport is created for subsequent RPCs.
// MaxStreamID is the upper bound for the stream ID before the current
// transport gracefully closes and new transport is created for subsequent RPCs.
// This is set to 75% of math.MaxUint32. It's exported so that tests can override it.
var MaxStreamIDForTesting = uint32(float32(math.MaxUint32) * 0.75)
var MaxStreamID = uint32(float32(math.MaxUint32) * 0.75)
26 changes: 18 additions & 8 deletions internal/transport/http2_client.go
Expand Up @@ -768,6 +768,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
}
firstTry := true
var ch chan struct{}
transportDrainRequired := false
checkForStreamQuota := func(it interface{}) bool {
if t.streamQuota <= 0 { // Can go negative if server decreases it.
if firstTry {
Expand All @@ -783,6 +784,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
h := it.(*headerFrame)
h.streamID = t.nextID
t.nextID += 2

// drain client transport if nextID > MaxStreamID, which is set to
// 75% of math.MaxUint32, which signals gRPC that the connection is closed
// and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID

s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.mu.Lock()
Expand Down Expand Up @@ -815,7 +822,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
}
return true
}
transportDrainRequired := false
for {
success, err := t.controlBuf.executeAndPut(func(it interface{}) bool {
return checkForHeaderListSize(it) && checkForStreamQuota(it)
Expand All @@ -825,12 +831,6 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return nil, &NewStreamError{Err: err, AllowTransparentRetry: true}
}
if success {
// drain client transport if nextID > MaxStreamID, which is set to
// 75% of math.MaxUint32, which then signals gRPC to restart transport
// for subsequent RPCs.
if t.nextID > MaxStreamIDForTesting {
transportDrainRequired = true
}
break
}
if hdrListSizeErr != nil {
Expand Down Expand Up @@ -871,7 +871,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
}
if transportDrainRequired {
if logger.V(logLevel) {
logger.Infof("t.nextID > MaxStreamID. transport: draining")
logger.Infof("transport: t.nextID > MaxStreamID. Draining")
}
t.GracefulClose()
}
Expand Down Expand Up @@ -1796,3 +1796,13 @@ func (t *http2Client) getOutFlowWindow() int64 {
return -2
}
}

func (t *http2Client) compareStateForTesting(s transportState) error {
t.mu.Lock()
defer t.mu.Unlock()
if t.state != s {
return fmt.Errorf("clientTransport.state: %v, want: %v", t.state, s)
}

return nil
}
30 changes: 21 additions & 9 deletions internal/transport/transport_test.go
Expand Up @@ -536,21 +536,20 @@ func (s) TestInflightStreamClosing(t *testing.T) {
}
}

// TestClientTransportDrainsAfterStreamIdExhausted tests that when
// streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIdExhausted(t *testing.T) {
// tests that when streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
defer cancel()
defer server.stop()
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}
// override MaxStreamIDForTesting.
originalMaxStreamID := MaxStreamIDForTesting
MaxStreamIDForTesting = 1
// override MaxStreamID.
originalMaxStreamID := MaxStreamID
MaxStreamID = 3
defer func() {
MaxStreamIDForTesting = originalMaxStreamID
MaxStreamID = originalMaxStreamID
}()

ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
Expand All @@ -564,10 +563,23 @@ func (s) TestClientTransportDrainsAfterStreamIdExhausted(t *testing.T) {
t.Fatalf("stream id: %d, want: 1", s.id)
}

if err := ct.compareStateForTesting(reachable); err != nil {
t.Fatal(err)
}

s, err = ct.NewStream(ctx, callHdr)
if err != nil {
t.Fatalf("ct.NewStream() = %v", err)
}
if s.id != 3 {
t.Fatalf("stream id: %d, want: 3", s.id)
}

// verifying that ct.state is draining when next stream ID > MaxStreamId.
if ct.state != draining {
t.Fatalf("ct.state: %v, want: %v", ct.state, draining)
if err := ct.compareStateForTesting(draining); err != nil {
t.Fatal(err)
}

}

func (s) TestClientSendAndReceive(t *testing.T) {
Expand Down
180 changes: 132 additions & 48 deletions test/end2end_test.go
Expand Up @@ -1085,7 +1085,7 @@ func testFailFast(t *testing.T, e env) {
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
}
// Stop the server and tear down all the existing connections.
// Stop the server and tear down all the existing connection.
te.srv.Stop()
// Loop until the server teardown is propagated to the client.
for {
Expand Down Expand Up @@ -7032,79 +7032,163 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) {
}()
}

// TestClientTransportRestartsAfterStreamIdExhausted tests that the client transport
// drains and restarts when next stream ID exceeds MaxStreamID. This test also
// verifies that subsequent RPCs use a new client transport and the old transport
// is closed.
func (s) TestClientTransportRestartsAfterStreamIdExhausted(t *testing.T) {
// overriding MaxStreamIDForTesting.
originalMaxStreamID := transport.MaxStreamIDForTesting
transport.MaxStreamIDForTesting = 5
type testAuthInfo struct {
credentials.CommonAuthInfo
connection net.Conn
}

func (ta *testAuthInfo) AuthType() string {
return ""
}

// ConnWrapper wraps a net.Conn and pushes on a channel when closed.
type ConnWrapper struct {
net.Conn
CloseCh *testutils.Channel
}

// Close closes the connection and sends a value on the close channel.
func (cw *ConnWrapper) Close() error {
err := cw.Conn.Close()
cw.CloseCh.Replace(nil)
return err
}

type transportRestartCheckCreds struct {
mu sync.Mutex
connections []ConnWrapper
credentials.TransportCredentials
}

func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
closeCh := testutils.NewChannel()
conn := &ConnWrapper{Conn: rawConn, CloseCh: closeCh}
c.connections = append(c.connections, *conn)
authInfo := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}, connection: conn}
return conn, authInfo, nil
}
func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials {
return &transportRestartCheckCreds{}
}
func (c *transportRestartCheckCreds) OverrideServerName(s string) error {
return nil
}

type streamWithPeer struct {
stream testpb.TestService_FullDuplexCallClient
peer *peer.Peer
id string
conn net.Conn
}

func (s *streamWithPeer) sendAndReceive() error {
req := &testpb.StreamingOutputCallRequest{}
if err := s.stream.Send(req); err != nil {
return fmt.Errorf("sending on stream1: %v", err)
}
if _, err := s.stream.Recv(); err != nil {
return fmt.Errorf("receiving on stream1: %v", err)
}
return nil
}

// tests that the client transport drains and restarts when next stream ID exceeds
// MaxStreamID. This test also verifies that subsequent RPCs use a new client
// transport and the old transport is closed.
func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
// Set the transport's MaxStreamID to 5 to cause connection to drain after 2 RPCs.
originalMaxStreamID := transport.MaxStreamID
transport.MaxStreamID = 5
defer func() {
transport.MaxStreamIDForTesting = originalMaxStreamID
transport.MaxStreamID = originalMaxStreamID
}()

// setting up StubServer.
s := grpc.NewServer()
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
for i := 0; i < 10; i++ {
m := &grpc.PreparedMsg{}
stream.SendMsg(m)
for {
if _, err := stream.Recv(); err != nil {
return err
}

res := &testpb.StreamingOutputCallResponse{}
if err := stream.Send(res); err != nil {
return err
}
}
return nil
},
}
testpb.RegisterTestServiceServer(s, ss)

// setting up gRPC server with ListenerWrapper.
lisWrap := testutils.NewListenerWrapper(t, nil)
go s.Serve(lisWrap)
defer s.Stop()
creds := &transportRestartCheckCreds{}

cc, err := grpc.Dial(lisWrap.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("dial(%q): %v", lisWrap.Addr().String(), err)
if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil {
t.Fatalf("starting stubServer: %v", err)
}
defer cc.Close()
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

client := testpb.NewTestServiceClient(cc)
var streams []streamWithPeer

createStreamAndRecv := func() {
stream, err := client.FullDuplexCall(ctx)
// setting up 3 streams and calling Send() and Recv() once on each.
for i := 0; i <= 2; i++ {
var p peer.Peer
s, err := ss.Client.FullDuplexCall(ctx, grpc.Peer(&p))
if err != nil {
t.Fatalf("creating FullDuplex stream: %v", err)
}
if _, err = stream.Recv(); err != nil && err != io.EOF {
t.Fatalf("stream.Recv() = _, %v want: nil or EOF", err)
streamID := fmt.Sprintf("stream %v", i)
sp := streamWithPeer{stream: s, peer: &p, id: streamID}
streams = append(streams, sp)

if err := sp.sendAndReceive(); err != nil {
t.Fatal(err)
}
}
// verify only 2 connections are created so far
if len(creds.connections) != 2 {
t.Fatalf("number of connections created: %v, want: 2", len(creds.connections))
}

// creating FullDuplexCall stream #1.
createStreamAndRecv()
// creating FullDuplexCall stream #2.
createStreamAndRecv()
// verifying that streams on the first conn still works.
if err := streams[0].sendAndReceive(); err != nil {
t.Fatal(err)
}
if err := streams[1].sendAndReceive(); err != nil {
t.Fatal(err)
}

// verifying creation of new conn channel.
val, err := lisWrap.NewConnCh.Receive(ctx)
if err != nil {
t.Fatal("timeout expired when waiting to create new conn channel")
// closing all streams to get AuthInfo in Peer
for i := 0; i < 3; i++ {
if err := streams[i].stream.CloseSend(); err != nil {
t.Fatalf("CloseSend() on %s: %v", streams[i].id, err)
}
if _, err := streams[i].stream.Recv(); !strings.Contains(err.Error(), "EOF") {
t.Fatalf("receiving on %s got: %v, want: EOF", streams[i].id, err)
} else {
streams[i].conn = streams[i].peer.AuthInfo.(*testAuthInfo).connection
}
}
conn1 := val.(*testutils.ConnWrapper)

// this stream should be created in a new conn channel.
createStreamAndRecv()
// verifying that the 3rd RPC was made on a different connection.
if streams[0].conn != streams[1].conn {
t.Fatal("got streams using different connections; want same.")
}

// verifying a new conn channel is created.
if _, err = lisWrap.NewConnCh.Receive(ctx); err != nil {
t.Fatal("timeout expired when waiting to create new conn channel")
if streams[2].conn == streams[0].conn {
t.Fatal("got streams using same connections; want different.")
}

// verifying the connection to the old one is drained and closed.
if _, err := conn1.CloseCh.Receive(ctx); err != nil {
// verifying if first connection was closed.
if _, err := creds.connections[0].CloseCh.Receive(ctx); err != nil {
t.Fatal("timeout expired when waiting for first client transport to close")
}

Expand Down Expand Up @@ -7536,7 +7620,7 @@ func (s) TestAuthorityHeader(t *testing.T) {
}

// wrapCloseListener tracks Accepts/Closes and maintains a counter of the
// number of open connections.
// number of open connection.
type wrapCloseListener struct {
net.Listener
connsOpen int32
Expand Down

0 comments on commit 4b1c7cd

Please sign in to comment.