diff --git a/orca/producer.go b/orca/producer.go index df29a8176c70..de0c6c0ebc1a 100644 --- a/orca/producer.go +++ b/orca/producer.go @@ -19,6 +19,7 @@ package orca import ( "context" "sync" + "sync/atomic" "time" "google.golang.org/grpc" @@ -40,17 +41,27 @@ type producerBuilder struct{} func (*producerBuilder) Build(cci interface{}) (balancer.Producer, func()) { ctx, cancel := context.WithCancel(context.Background()) p := &producer{ - client: v3orcaservicegrpc.NewOpenRcaServiceClient(cci.(grpc.ClientConnInterface)), - closed: grpcsync.NewEvent(), - intervals: make(map[time.Duration]int), - listeners: make(map[OOBListener]struct{}), - backoff: internal.DefaultBackoffFunc, - hasIntervals: make(chan struct{}), + client: v3orcaservicegrpc.NewOpenRcaServiceClient(cci.(grpc.ClientConnInterface)), + closed: grpcsync.NewEvent(), + intervals: make(map[time.Duration]int), + listeners: make(map[OOBListener]struct{}), + backoff: internal.DefaultBackoffFunc, + initialized: 0, } + // Take the mutex at creation time. This will allow the first listener to + // be registered before the run goroutine calls minInterval. + p.mu.Lock() go p.run(ctx) return p, func() { + // Signal the run goroutine to exit. cancel() - <-p.closed.Done() // Block until stream stopped. + // The stop function that unregisters a listener takes the lock. Give + // it up to allow the run goroutine to exit. + p.mu.Unlock() + // Block until stream stopped. + <-p.closed.Done() + // Re-take the lock so the stop function can safely unlock it. + p.mu.Lock() } } @@ -78,7 +89,13 @@ type OOBListenerOptions struct { func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOptions) (stop func()) { pr, close := sc.GetOrBuildProducer(producerBuilderSingleton) p := pr.(*producer) - p.registerListener(l, opts.ReportInterval) + if initialized := atomic.SwapInt32(&p.initialized, 1); initialized != 0 { + // If we're still initializing, the mutex is held. Otherwise, take the + // mutex now. + p.mu.Lock() + } + p.registerListenerLocked(l, opts.ReportInterval) + p.mu.Unlock() // TODO: When we can register for SubConn state updates, automatically call // stop() on SHUTDOWN. @@ -86,8 +103,10 @@ func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOpt // If stop is called multiple times, prevent it from having any effect on // subsequent calls. return grpcsync.OnceFunc(func() { - p.unregisterListener(l, opts.ReportInterval) - close() + p.mu.Lock() + p.unregisterListenerLocked(l, opts.ReportInterval) + close() // If this is the final producer instance, stops run(). + p.mu.Unlock() }) } @@ -101,38 +120,28 @@ type producer struct { // reports a result. backoff func(int) time.Duration - mu sync.Mutex - intervals map[time.Duration]int // map from interval time to count of listeners requesting that time - listeners map[OOBListener]struct{} // set of registered listeners - hasIntervals chan struct{} // created when intervals is empty; closed and nilled when non-empty. + initialized int32 // Set to 1 when initialized. Accessed atomically. + + mu sync.Mutex + intervals map[time.Duration]int // map from interval time to count of listeners requesting that time + listeners map[OOBListener]struct{} // set of registered listeners } // registerListener adds the listener and its requested report interval to the -// producer. -func (p *producer) registerListener(l OOBListener, interval time.Duration) { - p.mu.Lock() - defer p.mu.Unlock() - if p.hasIntervals != nil { - close(p.hasIntervals) - p.hasIntervals = nil - } +// producer. p.mu must be held. +func (p *producer) registerListenerLocked(l OOBListener, interval time.Duration) { p.listeners[l] = struct{}{} p.intervals[interval]++ } // registerListener removes the listener and its requested report interval to -// the producer. -func (p *producer) unregisterListener(l OOBListener, interval time.Duration) { - p.mu.Lock() - defer p.mu.Unlock() +// the producer. p.mu must be held. +func (p *producer) unregisterListenerLocked(l OOBListener, interval time.Duration) { delete(p.listeners, l) p.intervals[interval]-- if p.intervals[interval] == 0 { delete(p.intervals, interval) } - if len(p.intervals) == 0 { - p.hasIntervals = make(chan struct{}) - } } // minInterval returns the smallest key in p.intervals. If p.intervals is @@ -140,29 +149,17 @@ func (p *producer) unregisterListener(l OOBListener, interval time.Duration) { // producer closure; it is the caller's duty to determine if 0 is a valid value // or if the producer was closed. func (p *producer) minInterval() time.Duration { - for !p.closed.HasFired() { - p.mu.Lock() - if len(p.intervals) == 0 { - ch := p.hasIntervals - p.mu.Unlock() - select { - case <-p.closed.Done(): - case <-ch: - } - continue - } - var min time.Duration - first := true - for t := range p.intervals { - if t < min || first { - min = t - first = false - } + p.mu.Lock() + defer p.mu.Unlock() + var min time.Duration + first := true + for t := range p.intervals { + if t < min || first { + min = t + first = false } - p.mu.Unlock() - return min } - return 0 + return min } // run manages the ORCA OOB stream on the subchannel. diff --git a/orca/service_test.go b/orca/service_test.go index 715d53241c71..89e4c38a8e72 100644 --- a/orca/service_test.go +++ b/orca/service_test.go @@ -73,13 +73,13 @@ func (t *testServiceImpl) EmptyCall(context.Context, *testpb.Empty) (*testpb.Emp return &testpb.Empty{}, nil } -// Test_E2E_CustomBackendMetrics_OutOfBand tests the injection of out-of-band +// TestE2E_CustomBackendMetrics_OutOfBand tests the injection of out-of-band // custom backend metrics from the server application, and verifies that // expected load reports are received at the client. // // TODO: Change this test to use the client API, when ready, to read the // out-of-band metrics pushed by the server. -func (s) Test_E2E_CustomBackendMetrics_OutOfBand(t *testing.T) { +func (s) TestE2E_CustomBackendMetrics_OutOfBand(t *testing.T) { lis, err := testutils.LocalTCPListener() if err != nil { t.Fatal(err) @@ -123,7 +123,7 @@ func (s) Test_E2E_CustomBackendMetrics_OutOfBand(t *testing.T) { errCh <- fmt.Errorf("UnaryCall failed: %v", err) return } - time.Sleep(10 * time.Millisecond) + time.Sleep(2 * time.Millisecond) } errCh <- nil }()