diff --git a/pickfirst.go b/pickfirst.go index fc91b4d266d..89e54196e1e 100644 --- a/pickfirst.go +++ b/pickfirst.go @@ -119,7 +119,6 @@ func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state b } return } - b.state = state.ConnectivityState if state.ConnectivityState == connectivity.Shutdown { b.subConn = nil return @@ -132,11 +131,21 @@ func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state b Picker: &picker{result: balancer.PickResult{SubConn: subConn}}, }) case connectivity.Connecting: + if b.state == connectivity.TransientFailure { + // We stay in TransientFailure until we are Ready. See A62. + return + } b.cc.UpdateState(balancer.State{ ConnectivityState: state.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) case connectivity.Idle: + if b.state == connectivity.TransientFailure { + // We stay in TransientFailure until we are Ready. Also kick the + // subConn out of Idle into Connecting. See A62. + b.subConn.Connect() + return + } b.cc.UpdateState(balancer.State{ ConnectivityState: state.ConnectivityState, Picker: &idlePicker{subConn: subConn}, @@ -147,6 +156,7 @@ func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state b Picker: &picker{err: state.ConnectionError}, }) } + b.state = state.ConnectivityState } func (b *pickfirstBalancer) Close() { diff --git a/test/pickfirst_test.go b/test/pickfirst_test.go index 800d2f4178c..75cb2a659ed 100644 --- a/test/pickfirst_test.go +++ b/test/pickfirst_test.go @@ -20,15 +20,18 @@ package test import ( "context" + "sync" "testing" "time" "google.golang.org/grpc" + "google.golang.org/grpc/backoff" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/testutils/pickfirst" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" @@ -293,3 +296,86 @@ func (s) TestPickFirst_NewAddressWhileBlocking(t *testing.T) { case <-doneCh: } } + +func (s) TestPickFirst_StickyTransientFailure(t *testing.T) { + // Spin up a local server which closes the connection as soon as it receives + // one. It also sends a signal on a channel whenver it received a connection. + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + t.Cleanup(func() { lis.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + connCh := make(chan struct{}, 1) + go func() { + for { + conn, err := lis.Accept() + if err != nil { + return + } + select { + case connCh <- struct{}{}: + conn.Close() + case <-ctx.Done(): + return + } + } + }() + + // Dial the above server with a ConnectParams that does a constant backoff + // of defaultTestShortTimeout duration. + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(pickFirstServiceConfig), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.Config{ + BaseDelay: defaultTestShortTimeout, + Multiplier: float64(0), + Jitter: float64(0), + MaxDelay: defaultTestShortTimeout, + }, + }), + } + cc, err := grpc.Dial(lis.Addr().String(), dopts...) + if err != nil { + t.Fatalf("Failed to dial server at %q: %v", lis.Addr(), err) + } + t.Cleanup(func() { cc.Close() }) + + var wg sync.WaitGroup + wg.Add(2) + // Spin up a goroutine that waits for the channel to move to + // TransientFailure. After that it checks that the channel stays in + // TransientFailure, until Shutdown. + go func() { + defer wg.Done() + for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { + if !cc.WaitForStateChange(ctx, state) { + t.Errorf("Timeout when waiting for state to change to TransientFailure. Current state is %s", state) + return + } + } + + // TODO(easwars): this waits for 10s. Need shorter deadline here. Basically once the second goroutine exits, we should exit from here too. + if cc.WaitForStateChange(ctx, connectivity.TransientFailure) { + if state := cc.GetState(); state != connectivity.Shutdown { + t.Errorf("Unexpected state change from TransientFailure to %s", cc.GetState()) + } + } + }() + // Spin up a goroutine which ensures that the pick_first LB policy is + // constantly trying to reconnect. + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + select { + case <-connCh: + case <-time.After(2 * defaultTestShortTimeout): + t.Error("Timeout when waiting for pick_first to reconnect") + } + } + }() + wg.Wait() +}