Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

orca: fix race at producer startup #6245

Merged
merged 2 commits into from May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 48 additions & 29 deletions orca/producer.go
Expand Up @@ -38,19 +38,13 @@ type producerBuilder struct{}

// Build constructs and returns a producer and its cleanup function
func (*producerBuilder) Build(cci interface{}) (balancer.Producer, func()) {
ctx, cancel := context.WithCancel(context.Background())
p := &producer{
client: v3orcaservicegrpc.NewOpenRcaServiceClient(cci.(grpc.ClientConnInterface)),
closed: grpcsync.NewEvent(),
intervals: make(map[time.Duration]int),
listeners: make(map[OOBListener]struct{}),
backoff: internal.DefaultBackoffFunc,
}
go p.run(ctx)
return p, func() {
cancel()
<-p.closed.Done() // Block until stream stopped.
}
return p, func() {}
}

var producerBuilderSingleton = &producerBuilder{}
Expand All @@ -77,6 +71,7 @@ type OOBListenerOptions struct {
func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOptions) (stop func()) {
pr, close := sc.GetOrBuildProducer(producerBuilderSingleton)
p := pr.(*producer)

p.registerListener(l, opts.ReportInterval)

// TODO: When we can register for SubConn state updates, automatically call
Expand All @@ -93,57 +88,86 @@ func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOpt
type producer struct {
client v3orcaservicegrpc.OpenRcaServiceClient

closed *grpcsync.Event // fired when closure completes
// backoff is called between stream attempts to determine how long to delay
// to avoid overloading a server experiencing problems. The attempt count
// is incremented when stream errors occur and is reset when the stream
// reports a result.
backoff func(int) time.Duration

mu sync.Mutex
intervals map[time.Duration]int // map from interval time to count of listeners requesting that time
listeners map[OOBListener]struct{} // set of registered listeners
mu sync.Mutex
intervals map[time.Duration]int // map from interval time to count of listeners requesting that time
listeners map[OOBListener]struct{} // set of registered listeners
minInterval time.Duration
stop func() // stops the current run goroutine
stopped chan struct{} // closed when the run goroutine exits
}

// registerListener adds the listener and its requested report interval to the
// producer.
func (p *producer) registerListener(l OOBListener, interval time.Duration) {
p.mu.Lock()
defer p.mu.Unlock()

p.listeners[l] = struct{}{}
p.intervals[interval]++
if len(p.listeners) == 1 || interval < p.minInterval {
p.minInterval = interval
p.updateRunLocked()
}
}

// registerListener removes the listener and its requested report interval to
// the producer.
func (p *producer) unregisterListener(l OOBListener, interval time.Duration) {
p.mu.Lock()
defer p.mu.Unlock()

delete(p.listeners, l)
p.intervals[interval]--
if p.intervals[interval] == 0 {
delete(p.intervals, interval)

if p.minInterval == interval {
p.recomputeMinInterval()
p.updateRunLocked()
}
}
}

// minInterval returns the smallest key in p.intervals.
func (p *producer) minInterval() time.Duration {
p.mu.Lock()
defer p.mu.Unlock()
var min time.Duration
// recomputeMinInterval sets p.minInterval to the minimum key's value in
// p.intervals.
func (p *producer) recomputeMinInterval() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: s/recomputeMinInterval/recomputeMinIntervalLocked/ ?

first := true
for t := range p.intervals {
if t < min || first {
min = t
for interval := range p.intervals {
if first || interval < p.minInterval {
p.minInterval = interval
first = false
}
}
return min
}

// updateRunLocked is called whenever the run goroutine needs to be started /
// stopped / restarted due to: 1. the initial listener being registered, 2. the
// final listener being unregistered, or 3. the minimum registered interval
// changing.
func (p *producer) updateRunLocked() {
if p.stop != nil {
p.stop()
<-p.stopped
p.stop = nil
}
if len(p.listeners) > 0 {
var ctx context.Context
ctx, p.stop = context.WithCancel(context.Background())
p.stopped = make(chan struct{})
go p.run(ctx, p.minInterval)
}
}

// run manages the ORCA OOB stream on the subchannel.
func (p *producer) run(ctx context.Context) {
defer p.closed.Fire()
func (p *producer) run(ctx context.Context, interval time.Duration) {
defer close(p.stopped)

backoffAttempt := 0
backoffTimer := time.NewTimer(0)
for ctx.Err() == nil {
Expand All @@ -153,7 +177,7 @@ func (p *producer) run(ctx context.Context) {
return
}

resetBackoff, err := p.runStream(ctx)
resetBackoff, err := p.runStream(ctx, interval)

if resetBackoff {
backoffTimer.Reset(0)
Expand Down Expand Up @@ -190,8 +214,7 @@ func (p *producer) run(ctx context.Context) {
// runStream runs a single stream on the subchannel and returns the resulting
// error, if any, and whether or not the run loop should reset the backoff
// timer to zero or advance it.
func (p *producer) runStream(ctx context.Context) (resetBackoff bool, err error) {
interval := p.minInterval()
func (p *producer) runStream(ctx context.Context, interval time.Duration) (resetBackoff bool, err error) {
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := p.client.StreamCoreMetrics(streamCtx, &v3orcaservicepb.OrcaLoadReportRequest{
Expand All @@ -212,9 +235,5 @@ func (p *producer) runStream(ctx context.Context) (resetBackoff bool, err error)
l.OnLoadReport(report)
}
p.mu.Unlock()
if interval != p.minInterval() {
// restart stream to use new interval
return true, nil
}
}
}
12 changes: 5 additions & 7 deletions orca/producer_test.go
Expand Up @@ -519,12 +519,11 @@ func (s) TestProducerMultipleListeners(t *testing.T) {
checkReports(2, 1, 0)

// Register listener 3 with a more frequent interval; stream is recreated
// with this interval after the next report is received. The first report
// will go to all three listeners.
// with this interval. The next report will go to all three listeners.
oobLis3.cleanup = orca.RegisterOOBListener(li.sc, oobLis3, lisOpts3)
awaitRequest(reportInterval3)
fake.respCh <- loadReportWant
checkReports(3, 2, 1)
awaitRequest(reportInterval3)

// Another report without a change in listeners should go to all three listeners.
fake.respCh <- loadReportWant
Expand All @@ -536,13 +535,12 @@ func (s) TestProducerMultipleListeners(t *testing.T) {
fake.respCh <- loadReportWant
checkReports(5, 3, 3)

// Stop listener 3. This makes the interval longer, with stream recreation
// delayed until the next report is received. Reports should only go to
// listener 1 now.
// Stop listener 3. This makes the interval longer. Reports should only
// go to listener 1 now.
oobLis3.Stop()
awaitRequest(reportInterval1)
fake.respCh <- loadReportWant
checkReports(6, 3, 3)
awaitRequest(reportInterval1)
// Another report without a change in listeners should go to the first listener.
fake.respCh <- loadReportWant
checkReports(7, 3, 3)
Expand Down
11 changes: 6 additions & 5 deletions orca/service_test.go
Expand Up @@ -86,7 +86,7 @@ func (s) TestE2E_CustomBackendMetrics_OutOfBand(t *testing.T) {
}

// Override the min reporting interval in the internal package.
const shortReportingInterval = 100 * time.Millisecond
const shortReportingInterval = 10 * time.Millisecond
smr := orca.NewServerMetricsRecorder()
opts := orca.ServiceOptions{MinReportingInterval: shortReportingInterval, ServerMetricsProvider: smr}
internal.AllowAnyMinReportingInterval.(func(*orca.ServiceOptions))(&opts)
Expand All @@ -110,20 +110,21 @@ func (s) TestE2E_CustomBackendMetrics_OutOfBand(t *testing.T) {
}
defer cc.Close()

// Spawn a goroutine which sends 100 unary RPCs to the test server. This
// Spawn a goroutine which sends 20 unary RPCs to the test server. This
// will trigger the injection of custom backend metrics from the
// testServiceImpl.
const numRequests = 20
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testStub := testgrpc.NewTestServiceClient(cc)
errCh := make(chan error, 1)
go func() {
for i := 0; i < 100; i++ {
for i := 0; i < numRequests; i++ {
if _, err := testStub.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
errCh <- fmt.Errorf("UnaryCall failed: %v", err)
return
}
time.Sleep(10 * time.Millisecond)
time.Sleep(time.Millisecond)
}
errCh <- nil
}()
Expand Down Expand Up @@ -151,7 +152,7 @@ func (s) TestE2E_CustomBackendMetrics_OutOfBand(t *testing.T) {
wantProto := &v3orcapb.OrcaLoadReport{
CpuUtilization: 50.0,
MemUtilization: 99.0,
Utilization: map[string]float64{requestsMetricKey: 100.0},
Utilization: map[string]float64{requestsMetricKey: numRequests},
}
gotProto, err := stream.Recv()
if err != nil {
Expand Down