Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Drain() infinite loop and add test for concurrent Next() calls #1525

Merged
merged 6 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 28 additions & 18 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ type (
closed uint32
draining uint32
done chan struct{}
drained chan struct{}
connStatusChanged chan nats.Status
fetchNext chan *pullRequest
consumeOpts *consumeOpts
Expand Down Expand Up @@ -476,7 +475,6 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error
id: consumeID,
consumer: p,
done: make(chan struct{}, 1),
drained: make(chan struct{}, 1),
msgs: msgs,
errs: make(chan error, 1),
fetchNext: make(chan *pullRequest, 1),
Expand Down Expand Up @@ -537,69 +535,79 @@ var (
)

func (s *pullSubscription) Next() (Msg, error) {
s.Lock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think removing the lock from the entire Next() execution is a good idea. For example, it can cause problems with StopAfter option where executing Next() concurrently may lead to delivering more messages to the caller than specified in StopAfter option.

Holding the lock for the duration of Next() is challenging, and you're absolutely right, we need to have a way to unlock it for cleanup - therefore, I believe catching tle closure of s.done is necessary. Based on your branch I came up with a different solution (it's a bit crude right now as I just wanted to give an example, this would have to be cleaned up a bit): https://github.com/nats-io/nats.go/blob/fix-drain-in-messages/jetstream/pull.go#L537

Here's the gist of it:

  1. When s.done is closed, we unlock the mutex, so that the subscription can be cleaned up properly and s.msgs can be closed. We need a way to conditionally unlock mutex in defer, thus the done bool (I really don't like that...)

  2. If we detect we are draining, we set done to true and continue. Next iterations of the loop will check for the state of done and if it's set, will go to a select statement which does not listen on s.done. Those 2 select statements are identical except whether or not we have case <-s.done.

I extracted handleIncomingMessage() and handleError() methods to make it a bit more readable, but now it's just copy-paste from the select.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, as you mentioned, we could use a separate lock just to make sure Next() cannot be executed concurrently

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, I overlooked this StopAfter option, I think we should I add a test to verify this behavior first before we move on, the current test calls Next() sequentially.
After we add this test we should be able to refactor the code without breaking things.

There are also more elegant solutions:

  1. Using a dedicated lock for the subscription (Used in cleanup()) and another lock for the counter fields or maybe the rest of the fields
  2. Use atomic values for the counter fields

Which solution do you prefer?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using a separate lock for accessing subscription would be a preferable solution since I would be hesitant about allowing concurrent Next() calls - for concurrency, the suggested solution would be to create a whole new MessagesContext() for the same consumer. Separate lock sounds like it could actually simplify some things though, so that's nice.

Do you have time and would like to tackle this? Or should I take over?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piotrpio yes you're right this is a better solution.
I started working on adding a test to verify Next() concurrent calls so any changes afterwards won't break this behavior.
I'll see what I can do, and you too feel free to do whatever is best.
I'll keep you updated with what I come up with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that in cleanup() there's no risk of race conditions, so it might be ok without holding the lock.
It only reads the subscription field which is set by methods that create the pullSubscription struct like pullConsumer.Consume, pullConsumer.Messages and pullConsumer.fetch and the actual Subscription has it's own mutex.

But I don't know if this acceptable, I mean for future changes to the code that might introduce race conditions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's ok right now (and looks like it is) and it does not produce a race, we can try to go without the lock I think - if in the future we will need locking mechanisms we can always add it. Just please (if you're working on it), add an appropriate comment on why the lock is not needed.

Thank you again for your contribution, it's extremely valuable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will add the comment right now.

defer s.Unlock()
drainMode := atomic.LoadUint32(&s.draining) == 1
closed := atomic.LoadUint32(&s.closed) == 1
if closed && !drainMode {
return nil, ErrMsgIteratorClosed
}
hbMonitor := s.scheduleHeartbeatCheck(2 * s.consumeOpts.Heartbeat)

s.Lock()
consumeOpts := *s.consumeOpts
delivered := s.delivered
s.Unlock()

hbMonitor := s.scheduleHeartbeatCheck(2 * consumeOpts.Heartbeat)
defer func() {
if hbMonitor != nil {
hbMonitor.Stop()
}
}()

isConnected := true
if s.consumeOpts.StopAfter > 0 && s.delivered >= s.consumeOpts.StopAfter {
if consumeOpts.StopAfter > 0 && delivered >= consumeOpts.StopAfter {
s.Stop()
return nil, ErrMsgIteratorClosed
}

for {
s.Lock()
s.checkPending()
s.Unlock()

select {
case <-s.done:
drainMode := atomic.LoadUint32(&s.draining) == 1
if drainMode {
continue
}
return nil, ErrMsgIteratorClosed
case msg, ok := <-s.msgs:
if !ok {
// if msgs channel is closed, it means that subscription was either drained or stopped
s.Lock()
delete(s.consumer.subscriptions, s.id)
s.Unlock()
atomic.CompareAndSwapUint32(&s.draining, 1, 0)
return nil, ErrMsgIteratorClosed
}
if hbMonitor != nil {
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
hbMonitor.Reset(2 * consumeOpts.Heartbeat)
}
userMsg, msgErr := checkMsg(msg)
if !userMsg {
// heartbeat message
if msgErr == nil {
continue
}
if err := s.handleStatusMsg(msg, msgErr); err != nil {
s.Lock()
err := s.handleStatusMsg(msg, msgErr)
s.Unlock()
if err != nil {
s.Stop()
return nil, err
}
continue
}
s.Lock()
s.decrementPendingMsgs(msg)
s.incrementDeliveredMsgs()
s.Unlock()
return s.consumer.jetStream.toJSMsg(msg), nil
case err := <-s.errs:
if errors.Is(err, ErrNoHeartbeat) {
s.Lock()
s.pending.msgCount = 0
s.pending.byteCount = 0
if s.consumeOpts.ReportMissingHeartbeats {
s.Unlock()
if consumeOpts.ReportMissingHeartbeats {
return nil, err
}
if hbMonitor != nil {
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
hbMonitor.Reset(2 * consumeOpts.Heartbeat)
}
}
if errors.Is(err, errConnected) {
Expand Down Expand Up @@ -638,16 +646,18 @@ func (s *pullSubscription) Next() (Msg, error) {
return nil, err
}

s.Lock()
s.pending.msgCount = 0
s.pending.byteCount = 0
s.Unlock()
if hbMonitor != nil {
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
hbMonitor.Reset(2 * consumeOpts.Heartbeat)
}
}
}
if errors.Is(err, errDisconnected) {
if hbMonitor != nil {
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
hbMonitor.Reset(2 * consumeOpts.Heartbeat)
}
isConnected = false
}
Expand Down
112 changes: 61 additions & 51 deletions jetstream/test/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1293,69 +1293,79 @@ func TestPullConsumerMessages(t *testing.T) {
})

t.Run("with graceful shutdown", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)

nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
cases := map[string]func(jetstream.MessagesContext){
"stop": func(mc jetstream.MessagesContext) { mc.Stop() },
"drain": func(mc jetstream.MessagesContext) { mc.Drain() },
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
for name, unsubscribe := range cases {
t.Run(name, func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

it, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

publishTestMsgs(t, nc)
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

errs := make(chan error)
msgs := make([]jetstream.Msg, 0)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

go func() {
for {
msg, err := it.Next()
it, err := c.Messages()
if err != nil {
errs <- err
return
t.Fatalf("Unexpected error: %v", err)
}
msg.Ack()
msgs = append(msgs, msg)
}
}()

time.Sleep(10 * time.Millisecond)
it.Stop() // Next() should return ErrMsgIteratorClosed
publishTestMsgs(t, nc)

timeout := time.NewTimer(5 * time.Second)
errs := make(chan error)
msgs := make([]jetstream.Msg, 0)

select {
case <-timeout.C:
t.Fatal("Timed out waiting for Next() to return after Stop()")
case err := <-errs:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Unexpected error: %v", err)
}
go func() {
for {
msg, err := it.Next()
if err != nil {
errs <- err
return
}
msg.Ack()
msgs = append(msgs, msg)
}
}()

if len(msgs) != len(testMsgs) {
t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
}
time.Sleep(10 * time.Millisecond)
unsubscribe(it) // Next() should return ErrMsgIteratorClosed

timer := time.NewTimer(5 * time.Second)
defer timer.Stop()

select {
case <-timer.C:
t.Fatal("Timed out waiting for Next() to return")
case err := <-errs:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Unexpected error: %v", err)
}

if len(msgs) != len(testMsgs) {
t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
}
}
})
}
})

Expand Down