diff --git a/pickfirst.go b/pickfirst.go index abe266b021d..6bf4701c534 100644 --- a/pickfirst.go +++ b/pickfirst.go @@ -153,6 +153,7 @@ func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state b } return } + b.state = state.ConnectivityState if state.ConnectivityState == connectivity.Shutdown { b.subConn = nil return @@ -165,21 +166,11 @@ func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state b Picker: &picker{result: balancer.PickResult{SubConn: subConn}}, }) case connectivity.Connecting: - if b.state == connectivity.TransientFailure { - // We stay in TransientFailure until we are Ready. See A62. - return - } b.cc.UpdateState(balancer.State{ ConnectivityState: state.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) case connectivity.Idle: - if b.state == connectivity.TransientFailure { - // We stay in TransientFailure until we are Ready. Also kick the - // subConn out of Idle into Connecting. See A62. - b.subConn.Connect() - return - } b.cc.UpdateState(balancer.State{ ConnectivityState: state.ConnectivityState, Picker: &idlePicker{subConn: subConn}, @@ -190,7 +181,6 @@ func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state b Picker: &picker{err: state.ConnectionError}, }) } - b.state = state.ConnectivityState } func (b *pickfirstBalancer) Close() { diff --git a/test/pickfirst_test.go b/test/pickfirst_test.go index 55659b928a5..800d2f4178c 100644 --- a/test/pickfirst_test.go +++ b/test/pickfirst_test.go @@ -20,20 +20,15 @@ package test import ( "context" - "sync" "testing" "time" "google.golang.org/grpc" - "google.golang.org/grpc/backoff" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/channelz" - "google.golang.org/grpc/internal/envconfig" - "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/stubserver" - "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/testutils/pickfirst" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" @@ -298,180 +293,3 @@ func (s) TestPickFirst_NewAddressWhileBlocking(t *testing.T) { case <-doneCh: } } - -func (s) TestPickFirst_StickyTransientFailure(t *testing.T) { - // Spin up a local server which closes the connection as soon as it receives - // one. It also sends a signal on a channel whenver it received a connection. - lis, err := testutils.LocalTCPListener() - if err != nil { - t.Fatalf("Failed to create listener: %v", err) - } - t.Cleanup(func() { lis.Close() }) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - connCh := make(chan struct{}, 1) - go func() { - for { - conn, err := lis.Accept() - if err != nil { - return - } - select { - case connCh <- struct{}{}: - conn.Close() - case <-ctx.Done(): - return - } - } - }() - - // Dial the above server with a ConnectParams that does a constant backoff - // of defaultTestShortTimeout duration. - dopts := []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultServiceConfig(pickFirstServiceConfig), - grpc.WithConnectParams(grpc.ConnectParams{ - Backoff: backoff.Config{ - BaseDelay: defaultTestShortTimeout, - Multiplier: float64(0), - Jitter: float64(0), - MaxDelay: defaultTestShortTimeout, - }, - }), - } - cc, err := grpc.Dial(lis.Addr().String(), dopts...) - if err != nil { - t.Fatalf("Failed to dial server at %q: %v", lis.Addr(), err) - } - t.Cleanup(func() { cc.Close() }) - - var wg sync.WaitGroup - wg.Add(2) - // Spin up a goroutine that waits for the channel to move to - // TransientFailure. After that it checks that the channel stays in - // TransientFailure, until Shutdown. - go func() { - defer wg.Done() - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Errorf("Timeout when waiting for state to change to TransientFailure. Current state is %s", state) - return - } - } - - // TODO(easwars): this waits for 10s. Need shorter deadline here. Basically once the second goroutine exits, we should exit from here too. - if cc.WaitForStateChange(ctx, connectivity.TransientFailure) { - if state := cc.GetState(); state != connectivity.Shutdown { - t.Errorf("Unexpected state change from TransientFailure to %s", cc.GetState()) - } - } - }() - // Spin up a goroutine which ensures that the pick_first LB policy is - // constantly trying to reconnect. - go func() { - defer wg.Done() - for i := 0; i < 10; i++ { - select { - case <-connCh: - case <-time.After(2 * defaultTestShortTimeout): - t.Error("Timeout when waiting for pick_first to reconnect") - } - } - }() - wg.Wait() -} - -// Tests the PF LB policy with shuffling enabled. -func (s) TestPickFirst_ShuffleAddressList(t *testing.T) { - defer func(old bool) { envconfig.PickFirstLBConfig = old }(envconfig.PickFirstLBConfig) - envconfig.PickFirstLBConfig = true - const serviceConfig = `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": true }}]}` - - // Install a shuffler that always reverses two entries. - origShuf := grpcrand.Shuffle - defer func() { grpcrand.Shuffle = origShuf }() - grpcrand.Shuffle = func(n int, f func(int, int)) { - if n != 2 { - t.Errorf("Shuffle called with n=%v; want 2", n) - return - } - f(0, 1) // reverse the two addresses - } - - // Set up our backends. - cc, r, backends := setupPickFirst(t, 2) - addrs := stubBackendsToResolverAddrs(backends) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - - // Push an update with both addresses and shuffling disabled. We should - // connect to backend 0. - r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}}) - if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { - t.Fatal(err) - } - - // Send a config with shuffling enabled. This will reverse the addresses, - // but the channel should still be connected to backend 0. - shufState := resolver.State{ - ServiceConfig: parseServiceConfig(t, r, serviceConfig), - Addresses: []resolver.Address{addrs[0], addrs[1]}, - } - r.UpdateState(shufState) - if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { - t.Fatal(err) - } - - // Send a resolver update with no addresses. This should push the channel - // into TransientFailure. - r.UpdateState(resolver.State{}) - awaitState(ctx, t, cc, connectivity.TransientFailure) - - // Send the same config as last time with shuffling enabled. Since we are - // not connected to backend 0, we should connect to backend 1. - r.UpdateState(shufState) - if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { - t.Fatal(err) - } -} - -// Tests the PF LB policy with the environment variable support of address list -// shuffling disabled. -func (s) TestPickFirst_ShuffleAddressListDisabled(t *testing.T) { - defer func(old bool) { envconfig.PickFirstLBConfig = old }(envconfig.PickFirstLBConfig) - envconfig.PickFirstLBConfig = false - const serviceConfig = `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": true }}]}` - - // Install a shuffler that always reverses two entries. - origShuf := grpcrand.Shuffle - defer func() { grpcrand.Shuffle = origShuf }() - grpcrand.Shuffle = func(n int, f func(int, int)) { - if n != 2 { - t.Errorf("Shuffle called with n=%v; want 2", n) - return - } - f(0, 1) // reverse the two addresses - } - - // Set up our backends. - cc, r, backends := setupPickFirst(t, 2) - addrs := stubBackendsToResolverAddrs(backends) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - - // Send a config with shuffling enabled. This will reverse the addresses, - // so we should connect to backend 1 if shuffling is supported. However - // with it disabled at the start of the test, we will connect to backend 0 - // instead. - shufState := resolver.State{ - ServiceConfig: parseServiceConfig(t, r, serviceConfig), - Addresses: []resolver.Address{addrs[0], addrs[1]}, - } - r.UpdateState(shufState) - if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { - t.Fatal(err) - } -}