diff --git a/balancer/weightedroundrobin/balancer.go b/balancer/weightedroundrobin/balancer.go new file mode 100644 index 00000000000..e0d255222d5 --- /dev/null +++ b/balancer/weightedroundrobin/balancer.go @@ -0,0 +1,532 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + "unsafe" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/weightedroundrobin/internal" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/orca" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + + v3orcapb "github.com/cncf/xds/go/xds/data/orca/v3" +) + +// Name is the name of the weighted round robin balancer. +const Name = "weighted_round_robin_experimental" + +func init() { + balancer.Register(bb{}) +} + +type bb struct{} + +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { + b := &wrrBalancer{ + cc: cc, + subConns: resolver.NewAddressMap(), + csEvltr: &balancer.ConnectivityStateEvaluator{}, + scMap: make(map[balancer.SubConn]*weightedSubConn), + connectivityState: connectivity.Connecting, + } + b.logger = prefixLogger(b) + b.logger.Infof("Created") + return b +} + +func (bb) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + lbCfg := &lbConfig{ + // Default values as documented in A58. + OOBReportingPeriod: 10 * time.Second, + BlackoutPeriod: 10 * time.Second, + WeightExpirationPeriod: 3 * time.Minute, + WeightUpdatePeriod: time.Second, + ErrorUtilizationPenalty: 1, + } + if err := json.Unmarshal(js, lbCfg); err != nil { + return nil, fmt.Errorf("wrr: unable to unmarshal LB policy config: %s, error: %v", string(js), err) + } + + if lbCfg.ErrorUtilizationPenalty < 0 { + return nil, fmt.Errorf("wrr: errorUtilizationPenalty must be non-negative") + } + + // For easier comparisons later, ensure the OOB reporting period is unset + // (0s) when OOB reports are disabled. + if !lbCfg.EnableOOBLoadReport { + lbCfg.OOBReportingPeriod = 0 + } + + // Impose lower bound of 100ms on weightUpdatePeriod. + if !internal.AllowAnyWeightUpdatePeriod && lbCfg.WeightUpdatePeriod < 100*time.Millisecond { + lbCfg.WeightUpdatePeriod = 100 * time.Millisecond + } + + return lbCfg, nil +} + +func (bb) Name() string { + return Name +} + +// wrrBalancer implements the weighted round robin LB policy. +type wrrBalancer struct { + cc balancer.ClientConn + logger *grpclog.PrefixLogger + + // The following fields are only accessed on calls into the LB policy, and + // do not need a mutex. + cfg *lbConfig // active config + subConns *resolver.AddressMap // active weightedSubConns mapped by address + scMap map[balancer.SubConn]*weightedSubConn + connectivityState connectivity.State // aggregate state + csEvltr *balancer.ConnectivityStateEvaluator + resolverErr error // the last error reported by the resolver; cleared on successful resolution + connErr error // the last connection error; cleared upon leaving TransientFailure + stopPicker func() +} + +func (b *wrrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { + b.logger.Infof("UpdateCCS: %v", ccs) + b.resolverErr = nil + cfg, ok := ccs.BalancerConfig.(*lbConfig) + if !ok { + return fmt.Errorf("wrr: received nil or illegal BalancerConfig (type %T): %v", ccs.BalancerConfig, ccs.BalancerConfig) + } + + b.cfg = cfg + b.updateAddresses(ccs.ResolverState.Addresses) + + if len(ccs.ResolverState.Addresses) == 0 { + b.ResolverError(errors.New("resolver produced zero addresses")) // will call regeneratePicker + return balancer.ErrBadResolverState + } + + b.regeneratePicker() + + return nil +} + +func (b *wrrBalancer) updateAddresses(addrs []resolver.Address) { + addrsSet := resolver.NewAddressMap() + + // Loop through new address list and create subconns for any new addresses. + for _, addr := range addrs { + if _, ok := addrsSet.Get(addr); ok { + // Redundant address; skip. + continue + } + addrsSet.Set(addr, nil) + + var wsc *weightedSubConn + wsci, ok := b.subConns.Get(addr) + if ok { + wsc = wsci.(*weightedSubConn) + } else { + // addr is a new address (not existing in b.subConns). + sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{}) + if err != nil { + b.logger.Warningf("Failed to create new SubConn for address %v: %v", addr, err) + continue + } + wsc = &weightedSubConn{ + SubConn: sc, + logger: b.logger, + connectivityState: connectivity.Idle, + // Initially, we set load reports to off, because they are not + // running upon initial weightedSubConn creation. + cfg: &lbConfig{EnableOOBLoadReport: false}, + } + b.subConns.Set(addr, wsc) + b.scMap[sc] = wsc + b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle) + sc.Connect() + } + // Update config for existing weightedSubConn or send update for first + // time to new one. Ensures an OOB listener is running if needed + // (and stops the existing one if applicable). + wsc.updateConfig(b.cfg) + } + + // Loop through existing subconns and remove ones that are not in addrs. + for _, addr := range b.subConns.Keys() { + if _, ok := addrsSet.Get(addr); ok { + // Existing address also in new address list; skip. + continue + } + // addr was removed by resolver. Remove. + wsci, _ := b.subConns.Get(addr) + wsc := wsci.(*weightedSubConn) + b.cc.RemoveSubConn(wsc.SubConn) + b.subConns.Delete(addr) + } +} + +func (b *wrrBalancer) ResolverError(err error) { + b.resolverErr = err + if b.subConns.Len() == 0 { + b.connectivityState = connectivity.TransientFailure + } + if b.connectivityState != connectivity.TransientFailure { + // No need to update the picker since no error is being returned. + return + } + b.regeneratePicker() +} + +func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + wsc := b.scMap[sc] + if wsc == nil { + b.logger.Errorf("UpdateSubConnState called with an unknown SubConn: %p, %v", sc, state) + return + } + if b.logger.V(2) { + logger.Infof("UpdateSubConnState(%+v, %+v)", sc, state) + } + + cs := state.ConnectivityState + + if cs == connectivity.TransientFailure { + // Save error to be reported via picker. + b.connErr = state.ConnectionError + } + + if cs == connectivity.Shutdown { + delete(b.scMap, sc) + // The subconn was removed from b.subConns when the address was removed + // in updateAddresses. + } + + oldCS := wsc.updateConnectivityState(cs) + b.connectivityState = b.csEvltr.RecordTransition(oldCS, cs) + + // Regenerate picker when one of the following happens: + // - this sc entered or left ready + // - the aggregated state of balancer is TransientFailure + // (may need to update error message) + if (cs == connectivity.Ready) != (oldCS == connectivity.Ready) || + b.connectivityState == connectivity.TransientFailure { + b.regeneratePicker() + } +} + +// Close stops the balancer. It cancels any ongoing scheduler updates and +// stops any ORCA listeners. +func (b *wrrBalancer) Close() { + if b.stopPicker != nil { + b.stopPicker() + b.stopPicker = nil + } + for _, wsc := range b.scMap { + // Ensure any lingering OOB watchers are stopped. + wsc.updateConnectivityState(connectivity.Shutdown) + } +} + +// ExitIdle is ignored; we always connect to all backends. +func (b *wrrBalancer) ExitIdle() {} + +func (b *wrrBalancer) readySubConns() []*weightedSubConn { + var ret []*weightedSubConn + for _, v := range b.subConns.Values() { + wsc := v.(*weightedSubConn) + if wsc.connectivityState == connectivity.Ready { + ret = append(ret, wsc) + } + } + return ret +} + +// mergeErrors builds an error from the last connection error and the last +// resolver error. Must only be called if b.connectivityState is +// TransientFailure. +func (b *wrrBalancer) mergeErrors() error { + // connErr must always be non-nil unless there are no SubConns, in which + // case resolverErr must be non-nil. + if b.connErr == nil { + return fmt.Errorf("last resolver error: %v", b.resolverErr) + } + if b.resolverErr == nil { + return fmt.Errorf("last connection error: %v", b.connErr) + } + return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr) +} + +func (b *wrrBalancer) regeneratePicker() { + if b.stopPicker != nil { + b.stopPicker() + b.stopPicker = nil + } + + switch b.connectivityState { + case connectivity.TransientFailure: + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: base.NewErrPicker(b.mergeErrors()), + }) + return + case connectivity.Connecting, connectivity.Idle: + // Idle could happen very briefly if all subconns are Idle and we've + // asked them to connect but they haven't reported Connecting yet. + // Report the same as Connecting since this is temporary. + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable), + }) + return + case connectivity.Ready: + b.connErr = nil + } + + p := &picker{ + v: grpcrand.Uint32(), // start the scheduler at a random point + cfg: b.cfg, + subConns: b.readySubConns(), + } + var ctx context.Context + ctx, b.stopPicker = context.WithCancel(context.Background()) + p.start(ctx) + b.cc.UpdateState(balancer.State{ + ConnectivityState: b.connectivityState, + Picker: p, + }) +} + +// picker is the WRR policy's picker. It uses live-updating backend weights to +// update the scheduler periodically and ensure picks are routed proportional +// to those weights. +type picker struct { + scheduler unsafe.Pointer // *scheduler; accessed atomically + v uint32 // incrementing value used by the scheduler; accessed atomically + cfg *lbConfig // active config when picker created + subConns []*weightedSubConn // all READY subconns +} + +// scWeights returns a slice containing the weights from p.subConns in the same +// order as p.subConns. +func (p *picker) scWeights() []float64 { + ws := make([]float64, len(p.subConns)) + now := internal.TimeNow() + for i, wsc := range p.subConns { + ws[i] = wsc.weight(now, p.cfg.WeightExpirationPeriod, p.cfg.BlackoutPeriod) + } + return ws +} + +func (p *picker) inc() uint32 { + return atomic.AddUint32(&p.v, 1) +} + +func (p *picker) regenerateScheduler() { + s := newScheduler(p.scWeights(), p.inc) + atomic.StorePointer(&p.scheduler, unsafe.Pointer(&s)) +} + +func (p *picker) start(ctx context.Context) { + p.regenerateScheduler() + if len(p.subConns) == 1 { + // No need to regenerate weights with only one backend. + return + } + go func() { + ticker := time.NewTicker(p.cfg.WeightUpdatePeriod) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + p.regenerateScheduler() + } + } + }() +} + +func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + // Read the scheduler atomically. All scheduler operations are threadsafe, + // and if the scheduler is replaced during this usage, we want to use the + // scheduler that was live when the pick started. + sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler)) + + pickedSC := p.subConns[sched.nextIndex()] + pr := balancer.PickResult{SubConn: pickedSC.SubConn} + if !p.cfg.EnableOOBLoadReport { + pr.Done = func(info balancer.DoneInfo) { + if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil { + pickedSC.OnLoadReport(load) + } + } + } + return pr, nil +} + +// weightedSubConn is the wrapper of a subconn that holds the subconn and its +// weight (and other parameters relevant to computing the effective weight). +// When needed, it also tracks connectivity state, listens for metrics updates +// by implementing the orca.OOBListener interface and manages that listener. +type weightedSubConn struct { + balancer.SubConn + logger *grpclog.PrefixLogger + + // The following fields are only accessed on calls into the LB policy, and + // do not need a mutex. + connectivityState connectivity.State + stopORCAListener func() + + // The following fields are accessed asynchronously and are protected by + // mu. Note that mu may not be held when calling into the stopORCAListener + // or when registering a new listener, as those calls require the ORCA + // producer mu which is held when calling the listener, and the listener + // holds mu. + mu sync.Mutex + weightVal float64 + nonEmptySince time.Time + lastUpdated time.Time + cfg *lbConfig +} + +func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) { + if w.logger.V(2) { + w.logger.Infof("Received load report for subchannel %v: %v", w.SubConn, load) + } + // Update weights of this subchannel according to the reported load + if load.CpuUtilization == 0 || load.RpsFractional == 0 { + if w.logger.V(2) { + w.logger.Infof("Ignoring empty load report for subchannel %v", w.SubConn) + } + return + } + + w.mu.Lock() + defer w.mu.Unlock() + + errorRate := load.Eps / load.RpsFractional + w.weightVal = load.RpsFractional / (load.CpuUtilization + errorRate*w.cfg.ErrorUtilizationPenalty) + if w.logger.V(2) { + w.logger.Infof("New weight for subchannel %v: %v", w.SubConn, w.weightVal) + } + + w.lastUpdated = internal.TimeNow() + if w.nonEmptySince == (time.Time{}) { + w.nonEmptySince = w.lastUpdated + } +} + +// updateConfig updates the parameters of the WRR policy and +// stops/starts/restarts the ORCA OOB listener. +func (w *weightedSubConn) updateConfig(cfg *lbConfig) { + w.mu.Lock() + oldCfg := w.cfg + w.cfg = cfg + w.mu.Unlock() + + newPeriod := cfg.OOBReportingPeriod + if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport && + newPeriod == oldCfg.OOBReportingPeriod { + // Load reporting wasn't enabled before or after, or load reporting was + // enabled before and after, and had the same period. (Note that with + // load reporting disabled, OOBReportingPeriod is always 0.) + return + } + // (Optionally stop and) start the listener to use the new config's + // settings for OOB reporting. + + if w.stopORCAListener != nil { + w.stopORCAListener() + } + if !cfg.EnableOOBLoadReport { + w.stopORCAListener = nil + return + } + if w.logger.V(2) { + w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, newPeriod) + } + opts := orca.OOBListenerOptions{ReportInterval: newPeriod} + w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts) +} + +func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connectivity.State { + switch cs { + case connectivity.Idle: + // Always reconnect when idle. + w.SubConn.Connect() + case connectivity.Ready: + // If we transition back to READY state, reset nonEmptySince so that we + // apply the blackout period after we start receiving load data. Note + // that we cannot guarantee that we will never receive lingering + // callbacks for backend metric reports from the previous connection + // after the new connection has been established, but they should be + // masked by new backend metric reports from the new connection by the + // time the blackout period ends. + w.mu.Lock() + w.nonEmptySince = time.Time{} + w.mu.Unlock() + case connectivity.Shutdown: + if w.stopORCAListener != nil { + w.stopORCAListener() + } + } + + oldCS := w.connectivityState + + if oldCS == connectivity.TransientFailure && + (cs == connectivity.Connecting || cs == connectivity.Idle) { + // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or + // CONNECTING transitions to prevent the aggregated state from being + // always CONNECTING when many backends exist but are all down. + return oldCS + } + + w.connectivityState = cs + + return oldCS +} + +// weight returns the current effective weight of the subconn, taking into +// account the parameters. Returns 0 for blacked out or expired data, which +// will cause the backend weight to be treated as the mean of the weights of +// the other backends. +func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration) float64 { + w.mu.Lock() + defer w.mu.Unlock() + // If the most recent update was longer ago than the expiration period, + // reset nonEmptySince so that we apply the blackout period again if we + // start getting data again in the future, and return 0. + if now.Sub(w.lastUpdated) >= weightExpirationPeriod { + w.nonEmptySince = time.Time{} + return 0 + } + // If we don't have at least blackoutPeriod worth of data, return 0. + if blackoutPeriod != 0 && (w.nonEmptySince == (time.Time{}) || now.Sub(w.nonEmptySince) < blackoutPeriod) { + return 0 + } + return w.weightVal +} diff --git a/balancer/weightedroundrobin/balancer_test.go b/balancer/weightedroundrobin/balancer_test.go new file mode 100644 index 00000000000..5dd62ebf872 --- /dev/null +++ b/balancer/weightedroundrobin/balancer_test.go @@ -0,0 +1,713 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin_test + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils/roundrobin" + "google.golang.org/grpc/orca" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/resolver" + + wrr "google.golang.org/grpc/balancer/weightedroundrobin" + iwrr "google.golang.org/grpc/balancer/weightedroundrobin/internal" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +const defaultTestTimeout = 10 * time.Second +const weightUpdatePeriod = 50 * time.Millisecond +const oobReportingInterval = 10 * time.Millisecond + +func init() { + iwrr.AllowAnyWeightUpdatePeriod = true +} + +func boolp(b bool) *bool { return &b } +func float64p(f float64) *float64 { return &f } +func durationp(d time.Duration) *time.Duration { return &d } + +var ( + perCallConfig = iwrr.LBConfig{ + EnableOOBLoadReport: boolp(false), + OOBReportingPeriod: durationp(5 * time.Millisecond), + BlackoutPeriod: durationp(0), + WeightExpirationPeriod: durationp(time.Minute), + WeightUpdatePeriod: durationp(weightUpdatePeriod), + ErrorUtilizationPenalty: float64p(0), + } + oobConfig = iwrr.LBConfig{ + EnableOOBLoadReport: boolp(true), + OOBReportingPeriod: durationp(5 * time.Millisecond), + BlackoutPeriod: durationp(0), + WeightExpirationPeriod: durationp(time.Minute), + WeightUpdatePeriod: durationp(weightUpdatePeriod), + ErrorUtilizationPenalty: float64p(0), + } +) + +type testServer struct { + *stubserver.StubServer + + oobMetrics orca.ServerMetricsRecorder // Attached to the OOB stream. + callMetrics orca.CallMetricsRecorder // Attached to per-call metrics. +} + +type reportType int + +const ( + reportNone reportType = iota + reportOOB + reportCall + reportBoth +) + +func startServer(t *testing.T, r reportType) *testServer { + t.Helper() + + smr := orca.NewServerMetricsRecorder() + cmr := orca.NewServerMetricsRecorder().(orca.CallMetricsRecorder) + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if r := orca.CallMetricsRecorderFromContext(ctx); r != nil { + // Copy metrics from what the test set in cmr into r. + sm := cmr.(orca.ServerMetricsProvider).ServerMetrics() + r.SetCPUUtilization(sm.CPUUtilization) + r.SetQPS(sm.QPS) + r.SetEPS(sm.EPS) + } + return &testpb.Empty{}, nil + }, + } + + var sopts []grpc.ServerOption + if r == reportCall || r == reportBoth { + sopts = append(sopts, orca.CallMetricsServerOption(nil)) + } + + if r == reportOOB || r == reportBoth { + oso := orca.ServiceOptions{ + ServerMetricsProvider: smr, + MinReportingInterval: 10 * time.Millisecond, + } + internal.ORCAAllowAnyMinReportingInterval.(func(so *orca.ServiceOptions))(&oso) + sopts = append(sopts, stubserver.RegisterServiceServerOption(func(s *grpc.Server) { + if err := orca.Register(s, oso); err != nil { + t.Fatalf("Failed to register orca service: %v", err) + } + })) + } + + if err := ss.StartServer(sopts...); err != nil { + t.Fatalf("Error starting server: %v", err) + } + t.Cleanup(ss.Stop) + + return &testServer{ + StubServer: ss, + oobMetrics: smr, + callMetrics: cmr, + } +} + +func svcConfig(t *testing.T, wrrCfg iwrr.LBConfig) string { + t.Helper() + m, err := json.Marshal(wrrCfg) + if err != nil { + t.Fatalf("Error marshaling JSON %v: %v", wrrCfg, err) + } + sc := fmt.Sprintf(`{"loadBalancingConfig": [ {%q:%v} ] }`, wrr.Name, string(m)) + t.Logf("Marshaled service config: %v", sc) + return sc +} + +// Tests basic functionality with one address. With only one address, load +// reporting doesn't affect routing at all. +func (s) TestBalancer_OneAddress(t *testing.T) { + testCases := []struct { + rt reportType + cfg iwrr.LBConfig + }{ + {rt: reportNone, cfg: perCallConfig}, + {rt: reportCall, cfg: perCallConfig}, + {rt: reportOOB, cfg: oobConfig}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("reportType:%v", tc.rt), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv := startServer(t, tc.rt) + + sc := svcConfig(t, tc.cfg) + if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + // Perform many RPCs to ensure the LB policy works with 1 address. + for i := 0; i < 100; i++ { + srv.callMetrics.SetQPS(float64(i)) + srv.oobMetrics.SetQPS(float64(i)) + if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("Error from EmptyCall: %v", err) + } + time.Sleep(time.Millisecond) // Delay; test will run 100ms and should perform ~10 weight updates + } + }) + } +} + +// Tests two addresses with ORCA reporting disabled (should fall back to pure +// RR). +func (s) TestBalancer_TwoAddresses_ReportingDisabled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportNone) + srv2 := startServer(t, reportNone) + + sc := svcConfig(t, perCallConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Perform many RPCs to ensure the LB policy works with 2 addresses. + for i := 0; i < 20; i++ { + roundrobin.CheckRoundRobinRPCs(ctx, srv1.Client, addrs) + } +} + +// Tests two addresses with per-call ORCA reporting enabled. Checks the +// backends are called in the appropriate ratios. +func (s) TestBalancer_TwoAddresses_ReportingEnabledPerCall(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportCall) + srv2 := startServer(t, reportCall) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.callMetrics.SetQPS(10.0) + srv1.callMetrics.SetCPUUtilization(1.0) + + srv2.callMetrics.SetQPS(10.0) + srv2.callMetrics.SetCPUUtilization(.1) + + sc := svcConfig(t, perCallConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) +} + +// Tests two addresses with OOB ORCA reporting enabled. Checks the backends +// are called in the appropriate ratios. +func (s) TestBalancer_TwoAddresses_ReportingEnabledOOB(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportOOB) + srv2 := startServer(t, reportOOB) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) +} + +// Tests two addresses with OOB ORCA reporting enabled, where the reports +// change over time. Checks the backends are called in the appropriate ratios +// before and after modifying the reports. +func (s) TestBalancer_TwoAddresses_UpdateLoads(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportOOB) + srv2 := startServer(t, reportOOB) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Update the loads so srv2 is loaded and srv1 is not; ensure RPCs are + // routed disproportionately to srv1. + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(.1) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(1.0) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod + oobReportingInterval) + checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1}) +} + +// Tests two addresses with OOB ORCA reporting enabled, then with switching to +// per-call reporting. Checks the backends are called in the appropriate +// ratios before and after the change. +func (s) TestBalancer_TwoAddresses_OOBThenPerCall(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportBoth) + srv2 := startServer(t, reportBoth) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + // For per-call metrics (not used initially), srv2 reports that it is + // loaded and srv1 reports low load. After confirming OOB works, switch to + // per-call and confirm the new routing weights are applied. + srv1.callMetrics.SetQPS(10.0) + srv1.callMetrics.SetCPUUtilization(.1) + + srv2.callMetrics.SetQPS(10.0) + srv2.callMetrics.SetCPUUtilization(1.0) + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Update to per-call weights. + c := svcConfig(t, perCallConfig) + parsedCfg := srv1.R.CC.ParseServiceConfig(c) + if parsedCfg.Err != nil { + panic(fmt.Sprintf("Error parsing config %q: %v", c, parsedCfg.Err)) + } + srv1.R.UpdateState(resolver.State{Addresses: addrs, ServiceConfig: parsedCfg}) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1}) +} + +// Tests two addresses with OOB ORCA reporting enabled and a non-zero error +// penalty applied. +func (s) TestBalancer_TwoAddresses_ErrorPenalty(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportOOB) + srv2 := startServer(t, reportOOB) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). EPS values are set (but ignored + // initially due to ErrorUtilizationPenalty=0). Later EUP will be updated + // to 0.9 which will cause the weights to be equal and RPCs to be routed + // 50/50. + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + srv1.oobMetrics.SetEPS(0) + // srv1 weight before: 10.0 / 1.0 = 10.0 + // srv1 weight after: 10.0 / 1.0 = 10.0 + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + srv2.oobMetrics.SetEPS(10.0) + // srv2 weight before: 10.0 / 0.1 = 100.0 + // srv2 weight after: 10.0 / 1.0 = 10.0 + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Update to include an error penalty in the weights. + newCfg := oobConfig + newCfg.ErrorUtilizationPenalty = float64p(0.9) + c := svcConfig(t, newCfg) + parsedCfg := srv1.R.CC.ParseServiceConfig(c) + if parsedCfg.Err != nil { + panic(fmt.Sprintf("Error parsing config %q: %v", c, parsedCfg.Err)) + } + srv1.R.UpdateState(resolver.State{Addresses: addrs, ServiceConfig: parsedCfg}) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod + oobReportingInterval) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1}) +} + +// Tests that the blackout period causes backends to use 0 as their weight +// (meaning to use the average weight) until the blackout period elapses. +func (s) TestBalancer_TwoAddresses_BlackoutPeriod(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + var mu sync.Mutex + start := time.Now() + now := start + setNow := func(t time.Time) { + mu.Lock() + defer mu.Unlock() + now = t + } + iwrr.TimeNow = func() time.Time { + mu.Lock() + defer mu.Unlock() + return now + } + t.Cleanup(func() { iwrr.TimeNow = time.Now }) + + testCases := []struct { + blackoutPeriodCfg *time.Duration + blackoutPeriod time.Duration + }{{ + blackoutPeriodCfg: durationp(time.Second), + blackoutPeriod: time.Second, + }, { + blackoutPeriodCfg: nil, + blackoutPeriod: 10 * time.Second, // the default + }} + for _, tc := range testCases { + setNow(start) + srv1 := startServer(t, reportOOB) + srv2 := startServer(t, reportOOB) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + cfg := oobConfig + cfg.BlackoutPeriod = tc.blackoutPeriodCfg + sc := svcConfig(t, cfg) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + // During the blackout period (1s) we should route roughly 50/50. + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1}) + + // Advance time to right before the blackout period ends and the weights + // should still be zero. + setNow(start.Add(tc.blackoutPeriod - time.Nanosecond)) + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1}) + + // Advance time to right after the blackout period ends and the weights + // should now activate. + setNow(start.Add(tc.blackoutPeriod)) + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + } +} + +// Tests that the weight expiration period causes backends to use 0 as their +// weight (meaning to use the average weight) once the expiration period +// elapses. +func (s) TestBalancer_TwoAddresses_WeightExpiration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + var mu sync.Mutex + start := time.Now() + now := start + setNow := func(t time.Time) { + mu.Lock() + defer mu.Unlock() + now = t + } + iwrr.TimeNow = func() time.Time { + mu.Lock() + defer mu.Unlock() + return now + } + t.Cleanup(func() { iwrr.TimeNow = time.Now }) + + srv1 := startServer(t, reportBoth) + srv2 := startServer(t, reportBoth) + + // srv1 starts loaded and srv2 starts without load; ensure RPCs are routed + // disproportionately to srv2 (10:1). Because the OOB reporting interval + // is 1 minute but the weights expire in 1 second, routing will go to 50/50 + // after the weights expire. + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + + cfg := oobConfig + cfg.OOBReportingPeriod = durationp(time.Minute) + sc := svcConfig(t, cfg) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Advance what time.Now returns to the weight expiration time minus 1s to + // ensure all weights are still honored. + setNow(start.Add(*cfg.WeightExpirationPeriod - time.Second)) + + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}) + + // Advance what time.Now returns to the weight expiration time plus 1s to + // ensure all weights expired and addresses are routed evenly. + setNow(start.Add(*cfg.WeightExpirationPeriod + time.Second)) + + // Wait for the weight expiration period so the weights have expired. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1}) +} + +// Tests logic surrounding subchannel management. +func (s) TestBalancer_AddressesChanging(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv1 := startServer(t, reportBoth) + srv2 := startServer(t, reportBoth) + srv3 := startServer(t, reportBoth) + srv4 := startServer(t, reportBoth) + + // srv1: weight 10 + srv1.oobMetrics.SetQPS(10.0) + srv1.oobMetrics.SetCPUUtilization(1.0) + // srv2: weight 100 + srv2.oobMetrics.SetQPS(10.0) + srv2.oobMetrics.SetCPUUtilization(.1) + // srv3: weight 20 + srv3.oobMetrics.SetQPS(20.0) + srv3.oobMetrics.SetCPUUtilization(1.0) + // srv4: weight 200 + srv4.oobMetrics.SetQPS(20.0) + srv4.oobMetrics.SetCPUUtilization(.1) + + sc := svcConfig(t, oobConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + srv2.Client = srv1.Client + addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}, {Addr: srv3.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 3) + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}, srvWeight{srv3, 2}) + + // Add backend 4 + addrs = append(addrs, resolver.Address{Addr: srv4.Address}) + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}, srvWeight{srv3, 2}, srvWeight{srv4, 20}) + + // Shutdown backend 3. RPCs will no longer be routed to it. + srv3.Stop() + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}, srvWeight{srv4, 20}) + + // Remove addresses 2 and 3. RPCs will no longer be routed to 2 either. + addrs = []resolver.Address{{Addr: srv1.Address}, {Addr: srv4.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv4, 20}) + + // Re-add 2 and remove the rest. + addrs = []resolver.Address{{Addr: srv2.Address}} + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv2, 10}) + + // Re-add 4. + addrs = append(addrs, resolver.Address{Addr: srv4.Address}) + srv1.R.UpdateState(resolver.State{Addresses: addrs}) + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv2, 10}, srvWeight{srv4, 20}) +} + +func ensureReached(ctx context.Context, t *testing.T, c testgrpc.TestServiceClient, n int) { + t.Helper() + reached := make(map[string]struct{}) + for len(reached) != n { + var peer peer.Peer + if _, err := c.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil { + t.Fatalf("Error from EmptyCall: %v", err) + } + reached[peer.Addr.String()] = struct{}{} + } +} + +type srvWeight struct { + srv *testServer + w int +} + +const rrIterations = 100 + +// checkWeights does rrIterations RPCs and expects the different backends to be +// routed in a ratio as deterimined by the srvWeights passed in. Allows for +// some variance (+/- 2 RPCs per backend). +func checkWeights(ctx context.Context, t *testing.T, sws ...srvWeight) { + t.Helper() + + c := sws[0].srv.Client + + // Replace the weights with approximate counts of RPCs wanted given the + // iterations performed. + weightSum := 0 + for _, sw := range sws { + weightSum += sw.w + } + for i := range sws { + sws[i].w = rrIterations * sws[i].w / weightSum + } + + for attempts := 0; attempts < 10; attempts++ { + serverCounts := make(map[string]int) + for i := 0; i < rrIterations; i++ { + var peer peer.Peer + if _, err := c.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil { + t.Fatalf("Error from EmptyCall: %v; timed out waiting for weighted RR behavior?", err) + } + serverCounts[peer.Addr.String()]++ + } + if len(serverCounts) != len(sws) { + continue + } + success := true + for _, sw := range sws { + c := serverCounts[sw.srv.Address] + if c < sw.w-2 || c > sw.w+2 { + success = false + break + } + } + if success { + t.Logf("Passed iteration %v; counts: %v", attempts, serverCounts) + return + } + t.Logf("Failed iteration %v; counts: %v; want %+v", attempts, serverCounts, sws) + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("Failed to route RPCs with proper ratio") +} diff --git a/balancer/weightedroundrobin/config.go b/balancer/weightedroundrobin/config.go new file mode 100644 index 00000000000..caad18faa11 --- /dev/null +++ b/balancer/weightedroundrobin/config.go @@ -0,0 +1,60 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin + +import ( + "time" + + "google.golang.org/grpc/serviceconfig" +) + +type lbConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + + // Whether to enable out-of-band utilization reporting collection from the + // endpoints. By default, per-request utilization reporting is used. + EnableOOBLoadReport bool `json:"enableOobLoadReport,omitempty"` + + // Load reporting interval to request from the server. Note that the + // server may not provide reports as frequently as the client requests. + // Used only when enable_oob_load_report is true. Default is 10 seconds. + OOBReportingPeriod time.Duration `json:"oobReportingPeriod,omitempty"` + + // A given endpoint must report load metrics continuously for at least this + // long before the endpoint weight will be used. This avoids churn when + // the set of endpoint addresses changes. Takes effect both immediately + // after we establish a connection to an endpoint and after + // weight_expiration_period has caused us to stop using the most recent + // load metrics. Default is 10 seconds. + BlackoutPeriod time.Duration `json:"blackoutPeriod,omitempty"` + + // If a given endpoint has not reported load metrics in this long, + // then we stop using the reported weight. This ensures that we do + // not continue to use very stale weights. Once we stop using a stale + // value, if we later start seeing fresh reports again, the + // blackout_period applies. Defaults to 3 minutes. + WeightExpirationPeriod time.Duration `json:"weightExpirationPeriod,omitempty"` + + // How often endpoint weights are recalculated. Default is 1 second. + WeightUpdatePeriod time.Duration `json:"weightUpdatePeriod,omitempty"` + + // The multiplier used to adjust endpoint weights with the error rate + // calculated as eps/qps. Default is 1.0. + ErrorUtilizationPenalty float64 `json:"errorUtilizationPenalty,omitempty"` +} diff --git a/balancer/weightedroundrobin/internal/internal.go b/balancer/weightedroundrobin/internal/internal.go new file mode 100644 index 00000000000..d39830261b2 --- /dev/null +++ b/balancer/weightedroundrobin/internal/internal.go @@ -0,0 +1,44 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package internal allows for easier testing of the weightedroundrobin +// package. +package internal + +import ( + "time" +) + +// AllowAnyWeightUpdatePeriod permits any setting of WeightUpdatePeriod for +// testing. Normally a minimum of 100ms is applied. +var AllowAnyWeightUpdatePeriod bool + +// LBConfig allows tests to produce a JSON form of the config from the struct +// instead of using a string. +type LBConfig struct { + EnableOOBLoadReport *bool `json:"enableOobLoadReport,omitempty"` + OOBReportingPeriod *time.Duration `json:"oobReportingPeriod,omitempty"` + BlackoutPeriod *time.Duration `json:"blackoutPeriod,omitempty"` + WeightExpirationPeriod *time.Duration `json:"weightExpirationPeriod,omitempty"` + WeightUpdatePeriod *time.Duration `json:"weightUpdatePeriod,omitempty"` + ErrorUtilizationPenalty *float64 `json:"errorUtilizationPenalty,omitempty"` +} + +// TimeNow can be overridden by tests to return a different value for the +// current time. +var TimeNow = time.Now diff --git a/balancer/weightedroundrobin/logging.go b/balancer/weightedroundrobin/logging.go new file mode 100644 index 00000000000..43184ca9ab9 --- /dev/null +++ b/balancer/weightedroundrobin/logging.go @@ -0,0 +1,34 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin + +import ( + "fmt" + + "google.golang.org/grpc/grpclog" + internalgrpclog "google.golang.org/grpc/internal/grpclog" +) + +const prefix = "[%p] " + +var logger = grpclog.Component("weighted-round-robin") + +func prefixLogger(p *wrrBalancer) *internalgrpclog.PrefixLogger { + return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) +} diff --git a/balancer/weightedroundrobin/scheduler.go b/balancer/weightedroundrobin/scheduler.go new file mode 100644 index 00000000000..e19428112e1 --- /dev/null +++ b/balancer/weightedroundrobin/scheduler.go @@ -0,0 +1,138 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package weightedroundrobin + +import ( + "math" +) + +type scheduler interface { + nextIndex() int +} + +// newScheduler uses scWeights to create a new scheduler for selecting subconns +// in a picker. It will return a round robin implementation if at least +// len(scWeights)-1 are zero or there is only a single subconn, otherwise it +// will return an Earliest Deadline First (EDF) scheduler implementation that +// selects the subchannels according to their weights. +func newScheduler(scWeights []float64, inc func() uint32) scheduler { + n := len(scWeights) + if n == 0 { + return nil + } + if n == 1 { + return &rrScheduler{numSCs: 1, inc: inc} + } + sum := float64(0) + numZero := 0 + max := float64(0) + for _, w := range scWeights { + sum += w + if w > max { + max = w + } + if w == 0 { + numZero++ + } + } + if numZero >= n-1 { + return &rrScheduler{numSCs: uint32(n), inc: inc} + } + unscaledMean := sum / float64(n-numZero) + scalingFactor := maxWeight / max + mean := uint16(math.Round(scalingFactor * unscaledMean)) + + weights := make([]uint16, n) + allEqual := true + for i, w := range scWeights { + if w == 0 { + // Backends with weight = 0 use the mean. + weights[i] = mean + } else { + scaledWeight := uint16(math.Round(scalingFactor * w)) + weights[i] = scaledWeight + if scaledWeight != mean { + allEqual = false + } + } + } + + if allEqual { + return &rrScheduler{numSCs: uint32(n), inc: inc} + } + + logger.Infof("using edf scheduler with weights: %v", weights) + return &edfScheduler{weights: weights, inc: inc} +} + +const maxWeight = math.MaxUint16 + +// edfScheduler implements EDF using the same algorithm as grpc-c++ here: +// +// https://github.com/grpc/grpc/blob/master/src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc +type edfScheduler struct { + inc func() uint32 + weights []uint16 +} + +// Returns the index in s.weights for the picker to choose. +func (s *edfScheduler) nextIndex() int { + const offset = maxWeight / 2 + + for { + idx := uint64(s.inc()) + + // The sequence number (idx) is split in two: the lower %n gives the + // index of the backend, and the rest gives the number of times we've + // iterated through all backends. `generation` is used to + // deterministically decide whether we pick or skip the backend on this + // iteration, in proportion to the backend's weight. + + backendIndex := idx % uint64(len(s.weights)) + generation := idx / uint64(len(s.weights)) + weight := uint64(s.weights[backendIndex]) + + // We pick a backend `weight` times per `maxWeight` generations. The + // multiply and modulus ~evenly spread out the picks for a given + // backend between different generations. The offset by `backendIndex` + // helps to reduce the chance of multiple consecutive non-picks: if we + // have two consecutive backends with an equal, say, 80% weight of the + // max, with no offset we would see 1/5 generations that skipped both. + // TODO(b/190488683): add test for offset efficacy. + mod := uint64(weight*generation+backendIndex*offset) % maxWeight + + if mod < maxWeight-weight { + continue + } + return int(backendIndex) + } +} + +// A simple RR scheduler to use for fallback when fewer than two backends have +// non-zero weights, or all backends have the the same weight, or when only one +// subconn exists. +type rrScheduler struct { + inc func() uint32 + numSCs uint32 +} + +func (s *rrScheduler) nextIndex() int { + idx := s.inc() + return int(idx % s.numSCs) +} diff --git a/balancer/weightedroundrobin/weightedroundrobin.go b/balancer/weightedroundrobin/weightedroundrobin.go index 6fc4d1910e6..bb029f07c36 100644 --- a/balancer/weightedroundrobin/weightedroundrobin.go +++ b/balancer/weightedroundrobin/weightedroundrobin.go @@ -16,16 +16,21 @@ * */ -// Package weightedroundrobin defines a weighted roundrobin balancer. +// Package weightedroundrobin provides an implementation of the weighted round +// robin LB policy, as defined in [gRFC A58]. +// +// # Experimental +// +// Notice: This package is EXPERIMENTAL and may be changed or removed in a +// later release. +// +// [gRFC A58]: https://github.com/grpc/proposal/blob/master/A58-client-side-weighted-round-robin-lb-policy.md package weightedroundrobin import ( "google.golang.org/grpc/resolver" ) -// Name is the name of weighted_round_robin balancer. -const Name = "weighted_round_robin" - // attributeKey is the type used as the key to store AddrInfo in the // BalancerAttributes field of resolver.Address. type attributeKey struct{} @@ -44,11 +49,6 @@ func (a AddrInfo) Equal(o interface{}) bool { // SetAddrInfo returns a copy of addr in which the BalancerAttributes field is // updated with addrInfo. -// -// # Experimental -// -// Notice: This API is EXPERIMENTAL and may be changed or removed in a -// later release. func SetAddrInfo(addr resolver.Address, addrInfo AddrInfo) resolver.Address { addr.BalancerAttributes = addr.BalancerAttributes.WithValue(attributeKey{}, addrInfo) return addr @@ -56,11 +56,6 @@ func SetAddrInfo(addr resolver.Address, addrInfo AddrInfo) resolver.Address { // GetAddrInfo returns the AddrInfo stored in the BalancerAttributes field of // addr. -// -// # Experimental -// -// Notice: This API is EXPERIMENTAL and may be changed or removed in a -// later release. func GetAddrInfo(addr resolver.Address) AddrInfo { v := addr.BalancerAttributes.Value(attributeKey{}) ai, _ := v.(AddrInfo) diff --git a/internal/grpcrand/grpcrand.go b/internal/grpcrand/grpcrand.go index 517ea70642a..0b092cfbe15 100644 --- a/internal/grpcrand/grpcrand.go +++ b/internal/grpcrand/grpcrand.go @@ -72,3 +72,10 @@ func Uint64() uint64 { defer mu.Unlock() return r.Uint64() } + +// Uint32 implements rand.Uint32 on the grpcrand global source. +func Uint32() uint32 { + mu.Lock() + defer mu.Unlock() + return r.Uint32() +} diff --git a/orca/producer.go b/orca/producer.go index 3b7ed8b67d8..ce108aad65c 100644 --- a/orca/producer.go +++ b/orca/producer.go @@ -199,12 +199,13 @@ func (p *producer) run(ctx context.Context, done chan struct{}, interval time.Du // Unimplemented; do not retry. logger.Error("Server doesn't support ORCA OOB load reporting protocol; not listening for load reports.") return - case status.Code(err) == codes.Unavailable: - // TODO: this code should ideally log an error, too, but for now we - // receive this code when shutting down the ClientConn. Once we - // can determine the state or ensure the producer is stopped before - // the stream ends, we can log an error when it's not a natural - // shutdown. + case status.Code(err) == codes.Unavailable, status.Code(err) == codes.Canceled: + // TODO: these codes should ideally log an error, too, but for now + // we receive them when shutting down the ClientConn (Unavailable + // if the stream hasn't started yet, and Canceled if it happens + // mid-stream). Once we can determine the state or ensure the + // producer is stopped before the stream ends, we can log an error + // when it's not a natural shutdown. default: // Log all other errors. logger.Error("Received unexpected stream error:", err) diff --git a/xds/internal/balancer/clusterimpl/picker.go b/xds/internal/balancer/clusterimpl/picker.go index 360fc44c9e4..3f354424f28 100644 --- a/xds/internal/balancer/clusterimpl/picker.go +++ b/xds/internal/balancer/clusterimpl/picker.go @@ -160,7 +160,7 @@ func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { d.loadStore.CallFinished(lIDStr, info.Err) load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport) - if !ok { + if !ok || load == nil { return } d.loadStore.CallServerLoad(lIDStr, serverLoadCPUName, load.CpuUtilization)