From 540eaf9399158508be7a0df54776e78f91d8811b Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Wed, 3 May 2023 11:06:28 -0700 Subject: [PATCH] take 4 --- orca/producer.go | 112 ++++++++++++++++++++---------------------- orca/producer_test.go | 12 ++--- orca/service_test.go | 15 +++--- 3 files changed, 66 insertions(+), 73 deletions(-) diff --git a/orca/producer.go b/orca/producer.go index df29a8176c70..227baeb01ddf 100644 --- a/orca/producer.go +++ b/orca/producer.go @@ -38,20 +38,13 @@ type producerBuilder struct{} // Build constructs and returns a producer and its cleanup function func (*producerBuilder) Build(cci interface{}) (balancer.Producer, func()) { - ctx, cancel := context.WithCancel(context.Background()) p := &producer{ - client: v3orcaservicegrpc.NewOpenRcaServiceClient(cci.(grpc.ClientConnInterface)), - closed: grpcsync.NewEvent(), - intervals: make(map[time.Duration]int), - listeners: make(map[OOBListener]struct{}), - backoff: internal.DefaultBackoffFunc, - hasIntervals: make(chan struct{}), - } - go p.run(ctx) - return p, func() { - cancel() - <-p.closed.Done() // Block until stream stopped. + client: v3orcaservicegrpc.NewOpenRcaServiceClient(cci.(grpc.ClientConnInterface)), + intervals: make(map[time.Duration]int), + listeners: make(map[OOBListener]struct{}), + backoff: internal.DefaultBackoffFunc, } + return p, func() {} } var producerBuilderSingleton = &producerBuilder{} @@ -78,6 +71,7 @@ type OOBListenerOptions struct { func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOptions) (stop func()) { pr, close := sc.GetOrBuildProducer(producerBuilderSingleton) p := pr.(*producer) + p.registerListener(l, opts.ReportInterval) // TODO: When we can register for SubConn state updates, automatically call @@ -94,17 +88,18 @@ func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOpt type producer struct { client v3orcaservicegrpc.OpenRcaServiceClient - closed *grpcsync.Event // fired when closure completes // backoff is called between stream attempts to determine how long to delay // to avoid overloading a server experiencing problems. The attempt count // is incremented when stream errors occur and is reset when the stream // reports a result. backoff func(int) time.Duration - mu sync.Mutex - intervals map[time.Duration]int // map from interval time to count of listeners requesting that time - listeners map[OOBListener]struct{} // set of registered listeners - hasIntervals chan struct{} // created when intervals is empty; closed and nilled when non-empty. + mu sync.Mutex + intervals map[time.Duration]int // map from interval time to count of listeners requesting that time + listeners map[OOBListener]struct{} // set of registered listeners + minInterval time.Duration + stop func() // stops the current run goroutine + stopped chan struct{} // closed when the run goroutine exits } // registerListener adds the listener and its requested report interval to the @@ -112,12 +107,13 @@ type producer struct { func (p *producer) registerListener(l OOBListener, interval time.Duration) { p.mu.Lock() defer p.mu.Unlock() - if p.hasIntervals != nil { - close(p.hasIntervals) - p.hasIntervals = nil - } + p.listeners[l] = struct{}{} p.intervals[interval]++ + if len(p.listeners) == 1 || interval < p.minInterval { + p.minInterval = interval + p.updateRunLocked() + } } // registerListener removes the listener and its requested report interval to @@ -125,49 +121,52 @@ func (p *producer) registerListener(l OOBListener, interval time.Duration) { func (p *producer) unregisterListener(l OOBListener, interval time.Duration) { p.mu.Lock() defer p.mu.Unlock() + delete(p.listeners, l) p.intervals[interval]-- if p.intervals[interval] == 0 { delete(p.intervals, interval) - } - if len(p.intervals) == 0 { - p.hasIntervals = make(chan struct{}) + + if p.minInterval == interval { + p.recomputeMinInterval() + p.updateRunLocked() + } } } -// minInterval returns the smallest key in p.intervals. If p.intervals is -// empty, blocks until it is non-empty or the producer is closed. Returns 0 on -// producer closure; it is the caller's duty to determine if 0 is a valid value -// or if the producer was closed. -func (p *producer) minInterval() time.Duration { - for !p.closed.HasFired() { - p.mu.Lock() - if len(p.intervals) == 0 { - ch := p.hasIntervals - p.mu.Unlock() - select { - case <-p.closed.Done(): - case <-ch: - } - continue - } - var min time.Duration - first := true - for t := range p.intervals { - if t < min || first { - min = t - first = false - } +// recomputeMinInterval sets p.minInterval to the minimum key's value in +// p.intervals. +func (p *producer) recomputeMinInterval() { + first := true + for interval := range p.intervals { + if first || interval < p.minInterval { + p.minInterval = interval + first = false } - p.mu.Unlock() - return min } - return 0 +} + +// updateRunLocked is called whenever the run goroutine needs to be started / +// stopped / restarted due to: 1. the initial listener being registered, 2. the +// final listener being unregistered, or 3. the minimum registered interval +// changing. +func (p *producer) updateRunLocked() { + if p.stop != nil { + p.stop() + <-p.stopped + p.stop = nil + } + if len(p.listeners) > 0 { + var ctx context.Context + ctx, p.stop = context.WithCancel(context.Background()) + p.stopped = make(chan struct{}) + go p.run(ctx, p.minInterval) + } } // run manages the ORCA OOB stream on the subchannel. -func (p *producer) run(ctx context.Context) { - defer p.closed.Fire() +func (p *producer) run(ctx context.Context, interval time.Duration) { + defer close(p.stopped) backoffAttempt := 0 backoffTimer := time.NewTimer(0) @@ -178,7 +177,7 @@ func (p *producer) run(ctx context.Context) { return } - resetBackoff, err := p.runStream(ctx) + resetBackoff, err := p.runStream(ctx, interval) if resetBackoff { backoffTimer.Reset(0) @@ -215,8 +214,7 @@ func (p *producer) run(ctx context.Context) { // runStream runs a single stream on the subchannel and returns the resulting // error, if any, and whether or not the run loop should reset the backoff // timer to zero or advance it. -func (p *producer) runStream(ctx context.Context) (resetBackoff bool, err error) { - interval := p.minInterval() +func (p *producer) runStream(ctx context.Context, interval time.Duration) (resetBackoff bool, err error) { streamCtx, cancel := context.WithCancel(ctx) defer cancel() stream, err := p.client.StreamCoreMetrics(streamCtx, &v3orcaservicepb.OrcaLoadReportRequest{ @@ -237,9 +235,5 @@ func (p *producer) runStream(ctx context.Context) (resetBackoff bool, err error) l.OnLoadReport(report) } p.mu.Unlock() - if interval != p.minInterval() { - // restart stream to use new interval - return true, nil - } } } diff --git a/orca/producer_test.go b/orca/producer_test.go index f15317995dec..054698d0f38d 100644 --- a/orca/producer_test.go +++ b/orca/producer_test.go @@ -519,12 +519,11 @@ func (s) TestProducerMultipleListeners(t *testing.T) { checkReports(2, 1, 0) // Register listener 3 with a more frequent interval; stream is recreated - // with this interval after the next report is received. The first report - // will go to all three listeners. + // with this interval. The next report will go to all three listeners. oobLis3.cleanup = orca.RegisterOOBListener(li.sc, oobLis3, lisOpts3) + awaitRequest(reportInterval3) fake.respCh <- loadReportWant checkReports(3, 2, 1) - awaitRequest(reportInterval3) // Another report without a change in listeners should go to all three listeners. fake.respCh <- loadReportWant @@ -536,13 +535,12 @@ func (s) TestProducerMultipleListeners(t *testing.T) { fake.respCh <- loadReportWant checkReports(5, 3, 3) - // Stop listener 3. This makes the interval longer, with stream recreation - // delayed until the next report is received. Reports should only go to - // listener 1 now. + // Stop listener 3. This makes the interval longer. Reports should only + // go to listener 1 now. oobLis3.Stop() + awaitRequest(reportInterval1) fake.respCh <- loadReportWant checkReports(6, 3, 3) - awaitRequest(reportInterval1) // Another report without a change in listeners should go to the first listener. fake.respCh <- loadReportWant checkReports(7, 3, 3) diff --git a/orca/service_test.go b/orca/service_test.go index 715d53241c71..a95033de3d42 100644 --- a/orca/service_test.go +++ b/orca/service_test.go @@ -73,20 +73,20 @@ func (t *testServiceImpl) EmptyCall(context.Context, *testpb.Empty) (*testpb.Emp return &testpb.Empty{}, nil } -// Test_E2E_CustomBackendMetrics_OutOfBand tests the injection of out-of-band +// TestE2E_CustomBackendMetrics_OutOfBand tests the injection of out-of-band // custom backend metrics from the server application, and verifies that // expected load reports are received at the client. // // TODO: Change this test to use the client API, when ready, to read the // out-of-band metrics pushed by the server. -func (s) Test_E2E_CustomBackendMetrics_OutOfBand(t *testing.T) { +func (s) TestE2E_CustomBackendMetrics_OutOfBand(t *testing.T) { lis, err := testutils.LocalTCPListener() if err != nil { t.Fatal(err) } // Override the min reporting interval in the internal package. - const shortReportingInterval = 100 * time.Millisecond + const shortReportingInterval = 10 * time.Millisecond opts := orca.ServiceOptions{MinReportingInterval: shortReportingInterval} internal.AllowAnyMinReportingInterval.(func(*orca.ServiceOptions))(&opts) @@ -110,20 +110,21 @@ func (s) Test_E2E_CustomBackendMetrics_OutOfBand(t *testing.T) { } defer cc.Close() - // Spawn a goroutine which sends 100 unary RPCs to the test server. This + // Spawn a goroutine which sends 20 unary RPCs to the test server. This // will trigger the injection of custom backend metrics from the // testServiceImpl. + const numRequests = 20 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() testStub := testgrpc.NewTestServiceClient(cc) errCh := make(chan error, 1) go func() { - for i := 0; i < 100; i++ { + for i := 0; i < numRequests; i++ { if _, err := testStub.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { errCh <- fmt.Errorf("UnaryCall failed: %v", err) return } - time.Sleep(10 * time.Millisecond) + time.Sleep(time.Millisecond) } errCh <- nil }() @@ -151,7 +152,7 @@ func (s) Test_E2E_CustomBackendMetrics_OutOfBand(t *testing.T) { wantProto := &v3orcapb.OrcaLoadReport{ CpuUtilization: 50.0, MemUtilization: 99.0, - Utilization: map[string]float64{requestsMetricKey: 100.0}, + Utilization: map[string]float64{requestsMetricKey: numRequests}, } gotProto, err := stream.Recv() if err != nil {