diff --git a/balancer/weightedroundrobin/balancer.go b/balancer/weightedroundrobin/balancer.go new file mode 100644 index 000000000000..15e780880511 --- /dev/null +++ b/balancer/weightedroundrobin/balancer.go @@ -0,0 +1,493 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + "unsafe" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/weightedroundrobin/internal" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/orca" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + + v3orcapb "github.com/cncf/xds/go/xds/data/orca/v3" +) + +// Name is the name of the weighted round robin balancer. +const Name = "weighted_round_robin_experimental" + +func init() { + balancer.Register(bb{}) +} + +type bb struct{} + +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { + b := &wrrBalancer{ + cc: cc, + subConns: resolver.NewAddressMap(), + csEvltr: &balancer.ConnectivityStateEvaluator{}, + scMap: make(map[balancer.SubConn]*weightedSubConn), + connectivityState: connectivity.Connecting, + } + b.logger = prefixLogger(b) + b.logger.Infof("Created") + b.regeneratePicker() + return b +} + +func (bb) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + lbCfg := &lbConfig{ + // Default values as documented in A58. + OOBReportingPeriod: 10 * time.Second, + BlackoutPeriod: 10 * time.Second, + WeightExpirationPeriod: 3 * time.Minute, + WeightUpdatePeriod: time.Second, + } + if err := json.Unmarshal(js, lbCfg); err != nil { + return nil, fmt.Errorf("wrr: unable to unmarshal LB policy config: %s, error: %v", string(js), err) + } + + if lbCfg.ErrorUtilizationPenalty < 0 { + return nil, fmt.Errorf("wrr: errorUtilizationPenalty must be non-negative") + } + + // For easier comparisons later, ensure the OOB reporting period is unset + // (0s) when OOB reports are disabled. + if !lbCfg.EnableOOBLoadReport { + lbCfg.OOBReportingPeriod = 0 + } + + // Impose lower bound of 100ms on weightUpdatePeriod. + if !internal.AllowAnyWeightUpdatePeriod && lbCfg.WeightUpdatePeriod < 100*time.Millisecond { + lbCfg.WeightUpdatePeriod = 100 * time.Millisecond + } + + return lbCfg, nil +} + +func (bb) Name() string { + return Name +} + +type wrrBalancer struct { + cc balancer.ClientConn + logger *grpclog.PrefixLogger + + // The following fields are only accessed on calls into the LB policy, and + // do not need a mutex. + cfg *lbConfig // active config + subConns *resolver.AddressMap // active weightedSubConns mapped by address + scMap map[balancer.SubConn]*weightedSubConn + connectivityState connectivity.State // aggregate state + csEvltr *balancer.ConnectivityStateEvaluator + resolverErr error // the last error reported by the resolver; cleared on successful resolution + connErr error // the last connection error; cleared upon leaving TransientFailure + stopPicker func() +} + +func (b *wrrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { + b.logger.Infof("UpdateCCS: %v", ccs) + b.resolverErr = nil + cfg, ok := ccs.BalancerConfig.(*lbConfig) + if !ok { + return fmt.Errorf("wrr: received nil or illegal BalancerConfig: %v", ccs.BalancerConfig) + } + + b.cfg = cfg + b.updateAddressesLocked(ccs.ResolverState.Addresses) + + if len(ccs.ResolverState.Addresses) == 0 { + b.ResolverError(errors.New("resolver produced zero addresses")) // will call regeneratePicker + return balancer.ErrBadResolverState + } + + // Regenerate & send picker. + b.regeneratePicker() + + return nil +} + +func (b *wrrBalancer) updateAddressesLocked(addrs []resolver.Address) { + addrsSet := resolver.NewAddressMap() + + // Loop through new addresses + for _, addr := range addrs { + if _, ok := addrsSet.Get(addr); ok { + // Redundant address; skip. + continue + } + addrsSet.Set(addr, nil) + + var wsc *weightedSubConn + wsci, ok := b.subConns.Get(addr) + if ok { + wsc = wsci.(*weightedSubConn) + } else { + // addr is a new address (not existing in b.subConns). + sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{}) + if err != nil { + b.logger.Warningf("wrr: failed to create new SubConn: %v", err) + continue + } + wsc = &weightedSubConn{ + SubConn: sc, + logger: b.logger, + connectivityState: connectivity.Idle, + } + b.subConns.Set(addr, wsc) + b.scMap[sc] = wsc + b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle) + sc.Connect() + } + // Update config for existing weightedSubConn or send update for first + // time to new one. Ensures an OOB listener is running if needed. + wsc.updateConfig(b.cfg) + } + for _, addr := range b.subConns.Keys() { + if _, ok := addrsSet.Get(addr); ok { + // Existing address also in new address list; skip. + continue + } + // addr was removed by resolver. Remove. + wsci, _ := b.subConns.Get(addr) + wsc := wsci.(*weightedSubConn) + b.cc.RemoveSubConn(wsc.SubConn) + b.subConns.Delete(addr) + } +} + +func (b *wrrBalancer) ResolverError(err error) { + b.resolverErr = err + if b.subConns.Len() == 0 { + b.connectivityState = connectivity.TransientFailure + } + if b.connectivityState != connectivity.TransientFailure { + // No need to update the picker since no error is being returned. + return + } + b.regeneratePicker() +} + +func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + wsc := b.scMap[sc] + if wsc == nil { + b.logger.Errorf("wrr: got state changes for an unknown SubConn: %p, %v", sc, state) + return + } + + cs := state.ConnectivityState + + if cs == connectivity.TransientFailure { + // Save error to be reported via picker. + b.connErr = state.ConnectionError + } + + if cs == connectivity.Shutdown { + delete(b.scMap, sc) + } + + oldCS := wsc.updateConnectivityState(cs) + b.connectivityState = b.csEvltr.RecordTransition(oldCS, cs) + + // Regenerate picker when one of the following happens: + // - this sc entered or left ready + // - the aggregated state of balancer is TransientFailure + // (may need to update error message) + if (cs == connectivity.Ready) != (oldCS == connectivity.Ready) || + b.connectivityState == connectivity.TransientFailure { + b.regeneratePicker() + } +} + +// Close stops timers that would cause our scheduler to be reupdated +func (b *wrrBalancer) Close() { + if b.stopPicker != nil { + b.stopPicker() + b.stopPicker = nil + } + for _, wsc := range b.scMap { + // Ensure any lingering OOB watchers are stopped. + wsc.updateConnectivityState(connectivity.Shutdown) + } +} + +// ExitIdle is ignored; we always connect to all backends. +func (b *wrrBalancer) ExitIdle() {} + +func (b *wrrBalancer) readySubConns() []*weightedSubConn { + var ret []*weightedSubConn + for _, v := range b.subConns.Values() { + wsc := v.(*weightedSubConn) + if wsc.connectivityState == connectivity.Ready { + ret = append(ret, wsc) + } + } + return ret +} + +// mergeErrors builds an error from the last connection error and the last +// resolver error. Must only be called if b.connectivityState is +// TransientFailure. +func (b *wrrBalancer) mergeErrors() error { + // connErr must always be non-nil unless there are no SubConns, in which + // case resolverErr must be non-nil. + if b.connErr == nil { + return fmt.Errorf("last resolver error: %v", b.resolverErr) + } + if b.resolverErr == nil { + return fmt.Errorf("last connection error: %v", b.connErr) + } + return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr) +} + +func (b *wrrBalancer) regeneratePicker() { + if b.stopPicker != nil { + b.stopPicker() + b.stopPicker = nil + } + + switch b.connectivityState { + case connectivity.TransientFailure: + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: base.NewErrPicker(b.mergeErrors()), + }) + return + case connectivity.Connecting, connectivity.Idle: + // Idle could happen very briefly if all subconns are Idle and we've + // asked them to connect but they haven't reported Connecting yet. + // Report the same as Connecting since this is temporary. + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable), + }) + return + case connectivity.Ready: + b.connErr = nil + } + + p := &picker{ + cfg: b.cfg, + subConns: b.readySubConns(), + } + var ctx context.Context + ctx, b.stopPicker = context.WithCancel(context.Background()) + p.start(ctx) + b.cc.UpdateState(balancer.State{ + ConnectivityState: b.connectivityState, + Picker: p, + }) +} + +type picker struct { + idx uint32 // index used indirectly by the scheduler; accessed atomically + cfg *lbConfig // active config when picker created + subConns []*weightedSubConn // all READY subconns + scheduler unsafe.Pointer // *scheduler; accessed atomically +} + +func (p *picker) scWeights() []float64 { + ws := make([]float64, len(p.subConns), len(p.subConns)) + for i, wsc := range p.subConns { + ws[i] = wsc.weight(time.Now(), p.cfg.WeightExpirationPeriod, p.cfg.BlackoutPeriod) + } + return ws +} + +func (p *picker) nextIdx() uint32 { + return atomic.AddUint32(&p.idx, 1) +} + +func (p *picker) regenerateScheduler() { + newSched := newScheduler(p.scWeights(), p.nextIdx) + atomic.StorePointer(&p.scheduler, unsafe.Pointer(&newSched)) +} + +func (p *picker) start(ctx context.Context) { + p.regenerateScheduler() + if len(p.subConns) == 1 { + // No need to regenerate weights with only one backend. + return + } + go func() { + ticker := time.NewTicker(p.cfg.WeightUpdatePeriod) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + p.regenerateScheduler() + } + } + }() +} + +func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler)) + pickedSC := p.subConns[sched.nextIndex()] + pr := balancer.PickResult{SubConn: pickedSC.SubConn} + if !p.cfg.EnableOOBLoadReport { + pr.Done = func(info balancer.DoneInfo) { + if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil { + pickedSC.OnLoadReport(load) + } + } + } + return pr, nil +} + +type weightedSubConn struct { + balancer.SubConn + logger *grpclog.PrefixLogger + + // The following fields are only accessed on calls into the LB policy, and + // do not need a mutex. + connectivityState connectivity.State + stopORCAListener func() + + // The following field are accessed asynchronously and are protected by mu. + mu sync.Mutex + weightVal float64 + nonEmptySince time.Time + lastUpdated time.Time + cfg *lbConfig +} + +func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) { + if w.logger.V(2) { + w.logger.Infof("wrr: received load report for subchannel %v: %v", w.SubConn, load) + } + // Update weights of this subchannel according to the reported load + if load.CpuUtilization == 0 || load.RpsFractional == 0 { + if w.logger.V(2) { + w.logger.Infof("wrr: ignoring empty load report for subchannel %v", w.SubConn) + } + return + } + + w.mu.Lock() + defer w.mu.Unlock() + + errorRate := load.Eps / load.RpsFractional + w.weightVal = load.RpsFractional / (load.CpuUtilization + errorRate*w.cfg.ErrorUtilizationPenalty) + if w.logger.V(2) { + w.logger.Infof("wrr: new weight for subchannel %v: %v", w.SubConn, w.weightVal) + } + + w.lastUpdated = time.Now() + if w.nonEmptySince == (time.Time{}) { + w.nonEmptySince = w.lastUpdated + } +} + +func (w *weightedSubConn) updateConfig(cfg *lbConfig) { + w.mu.Lock() + defer w.mu.Unlock() + oldCfg := w.cfg + if oldCfg == nil { + oldCfg = &lbConfig{EnableOOBLoadReport: false} + } + w.cfg = cfg + newPeriod := cfg.OOBReportingPeriod + if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport && + newPeriod == oldCfg.OOBReportingPeriod { + // Load reporting wasn't enabled before or after, or load reporting was + // enabled before and after, and had the same period. (Note that with + // load reporting disabled, OOBReportingPeriod is always 0.) + return + } + // (Optionally stop and) start the listener. + if w.stopORCAListener != nil { + w.stopORCAListener() + } + if !cfg.EnableOOBLoadReport { + w.stopORCAListener = nil + return + } + w.logger.Infof("Registering listener for %v; %v", w.SubConn, newPeriod) + opts := orca.OOBListenerOptions{ReportInterval: newPeriod} + w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts) +} + +func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connectivity.State { + w.mu.Lock() + defer w.mu.Unlock() + + switch cs { + case connectivity.Idle: + // Always reconnect when idle. + w.SubConn.Connect() + case connectivity.Ready: + // If we transition back to READY state, restart the blackout period. + // Note that we cannot guarantee that we will never receive lingering + // callbacks for backend metric reports from the previous connection + // after the new connection has been established, but they should be + // masked by new backend metric reports from the new connection by the + // time the blackout period ends. + w.nonEmptySince = time.Time{} + case connectivity.Shutdown: + if w.stopORCAListener != nil { + w.stopORCAListener() + } + } + + old := w.connectivityState + + if old == connectivity.TransientFailure && + (cs == connectivity.Connecting || cs == connectivity.Idle) { + // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or + // CONNECTING transitions to prevent the aggregated state from being + // always CONNECTING when many backends exist but are all down. + return old + } + + w.connectivityState = cs + + return old +} + +func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration) float64 { + w.mu.Lock() + defer w.mu.Unlock() + // If the most recent update was longer ago than the expiration period, + // reset nonEmptySince so that we apply the blackout period again if we + // start getting data again in the future, and return 0. + if now.Sub(w.lastUpdated) > weightExpirationPeriod { + w.nonEmptySince = time.Time{} + return 0 + } + // If we don't have at least blackoutPeriod worth of data, return 0. + if blackoutPeriod != 0 && (w.nonEmptySince == (time.Time{}) || now.Sub(w.nonEmptySince) < blackoutPeriod) { + return 0 + } + // Otherwise, return the weight. + return w.weightVal +} diff --git a/balancer/weightedroundrobin/balancer_test.go b/balancer/weightedroundrobin/balancer_test.go new file mode 100644 index 000000000000..4a562809894a --- /dev/null +++ b/balancer/weightedroundrobin/balancer_test.go @@ -0,0 +1,580 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin_test + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils/roundrobin" + "google.golang.org/grpc/orca" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/resolver" + + wrr "google.golang.org/grpc/balancer/weightedroundrobin" + iwrr "google.golang.org/grpc/balancer/weightedroundrobin/internal" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +const defaultTestTimeout = 10 * time.Second +const rrIterations = 100 + +type testServer struct { + *stubserver.StubServer + + oobMetrics orca.ServerMetricsRecorder // Attached to the OOB stream. + callMetrics orca.CallMetricsRecorder // Attached to per-call metrics. +} + +type reportType int + +const ( + reportNone reportType = iota + reportOOB + reportCall + reportBoth +) + +func startServer(t *testing.T, r reportType) *testServer { + t.Helper() + + smr := orca.NewServerMetricsRecorder() + cmr := orca.NewServerMetricsRecorder().(orca.CallMetricsRecorder) + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if r := orca.CallMetricsRecorderFromContext(ctx); r != nil { + // Copy metrics from what the test set in cmr into r. + sm := cmr.(orca.ServerMetricsProvider).ServerMetrics() + r.SetCPUUtilization(sm.CPUUtilization) + r.SetQPS(sm.QPS) + r.SetEPS(sm.EPS) + } + return &testpb.Empty{}, nil + }, + } + + var sopts []grpc.ServerOption + if r == reportBoth || r == reportCall { + sopts = append(sopts, orca.CallMetricsServerOption(nil)) + } + + if r == reportOOB || r == reportBoth { + oso := orca.ServiceOptions{ + ServerMetricsProvider: smr, + MinReportingInterval: 10 * time.Millisecond, + } + internal.ORCAAllowAnyMinReportingInterval.(func(so *orca.ServiceOptions))(&oso) + sopts = append(sopts, stubserver.RegisterServiceServerOption(func(s *grpc.Server) { + if err := orca.Register(s, oso); err != nil { + t.Fatalf("Failed to register orca service: %v", err) + } + })) + } + + if err := ss.StartServer(sopts...); err != nil { + t.Fatalf("Error starting server: %v", err) + } + t.Cleanup(ss.Stop) + + return &testServer{ + StubServer: ss, + oobMetrics: smr, + callMetrics: cmr, + } +} + +func svcConfig(t *testing.T, wrrCfg wrr.LBConfigForTesting) string { + t.Helper() + m, err := json.Marshal(wrrCfg) + if err != nil { + t.Fatalf("Error marshaling JSON %v: %v", wrrCfg, err) + } + sc := fmt.Sprintf(`{"loadBalancingConfig": [ {%q:%v} ] }`, wrr.Name, string(m)) + t.Logf("Marshaled service config: %v", sc) + return sc +} + +const weightUpdatePeriod = 50 * time.Millisecond +const oobReportingInterval = 10 * time.Millisecond + +func init() { + iwrr.AllowAnyWeightUpdatePeriod = true +} + +var ( + perCallConfig = wrr.LBConfigForTesting{ + EnableOOBLoadReport: false, + BlackoutPeriod: 1 * time.Nanosecond, + OOBReportingPeriod: 5 * time.Millisecond, + WeightExpirationPeriod: time.Minute, + WeightUpdatePeriod: weightUpdatePeriod, + } + oobConfig = wrr.LBConfigForTesting{ + EnableOOBLoadReport: true, + BlackoutPeriod: 1 * time.Nanosecond, + OOBReportingPeriod: 5 * time.Millisecond, + WeightExpirationPeriod: time.Minute, + WeightUpdatePeriod: weightUpdatePeriod, + } +) + +func (s) TestBalancer_OneAddress_ReportingDisabled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv := startServer(t, reportNone) + + sc := svcConfig(t, perCallConfig) + if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + // Perform many RPCs to ensure the LB policy works with 1 address. + for i := 0; i < 100; i++ { + if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("Error from EmptyCall: %v", err) + } + time.Sleep(time.Millisecond) // Delay; test will run 100ms and should perform ~10 weight updates + } +} + +func (s) TestBalancer_OneAddress_ReportingEnabledPerCall(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv := startServer(t, reportCall) + + sc := svcConfig(t, perCallConfig) + if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + // Perform many RPCs to ensure the LB policy works with 1 address. + for i := 0; i < 100; i++ { + srv.callMetrics.SetQPS(float64(i)) + if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("Error from EmptyCall: %v", err) + } + time.Sleep(time.Millisecond) // Delay; test will run 100ms and should perform ~10 weight updates + } +} + +func (s) TestBalancer_OneAddress_ReportingEnabledOOB(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv := startServer(t, reportOOB) + + sc := svcConfig(t, oobConfig) + if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + // Perform many RPCs to ensure the LB policy works with 1 address. + for i := 0; i < 100; i++ { + srv.oobMetrics.SetQPS(float64(i)) + if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("Error from EmptyCall: %v", err) + } + time.Sleep(time.Millisecond) // Delay; test will run 100ms and should perform ~10 weight updates + } +} + +func (s) TestBalancer_TwoAddresses_ReportingDisabled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportNone) + srv2 := startServer(t, reportNone) + + sc := svcConfig(t, perCallConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Perform many RPCs to ensure the LB policy works with 2 addresses. + for i := 0; i < 20; i++ { + roundrobin.CheckRoundRobinRPCs(ctx, srv1.Client, addrs) + } +} + +func (s) TestBalancer_TwoAddresses_ReportingEnabledPerCall(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportCall) + srv2 := startServer(t, reportCall) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.callMetrics.SetQPS(10.0) + srv1.callMetrics.SetCPUUtilization(1.0) + + srv2.callMetrics.SetQPS(10.0) + srv2.callMetrics.SetCPUUtilization(.1) + + sc := svcConfig(t, perCallConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) +} + +func (s) TestBalancer_TwoAddresses_ReportingEnabledOOB(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportOOB) + srv2 := startServer(t, reportOOB) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) +} + +func (s) TestBalancer_TwoAddresses_UpdateLoads(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportOOB) + srv2 := startServer(t, reportOOB) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Update the loads so srv2 is loaded and srv1 is not; ensure RPCs are + // routed disproportionately to srv1. + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(.1) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(1.0) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod + oobReportingInterval) + checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1}) +} + +func (s) TestBalancer_TwoAddresses_OOBThenPerCall(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportBoth) + srv2 := startServer(t, reportBoth) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + // For per-call metrics (not used initially), srv2 reports that it is + // loaded and srv1 reports low load. After confirming OOB works, switch to + // per-call and confirm the new routing weights are applied. + srv1.callMetrics.SetQPS(10.0) + srv1.callMetrics.SetCPUUtilization(.1) + + srv2.callMetrics.SetQPS(10.0) + srv2.callMetrics.SetCPUUtilization(1.0) + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Update to per-call weights. + c := svcConfig(t, perCallConfig) + parsedCfg := srv1.R.CC.ParseServiceConfig(c) + if parsedCfg.Err != nil { + panic(fmt.Sprintf("Error parsing config %q: %v", c, parsedCfg.Err)) + } + srv1.R.UpdateState(resolver.State{Addresses: addrs, ServiceConfig: parsedCfg}) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1}) +} + +func (s) TestBalancer_TwoAddresses_ErrorPenalty(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportOOB) + srv2 := startServer(t, reportOOB) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). Errors are set (but ignored + // initially) such that RPCs will be routed 50/50. + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + srv1.oobMetrics.SetEPS(0) + // srv1 weight before: 10.0 / 1.0 = 10.0 + // srv1 weight after: 10.0 / 1.0 = 10.0 + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + srv2.oobMetrics.SetEPS(10.0) + // srv2 weight before: 10.0 / 0.1 = 100.0 + // srv2 weight after: 10.0 / 1.0 = 10.0 + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Update to include an error penalty in the weights. + newCfg := oobConfig + newCfg.ErrorUtilizationPenalty = 0.9 + c := svcConfig(t, newCfg) + parsedCfg := srv1.R.CC.ParseServiceConfig(c) + if parsedCfg.Err != nil { + panic(fmt.Sprintf("Error parsing config %q: %v", c, parsedCfg.Err)) + } + srv1.R.UpdateState(resolver.State{Addresses: addrs, ServiceConfig: parsedCfg}) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod + oobReportingInterval) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1}) +} + +func (s) TestBalancer_TwoAddresses_BlackoutPeriod(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportOOB) + srv2 := startServer(t, reportOOB) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + cfg := oobConfig + cfg.BlackoutPeriod = time.Second + sc := svcConfig(t, cfg) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + start := time.Now() + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + + // During the blackout period (1s) we should route roughly 50/50. + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1}) + + // Wait for the blackout period, then RPCs should be routed 10:1 to srv2. + time.Sleep(time.Until(start.Add(cfg.BlackoutPeriod))) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) +} + +func (s) TestBalancer_TwoAddresses_WeightExpiration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportBoth) + srv2 := startServer(t, reportBoth) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). Because the OOB reporting interval + // is 1 minute but the weights expire in 1 second, routing will go to 50/50 + // after the weights expire. + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + cfg := oobConfig + cfg.WeightExpirationPeriod = time.Second + cfg.OOBReportingPeriod = time.Minute + sc := svcConfig(t, cfg) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + start := time.Now() + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Wait for the weight expiration period so the weights have expired. + time.Sleep(time.Until(start.Add(cfg.WeightExpirationPeriod))) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1}) +} + +func ensureReached(ctx context.Context, t *testing.T, c testgrpc.TestServiceClient, n int) { + t.Helper() + reached := make(map[string]struct{}) + for len(reached) != n { + var peer peer.Peer + if _, err := c.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil { + t.Fatalf("Error from EmptyCall: %v", err) + } + reached[peer.Addr.String()] = struct{}{} + } +} + +type srvWeight struct { + srv *testServer + w int +} + +func checkWeights(ctx context.Context, t *testing.T, sws ...srvWeight) { + t.Helper() + + c := sws[0].srv.Client + + // Replace the weights with approximate counts of RPCs wanted given the + // iterations performed. + weightSum := 0 + for _, sw := range sws { + weightSum += sw.w + } + for i := range sws { + sws[i].w = rrIterations * sws[i].w / weightSum + } + + for attempts := 0; attempts < 10; attempts++ { + serverCounts := make(map[string]int) + for i := 0; i < rrIterations; i++ { + var peer peer.Peer + if _, err := c.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil { + t.Fatalf("Error from EmptyCall: %v; timed out waiting for weighted RR behavior?", err) + } + serverCounts[peer.Addr.String()]++ + } + if len(serverCounts) != len(sws) { + continue + } + success := true + for _, sw := range sws { + c := serverCounts[sw.srv.Address] + if c < sw.w-2 || c > sw.w+2 { + success = false + break + } + } + if success { + t.Logf("Passed iteration %v; counts: %v", attempts, serverCounts) + return + } + t.Logf("Failed iteration %v; counts: %v; want %+v", attempts, serverCounts, sws) + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("Failed to route RPCs with proper ratio") +} diff --git a/balancer/weightedroundrobin/config.go b/balancer/weightedroundrobin/config.go new file mode 100644 index 000000000000..cfc8a72a52ed --- /dev/null +++ b/balancer/weightedroundrobin/config.go @@ -0,0 +1,38 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin + +import ( + "time" + + "google.golang.org/grpc/serviceconfig" +) + +type lbConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + + EnableOOBLoadReport bool `json:"enableOobLoadReport,omitempty"` + OOBReportingPeriod time.Duration `json:"oobReportingPeriod,omitempty"` + BlackoutPeriod time.Duration `json:"blackoutPeriod,omitempty"` + WeightExpirationPeriod time.Duration `json:"weightExpirationPeriod,omitempty"` + WeightUpdatePeriod time.Duration `json:"weightUpdatePeriod,omitempty"` + ErrorUtilizationPenalty float64 `json:"errorUtilizationPenalty,omitempty"` +} + +type LBConfigForTesting = lbConfig diff --git a/balancer/weightedroundrobin/internal/internal.go b/balancer/weightedroundrobin/internal/internal.go new file mode 100644 index 000000000000..2f3fb96a140c --- /dev/null +++ b/balancer/weightedroundrobin/internal/internal.go @@ -0,0 +1,23 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package internal + +// AllowAnyWeightUpdatePeriod permits any setting of WeightUpdatePeriod for +// testing. Normally a minimum of 100ms is applied. +var AllowAnyWeightUpdatePeriod bool diff --git a/balancer/weightedroundrobin/logging.go b/balancer/weightedroundrobin/logging.go new file mode 100644 index 000000000000..2e62e43a26e4 --- /dev/null +++ b/balancer/weightedroundrobin/logging.go @@ -0,0 +1,34 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin + +import ( + "fmt" + + "google.golang.org/grpc/grpclog" + internalgrpclog "google.golang.org/grpc/internal/grpclog" +) + +const prefix = "[weighted-round-robin-lb %p] " + +var logger = grpclog.Component("xds") + +func prefixLogger(p *wrrBalancer) *internalgrpclog.PrefixLogger { + return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) +} diff --git a/balancer/weightedroundrobin/scheduler.go b/balancer/weightedroundrobin/scheduler.go new file mode 100644 index 000000000000..99ff79155e3d --- /dev/null +++ b/balancer/weightedroundrobin/scheduler.go @@ -0,0 +1,119 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin + +import ( + "math" +) + +type scheduler interface { + nextIndex() int +} + +func newScheduler(scWeights []float64, inc func() uint32) scheduler { + n := len(scWeights) + if n == 0 { + return nil + } + if n == 1 { + return &rrScheduler{numSCs: 1, inc: inc} + } + sum := float64(0) + numZero := 0 + max := float64(0) + for _, w := range scWeights { + sum += w + if w > max { + max = w + } + if w == 0 { + numZero++ + } + } + if numZero == n { + return &rrScheduler{numSCs: n, inc: inc} + } + unscaledMean := sum / float64(n-numZero) + scalingFactor := maxWeight / max + mean := uint16(math.Round(scalingFactor * unscaledMean)) + + weights := make([]uint16, n, n) + for i, w := range scWeights { + if w == 0 { + weights[i] = mean + } else { + weights[i] = uint16(math.Round(scalingFactor * w)) + } + } + + logger.Infof("using edf scheduler with weights: %v", weights) + return &edfScheduler{weights: weights, inc: inc} +} + +const maxWeight = math.MaxUint16 + +type edfScheduler struct { + inc func() uint32 + weights []uint16 +} + +// Returns the index in weights to choose. +func (s *edfScheduler) nextIndex() int { + const offset = maxWeight / 2 + + for { + idx := uint64(s.inc()) + + // The sequence number (idx) is split in two: the lower %n gives the + // index of the backend, and the rest gives the number of times we've + // iterated through all backends. `generation` is used to + // deterministically decide whether we pick or skip the backend on this + // iteration, in proportion to the backend's weight. + + backendIndex := idx % uint64(len(s.weights)) + generation := idx / uint64(len(s.weights)) + weight := uint64(s.weights[backendIndex]) + + // We pick a backend `weight` times per `maxWeight` generations. The + // multiply and modulus ~evenly spread out the picks for a given + // backend between different generations. The offset by `backendIndex` + // helps to reduce the chance of multiple consecutive non-picks: if we + // have two consecutive backends with an equal, say, 80% weight of the + // max, with no offset we would see 1/5 generations that skipped both. + // TODO(b/190488683): add test for offset efficacy. + mod := uint64(weight*generation+backendIndex*offset) % maxWeight + + if mod < maxWeight-weight { + continue + } + return int(backendIndex) + } +} + +// A simple RR scheduler to use for fallback when all weights are zero or only +// one subconn exists. +type rrScheduler struct { + inc func() uint32 + numSCs int +} + +func (s *rrScheduler) nextIndex() int { + idx := int(s.inc()) + return idx % s.numSCs +} diff --git a/balancer/weightedroundrobin/weightedroundrobin.go b/balancer/weightedroundrobin/weightedroundrobin.go index 6fc4d1910e67..cc5a96b8048b 100644 --- a/balancer/weightedroundrobin/weightedroundrobin.go +++ b/balancer/weightedroundrobin/weightedroundrobin.go @@ -16,16 +16,20 @@ * */ -// Package weightedroundrobin defines a weighted roundrobin balancer. +// Package weightedroundrobin provides an implementation of the weighted round +// robin LB policy, as defined in gRFC A58: +// https://github.com/grpc/proposal/blob/master/A58-client-side-weighted-round-robin-lb-policy.md +// +// # Experimental +// +// Notice: This package is EXPERIMENTAL and may be changed or removed in a +// later release. package weightedroundrobin import ( "google.golang.org/grpc/resolver" ) -// Name is the name of weighted_round_robin balancer. -const Name = "weighted_round_robin" - // attributeKey is the type used as the key to store AddrInfo in the // BalancerAttributes field of resolver.Address. type attributeKey struct{} @@ -44,11 +48,6 @@ func (a AddrInfo) Equal(o interface{}) bool { // SetAddrInfo returns a copy of addr in which the BalancerAttributes field is // updated with addrInfo. -// -// # Experimental -// -// Notice: This API is EXPERIMENTAL and may be changed or removed in a -// later release. func SetAddrInfo(addr resolver.Address, addrInfo AddrInfo) resolver.Address { addr.BalancerAttributes = addr.BalancerAttributes.WithValue(attributeKey{}, addrInfo) return addr @@ -56,11 +55,6 @@ func SetAddrInfo(addr resolver.Address, addrInfo AddrInfo) resolver.Address { // GetAddrInfo returns the AddrInfo stored in the BalancerAttributes field of // addr. -// -// # Experimental -// -// Notice: This API is EXPERIMENTAL and may be changed or removed in a -// later release. func GetAddrInfo(addr resolver.Address) AddrInfo { v := addr.BalancerAttributes.Value(attributeKey{}) ai, _ := v.(AddrInfo) diff --git a/xds/internal/balancer/clusterimpl/picker.go b/xds/internal/balancer/clusterimpl/picker.go index 360fc44c9e4d..3f354424f28e 100644 --- a/xds/internal/balancer/clusterimpl/picker.go +++ b/xds/internal/balancer/clusterimpl/picker.go @@ -160,7 +160,7 @@ func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { d.loadStore.CallFinished(lIDStr, info.Err) load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport) - if !ok { + if !ok || load == nil { return } d.loadStore.CallServerLoad(lIDStr, serverLoadCPUName, load.CpuUtilization)