diff --git a/go.mod b/go.mod index 4fafb24d2be..595a9d3c653 100644 --- a/go.mod +++ b/go.mod @@ -119,6 +119,7 @@ require ( github.com/bndr/gotabulate v1.1.2 github.com/hashicorp/go-version v1.6.0 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 + golang.org/x/sync v0.3.0 golang.org/x/tools/cmd/cover v0.1.0-deprecated modernc.org/sqlite v1.20.3 ) diff --git a/go.sum b/go.sum index bf252b10cb2..b9a439e07a5 100644 --- a/go.sum +++ b/go.sum @@ -917,6 +917,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/go/flags/endtoend/vtctld.txt b/go/flags/endtoend/vtctld.txt index 83eac50a108..906a9160b6f 100644 --- a/go/flags/endtoend/vtctld.txt +++ b/go/flags/endtoend/vtctld.txt @@ -55,6 +55,7 @@ Usage of vtctld: --grpc_server_initial_window_size int gRPC server initial window size --grpc_server_keepalive_enforcement_policy_min_time duration gRPC server minimum keepalive time (default 10s) --grpc_server_keepalive_enforcement_policy_permit_without_stream gRPC server permit client keepalive pings even when there are no active streams (RPCs) + --healthcheck-dial-concurrency int Maximum concurrency of new healthcheck connections. This should be less than the golang max thread limit of 10000. (default 1024) -h, --help display usage and exit --jaeger-agent-host string host and port to send spans to. if empty, no tracing will be done --keep_logs duration keep logs for this long (using ctime) (zero to keep forever) diff --git a/go/flags/endtoend/vtgate.txt b/go/flags/endtoend/vtgate.txt index f6b6ef51945..cbf04bf888c 100644 --- a/go/flags/endtoend/vtgate.txt +++ b/go/flags/endtoend/vtgate.txt @@ -59,6 +59,7 @@ Usage of vtgate: --grpc_server_keepalive_enforcement_policy_min_time duration gRPC server minimum keepalive time (default 10s) --grpc_server_keepalive_enforcement_policy_permit_without_stream gRPC server permit client keepalive pings even when there are no active streams (RPCs) --grpc_use_effective_callerid If set, and SSL is not used, will set the immediate caller id from the effective caller id's principal. + --healthcheck-dial-concurrency int Maximum concurrency of new healthcheck connections. This should be less than the golang max thread limit of 10000. (default 1024) --healthcheck_retry_delay duration health check retry delay (default 2ms) --healthcheck_timeout duration the health check timeout period (default 1m0s) -h, --help display usage and exit diff --git a/go/vt/discovery/healthcheck.go b/go/vt/discovery/healthcheck.go index be0d022ff98..9ffd90ad649 100644 --- a/go/vt/discovery/healthcheck.go +++ b/go/vt/discovery/healthcheck.go @@ -45,6 +45,7 @@ import ( "time" "github.com/spf13/pflag" + "golang.org/x/sync/semaphore" "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/stats" @@ -88,6 +89,9 @@ var ( // topoReadConcurrency tells us how many topo reads are allowed in parallel. topoReadConcurrency = 32 + // healthCheckDialConcurrency tells us how many healthcheck connections can be opened to tablets at once. This should be less than the golang max thread limit of 10000. + healthCheckDialConcurrency int64 = 1024 + // How much to sleep between each check. waitAvailableTabletInterval = 100 * time.Millisecond ) @@ -166,6 +170,7 @@ func registerWebUIFlags(fs *pflag.FlagSet) { fs.DurationVar(&refreshInterval, "tablet_refresh_interval", 1*time.Minute, "Tablet refresh interval.") fs.BoolVar(&refreshKnownTablets, "tablet_refresh_known_tablets", true, "Whether to reload the tablet's address/port map from topo in case they change.") fs.IntVar(&topoReadConcurrency, "topo_read_concurrency", 32, "Concurrency of topo reads.") + fs.Int64Var(&healthCheckDialConcurrency, "healthcheck-dial-concurrency", 1024, "Maximum concurrency of new healthcheck connections. This should be less than the golang max thread limit of 10000.") ParseTabletURLTemplateFromFlag() } @@ -282,6 +287,8 @@ type HealthCheckImpl struct { subMu sync.Mutex // subscribers subscribers map[chan *TabletHealth]struct{} + // healthCheckDialSem is used to limit how many healthcheck connections can be opened to tablets at once. + healthCheckDialSem *semaphore.Weighted } // NewHealthCheck creates a new HealthCheck object. @@ -316,6 +323,7 @@ func NewHealthCheck(ctx context.Context, retryDelay, healthCheckTimeout time.Dur cell: localCell, retryDelay: retryDelay, healthCheckTimeout: healthCheckTimeout, + healthCheckDialSem: semaphore.NewWeighted(healthCheckDialConcurrency), healthByAlias: make(map[tabletAliasString]*tabletHealthCheck), healthData: make(map[KeyspaceShardTabletType]map[tabletAliasString]*TabletHealth), healthy: make(map[KeyspaceShardTabletType][]*TabletHealth), @@ -700,30 +708,8 @@ func (hc *HealthCheckImpl) WaitForAllServingTablets(ctx context.Context, targets return hc.waitForTablets(ctx, targets, true) } -// FilterTargetsByKeyspaces only returns the targets that are part of the provided keyspaces -func FilterTargetsByKeyspaces(keyspaces []string, targets []*query.Target) []*query.Target { - filteredTargets := make([]*query.Target, 0) - - // Keep them all if there are no keyspaces to watch - if len(KeyspacesToWatch) == 0 { - return append(filteredTargets, targets...) - } - - // Let's remove from the target shards that are not in the keyspaceToWatch list. - for _, target := range targets { - for _, keyspaceToWatch := range keyspaces { - if target.Keyspace == keyspaceToWatch { - filteredTargets = append(filteredTargets, target) - } - } - } - return filteredTargets -} - // waitForTablets is the internal method that polls for tablets. func (hc *HealthCheckImpl) waitForTablets(ctx context.Context, targets []*query.Target, requireServing bool) error { - targets = FilterTargetsByKeyspaces(KeyspacesToWatch, targets) - for { // We nil targets as we find them. allPresent := true @@ -800,7 +786,7 @@ func (hc *HealthCheckImpl) TabletConnection(alias *topodata.TabletAlias, target // TODO: test that throws this error return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "tablet: %v is either down or nonexistent", alias) } - return thc.Connection(), nil + return thc.Connection(hc), nil } // getAliasByCell should only be called while holding hc.mu diff --git a/go/vt/discovery/healthcheck_test.go b/go/vt/discovery/healthcheck_test.go index e61915a9b85..35cd1f17d05 100644 --- a/go/vt/discovery/healthcheck_test.go +++ b/go/vt/discovery/healthcheck_test.go @@ -645,27 +645,6 @@ func TestWaitForAllServingTablets(t *testing.T) { err = hc.WaitForAllServingTablets(ctx, targets) assert.NotNil(t, err, "error should not be nil (there are no tablets on this keyspace") - - targets = []*querypb.Target{ - - { - Keyspace: tablet.Keyspace, - Shard: tablet.Shard, - TabletType: tablet.Type, - }, - { - Keyspace: "newkeyspace", - Shard: tablet.Shard, - TabletType: tablet.Type, - }, - } - - KeyspacesToWatch = []string{tablet.Keyspace} - - err = hc.WaitForAllServingTablets(ctx, targets) - assert.Nil(t, err, "error should be nil. Keyspace with no tablets is filtered") - - KeyspacesToWatch = []string{} } // TestRemoveTablet tests the behavior when a tablet goes away. diff --git a/go/vt/discovery/tablet_health_check.go b/go/vt/discovery/tablet_health_check.go index f0ad9b0a2ac..05ab47dee05 100644 --- a/go/vt/discovery/tablet_health_check.go +++ b/go/vt/discovery/tablet_health_check.go @@ -19,6 +19,7 @@ package discovery import ( "context" "fmt" + "net" "strings" "sync" "time" @@ -34,12 +35,16 @@ import ( "vitess.io/vitess/go/vt/vttablet/queryservice" "vitess.io/vitess/go/vt/vttablet/tabletconn" + "google.golang.org/grpc" "google.golang.org/protobuf/proto" "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/topodata" ) +// withDialerContextOnce ensures grpc.WithDialContext() is added once to the options. +var withDialerContextOnce sync.Once + // tabletHealthCheck maintains the health status of a tablet. A map of this // structure is maintained in HealthCheck. type tabletHealthCheck struct { @@ -123,8 +128,8 @@ func (thc *tabletHealthCheck) setServingState(serving bool, reason string) { } // stream streams healthcheck responses to callback. -func (thc *tabletHealthCheck) stream(ctx context.Context, callback func(*query.StreamHealthResponse) error) error { - conn := thc.Connection() +func (thc *tabletHealthCheck) stream(ctx context.Context, hc *HealthCheckImpl, callback func(*query.StreamHealthResponse) error) error { + conn := thc.Connection(hc) if conn == nil { // This signals the caller to retry return nil @@ -137,14 +142,34 @@ func (thc *tabletHealthCheck) stream(ctx context.Context, callback func(*query.S return err } -func (thc *tabletHealthCheck) Connection() queryservice.QueryService { +func (thc *tabletHealthCheck) Connection(hc *HealthCheckImpl) queryservice.QueryService { thc.connMu.Lock() defer thc.connMu.Unlock() - return thc.connectionLocked() + return thc.connectionLocked(hc) +} + +func healthCheckDialerFactory(hc *HealthCheckImpl) func(ctx context.Context, addr string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + // Limit the number of healthcheck connections opened in parallel to avoid high OS-thread + // usage due to blocking networking syscalls (eg: DNS lookups, TCP connection opens, + // etc). Without this limit it is possible for vtgates watching >10k tablets to hit + // the panic: 'runtime: program exceeds 10000-thread limit'. + if err := hc.healthCheckDialSem.Acquire(ctx, 1); err != nil { + return nil, err + } + defer hc.healthCheckDialSem.Release(1) + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", addr) + } } -func (thc *tabletHealthCheck) connectionLocked() queryservice.QueryService { +func (thc *tabletHealthCheck) connectionLocked(hc *HealthCheckImpl) queryservice.QueryService { if thc.Conn == nil { + withDialerContextOnce.Do(func() { + grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) { + return append(opts, grpc.WithContextDialer(healthCheckDialerFactory(hc))), nil + }) + }) conn, err := tabletconn.GetDialer()(thc.Tablet, grpcclient.FailFast(true)) if err != nil { thc.LastError = err @@ -273,7 +298,7 @@ func (thc *tabletHealthCheck) checkConn(hc *HealthCheckImpl) { }() // Read stream health responses. - err := thc.stream(streamCtx, func(shr *query.StreamHealthResponse) error { + err := thc.stream(streamCtx, hc, func(shr *query.StreamHealthResponse) error { // We received a message. Reset the back-off. retryDelay = hc.retryDelay // Don't block on send to avoid deadlocks. diff --git a/go/vt/grpcclient/client.go b/go/vt/grpcclient/client.go index 8ad995721da..be239518bf9 100644 --- a/go/vt/grpcclient/client.go +++ b/go/vt/grpcclient/client.go @@ -21,6 +21,7 @@ package grpcclient import ( "context" "crypto/tls" + "sync" "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" @@ -39,6 +40,7 @@ import ( ) var ( + grpcDialOptionsMu sync.Mutex keepaliveTime = 10 * time.Second keepaliveTimeout = 10 * time.Second initialConnWindowSize int @@ -88,6 +90,8 @@ var grpcDialOptions []func(opts []grpc.DialOption) ([]grpc.DialOption, error) // RegisterGRPCDialOptions registers an implementation of AuthServer. func RegisterGRPCDialOptions(grpcDialOptionsFunc func(opts []grpc.DialOption) ([]grpc.DialOption, error)) { + grpcDialOptionsMu.Lock() + defer grpcDialOptionsMu.Unlock() grpcDialOptions = append(grpcDialOptions, grpcDialOptionsFunc) } @@ -137,12 +141,14 @@ func DialContext(ctx context.Context, target string, failFast FailFast, opts ... newopts = append(newopts, opts...) var err error + grpcDialOptionsMu.Lock() for _, grpcDialOptionInitializer := range grpcDialOptions { newopts, err = grpcDialOptionInitializer(newopts) if err != nil { log.Fatalf("There was an error initializing client grpc.DialOption: %v", err) } } + grpcDialOptionsMu.Unlock() newopts = append(newopts, interceptors()...) diff --git a/go/vt/grpcclient/client_auth_static.go b/go/vt/grpcclient/client_auth_static.go index 22f69569956..bbb91a9fa55 100644 --- a/go/vt/grpcclient/client_auth_static.go +++ b/go/vt/grpcclient/client_auth_static.go @@ -20,24 +20,35 @@ import ( "context" "encoding/json" "os" + "os/signal" + "sync" + "syscall" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + + "vitess.io/vitess/go/vt/servenv" ) var ( credsFile string // registered as --grpc_auth_static_client_creds in RegisterFlags // StaticAuthClientCreds implements client interface to be able to WithPerRPCCredentials _ credentials.PerRPCCredentials = (*StaticAuthClientCreds)(nil) + + clientCreds *StaticAuthClientCreds + clientCredsCancel context.CancelFunc + clientCredsErr error + clientCredsMu sync.Mutex + clientCredsSigChan chan os.Signal ) -// StaticAuthClientCreds holder for client credentials +// StaticAuthClientCreds holder for client credentials. type StaticAuthClientCreds struct { Username string Password string } -// GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds +// GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds. func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { return map[string]string{ "username": c.Username, @@ -47,30 +58,82 @@ func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) ( // RequireTransportSecurity indicates whether the credentials requires transport security. // Given that people can use this with or without TLS, at the moment we are not enforcing -// transport security +// transport security. func (c *StaticAuthClientCreds) RequireTransportSecurity() bool { return false } // AppendStaticAuth optionally appends static auth credentials if provided. func AppendStaticAuth(opts []grpc.DialOption) ([]grpc.DialOption, error) { - if credsFile == "" { - return opts, nil - } - data, err := os.ReadFile(credsFile) + creds, err := getStaticAuthCreds() if err != nil { return nil, err } - clientCreds := &StaticAuthClientCreds{} - err = json.Unmarshal(data, clientCreds) + if creds != nil { + grpcCreds := grpc.WithPerRPCCredentials(creds) + opts = append(opts, grpcCreds) + } + return opts, nil +} + +// ResetStaticAuth resets the static auth credentials. +func ResetStaticAuth() { + clientCredsMu.Lock() + defer clientCredsMu.Unlock() + if clientCredsCancel != nil { + clientCredsCancel() + clientCredsCancel = nil + } + clientCreds = nil + clientCredsErr = nil +} + +// getStaticAuthCreds returns the static auth creds and error. +func getStaticAuthCreds() (*StaticAuthClientCreds, error) { + clientCredsMu.Lock() + defer clientCredsMu.Unlock() + if credsFile != "" && clientCreds == nil { + var ctx context.Context + ctx, clientCredsCancel = context.WithCancel(context.Background()) + go handleClientCredsSignals(ctx) + clientCreds, clientCredsErr = loadStaticAuthCredsFromFile(credsFile) + } + return clientCreds, clientCredsErr +} + +// handleClientCredsSignals handles signals to reload client creds. +func handleClientCredsSignals(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-clientCredsSigChan: + if newCreds, err := loadStaticAuthCredsFromFile(credsFile); err == nil { + clientCredsMu.Lock() + clientCreds = newCreds + clientCredsErr = err + clientCredsMu.Unlock() + } + } + } +} + +// loadStaticAuthCredsFromFile loads static auth credentials from a file. +func loadStaticAuthCredsFromFile(path string) (*StaticAuthClientCreds, error) { + data, err := os.ReadFile(path) if err != nil { return nil, err } - creds := grpc.WithPerRPCCredentials(clientCreds) - opts = append(opts, creds) - return opts, nil + creds := &StaticAuthClientCreds{} + err = json.Unmarshal(data, creds) + return creds, err } func init() { + servenv.OnInit(func() { + clientCredsSigChan = make(chan os.Signal, 1) + signal.Notify(clientCredsSigChan, syscall.SIGHUP) + _, _ = getStaticAuthCreds() // preload static auth credentials + }) RegisterGRPCDialOptions(AppendStaticAuth) } diff --git a/go/vt/grpcclient/client_auth_static_test.go b/go/vt/grpcclient/client_auth_static_test.go new file mode 100644 index 00000000000..e14ace527d1 --- /dev/null +++ b/go/vt/grpcclient/client_auth_static_test.go @@ -0,0 +1,126 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpcclient + +import ( + "errors" + "fmt" + "os" + "reflect" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +func TestAppendStaticAuth(t *testing.T) { + { + clientCreds = nil + clientCredsErr = nil + opts, err := AppendStaticAuth([]grpc.DialOption{}) + assert.Nil(t, err) + assert.Len(t, opts, 0) + } + { + clientCreds = nil + clientCredsErr = errors.New("test err") + opts, err := AppendStaticAuth([]grpc.DialOption{}) + assert.NotNil(t, err) + assert.Len(t, opts, 0) + } + { + clientCreds = &StaticAuthClientCreds{Username: "test", Password: "123456"} + clientCredsErr = nil + opts, err := AppendStaticAuth([]grpc.DialOption{}) + assert.Nil(t, err) + assert.Len(t, opts, 1) + } +} + +func TestGetStaticAuthCreds(t *testing.T) { + tmp, err := os.CreateTemp("", t.Name()) + assert.Nil(t, err) + defer os.Remove(tmp.Name()) + credsFile = tmp.Name() + clientCredsSigChan = make(chan os.Signal, 1) + + // load old creds + fmt.Fprint(tmp, `{"Username": "old", "Password": "123456"}`) + ResetStaticAuth() + creds, err := getStaticAuthCreds() + assert.Nil(t, err) + assert.Equal(t, &StaticAuthClientCreds{Username: "old", Password: "123456"}, creds) + + // write new creds to the same file + _ = tmp.Truncate(0) + _, _ = tmp.Seek(0, 0) + fmt.Fprint(tmp, `{"Username": "new", "Password": "123456789"}`) + + // test the creds did not change yet + creds, err = getStaticAuthCreds() + assert.Nil(t, err) + assert.Equal(t, &StaticAuthClientCreds{Username: "old", Password: "123456"}, creds) + + // test SIGHUP signal triggers reload + credsOld := creds + clientCredsSigChan <- syscall.SIGHUP + timeoutChan := time.After(time.Second * 10) + for { + select { + case <-timeoutChan: + assert.Fail(t, "timed out waiting for SIGHUP reload of static auth creds") + return + default: + // confirm new creds get loaded + creds, err = getStaticAuthCreds() + if reflect.DeepEqual(creds, credsOld) { + continue // not changed yet + } + assert.Nil(t, err) + assert.Equal(t, &StaticAuthClientCreds{Username: "new", Password: "123456789"}, creds) + return + } + } +} + +func TestLoadStaticAuthCredsFromFile(t *testing.T) { + { + f, err := os.CreateTemp("", t.Name()) + if !assert.Nil(t, err) { + assert.FailNowf(t, "cannot create temp file: %s", err.Error()) + } + defer os.Remove(f.Name()) + fmt.Fprint(f, `{ + "Username": "test", + "Password": "correct horse battery staple" + }`) + if !assert.Nil(t, err) { + assert.FailNowf(t, "cannot read auth file: %s", err.Error()) + } + + creds, err := loadStaticAuthCredsFromFile(f.Name()) + assert.Nil(t, err) + assert.Equal(t, "test", creds.Username) + assert.Equal(t, "correct horse battery staple", creds.Password) + } + { + _, err := loadStaticAuthCredsFromFile(`does-not-exist`) + assert.NotNil(t, err) + } +} diff --git a/go/vt/srvtopo/discover.go b/go/vt/srvtopo/discover.go index 91aaea9daf6..2997dc42e21 100644 --- a/go/vt/srvtopo/discover.go +++ b/go/vt/srvtopo/discover.go @@ -29,20 +29,23 @@ import ( topodatapb "vitess.io/vitess/go/vt/proto/topodata" ) -// FindAllTargets goes through all serving shards in the topology -// for the provided tablet types. It returns one Target object per -// keyspace / shard / matching TabletType. -func FindAllTargets(ctx context.Context, ts Server, cell string, tabletTypes []topodatapb.TabletType) ([]*querypb.Target, error) { - ksNames, err := ts.GetSrvKeyspaceNames(ctx, cell, true) - if err != nil { - return nil, err +// FindAllTargets goes through all serving shards in the topology for the provided keyspaces +// and tablet types. If no keyspaces are provided all available keyspaces in the topo are +// fetched. It returns one Target object per keyspace/shard/matching TabletType. +func FindAllTargets(ctx context.Context, ts Server, cell string, keyspaces []string, tabletTypes []topodatapb.TabletType) ([]*querypb.Target, error) { + var err error + if len(keyspaces) == 0 { + keyspaces, err = ts.GetSrvKeyspaceNames(ctx, cell, true) + if err != nil { + return nil, err + } } var targets []*querypb.Target var wg sync.WaitGroup var mu sync.Mutex var errRecorder concurrency.AllErrorRecorder - for _, ksName := range ksNames { + for _, ksName := range keyspaces { wg.Add(1) go func(keyspace string) { defer wg.Done() diff --git a/go/vt/srvtopo/discover_test.go b/go/vt/srvtopo/discover_test.go index c076ba0e7b7..503f98ace1e 100644 --- a/go/vt/srvtopo/discover_test.go +++ b/go/vt/srvtopo/discover_test.go @@ -18,11 +18,12 @@ package srvtopo import ( "context" - "reflect" "sort" "testing" "time" + "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/vt/topo/memorytopo" querypb "vitess.io/vitess/go/vt/proto/query" @@ -61,16 +62,12 @@ func TestFindAllTargets(t *testing.T) { rs := NewResilientServer(ts, "TestFindAllKeyspaceShards") // No keyspace / shards. - ks, err := FindAllTargets(ctx, rs, "cell1", []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if len(ks) > 0 { - t.Errorf("why did I get anything? %v", ks) - } + ks, err := FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) + assert.NoError(t, err) + assert.Len(t, ks, 0) // Add one. - if err := ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace", &topodatapb.SrvKeyspace{ + assert.NoError(t, ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace", &topodatapb.SrvKeyspace{ Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{ { ServedType: topodatapb.TabletType_PRIMARY, @@ -81,28 +78,34 @@ func TestFindAllTargets(t *testing.T) { }, }, }, - }); err != nil { - t.Fatalf("can't add srvKeyspace: %v", err) - } + })) // Get it. - ks, err = FindAllTargets(ctx, rs, "cell1", []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !reflect.DeepEqual(ks, []*querypb.Target{ + ks, err = FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) + assert.NoError(t, err) + assert.EqualValues(t, []*querypb.Target{ { Cell: "cell1", Keyspace: "test_keyspace", Shard: "test_shard0", TabletType: topodatapb.TabletType_PRIMARY, }, - }) { - t.Errorf("got wrong value: %v", ks) - } + }, ks) + + // Get any keyspace. + ks, err = FindAllTargets(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY}) + assert.NoError(t, err) + assert.EqualValues(t, []*querypb.Target{ + { + Cell: "cell1", + Keyspace: "test_keyspace", + Shard: "test_shard0", + TabletType: topodatapb.TabletType_PRIMARY, + }, + }, ks) // Add another one. - if err := ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace2", &topodatapb.SrvKeyspace{ + assert.NoError(t, ts.UpdateSrvKeyspace(ctx, "cell1", "test_keyspace2", &topodatapb.SrvKeyspace{ Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{ { ServedType: topodatapb.TabletType_PRIMARY, @@ -121,17 +124,13 @@ func TestFindAllTargets(t *testing.T) { }, }, }, - }); err != nil { - t.Fatalf("can't add srvKeyspace: %v", err) - } + })) - // Get it for all types. - ks, err = FindAllTargets(ctx, rs, "cell1", []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + // Get it for any keyspace, all types. + ks, err = FindAllTargets(ctx, rs, "cell1", nil, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) + assert.NoError(t, err) sort.Sort(TargetArray(ks)) - if !reflect.DeepEqual(ks, []*querypb.Target{ + assert.EqualValues(t, []*querypb.Target{ { Cell: "cell1", Keyspace: "test_keyspace", @@ -150,23 +149,40 @@ func TestFindAllTargets(t *testing.T) { Shard: "test_shard2", TabletType: topodatapb.TabletType_REPLICA, }, - }) { - t.Errorf("got wrong value: %v", ks) - } + }, ks) - // Only get the REPLICA targets. - ks, err = FindAllTargets(ctx, rs, "cell1", []topodatapb.TabletType{topodatapb.TabletType_REPLICA}) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !reflect.DeepEqual(ks, []*querypb.Target{ + // Only get 1 keyspace for all types. + ks, err = FindAllTargets(ctx, rs, "cell1", []string{"test_keyspace2"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) + assert.NoError(t, err) + assert.EqualValues(t, []*querypb.Target{ + { + Cell: "cell1", + Keyspace: "test_keyspace2", + Shard: "test_shard1", + TabletType: topodatapb.TabletType_PRIMARY, + }, { Cell: "cell1", Keyspace: "test_keyspace2", Shard: "test_shard2", TabletType: topodatapb.TabletType_REPLICA, }, - }) { - t.Errorf("got wrong value: %v", ks) - } + }, ks) + + // Only get the REPLICA targets for any keyspace. + ks, err = FindAllTargets(ctx, rs, "cell1", []string{}, []topodatapb.TabletType{topodatapb.TabletType_REPLICA}) + assert.NoError(t, err) + assert.Equal(t, []*querypb.Target{ + { + Cell: "cell1", + Keyspace: "test_keyspace2", + Shard: "test_shard2", + TabletType: topodatapb.TabletType_REPLICA, + }, + }, ks) + + // Get non-existent keyspace. + ks, err = FindAllTargets(ctx, rs, "cell1", []string{"doesnt-exist"}, []topodatapb.TabletType{topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA}) + assert.NoError(t, err) + assert.Len(t, ks, 0) } diff --git a/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go b/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go index cf272fe3606..55a067807bd 100644 --- a/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go +++ b/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go @@ -108,6 +108,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) { fs := pflag.NewFlagSet("", pflag.ContinueOnError) grpcclient.RegisterFlags(fs) + grpcclient.ResetStaticAuth() err = fs.Parse([]string{ "--grpc_auth_static_client_creds", f.Name(), @@ -148,6 +149,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) { fs = pflag.NewFlagSet("", pflag.ContinueOnError) grpcclient.RegisterFlags(fs) + grpcclient.ResetStaticAuth() err = fs.Parse([]string{ "--grpc_auth_static_client_creds", f.Name(), diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 6b4efde2768..011b2e09b7a 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -207,7 +207,7 @@ func (gw *TabletGateway) WaitForTablets(tabletTypesToWait []topodatapb.TabletTyp } // Finds the targets to look for. - targets, err := srvtopo.FindAllTargets(ctx, gw.srvTopoServer, gw.localCell, tabletTypesToWait) + targets, err := srvtopo.FindAllTargets(ctx, gw.srvTopoServer, gw.localCell, discovery.KeyspacesToWatch, tabletTypesToWait) if err != nil { return err }