Skip to content

Commit a7c8208

Browse files
authoredJun 12, 2024··
Allow SDK to handle speculative WFT with command events (#1509)
Allow SDK to handle speculative WFT with command events
1 parent d4ff1f6 commit a7c8208

4 files changed

+216
-30
lines changed
 

‎internal/internal_task_handlers.go

+48-26
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,10 @@ type (
115115
isWorkflowCompleted bool
116116
result *commonpb.Payloads
117117
err error
118-
118+
// previousStartedEventID is the event ID of the workflow task started event of the previous workflow task.
119119
previousStartedEventID int64
120+
// lastHandledEventID is the event ID of the last event that the workflow state machine processed.
121+
lastHandledEventID int64
120122

121123
newCommands []*commandpb.Command
122124
newMessages []*protocolpb.Message
@@ -170,18 +172,19 @@ type (
170172

171173
// history wrapper method to help information about events.
172174
history struct {
173-
workflowTask *workflowTask
174-
eventsHandler *workflowExecutionEventHandlerImpl
175-
loadedEvents []*historypb.HistoryEvent
176-
currentIndex int
177-
nextEventID int64 // next expected eventID for sanity
178-
lastEventID int64 // last expected eventID, zero indicates read until end of stream
179-
next []*historypb.HistoryEvent
180-
nextMessages []*protocolpb.Message
181-
nextFlags []sdkFlag
182-
binaryChecksum string
183-
sdkVersion string
184-
sdkName string
175+
workflowTask *workflowTask
176+
eventsHandler *workflowExecutionEventHandlerImpl
177+
loadedEvents []*historypb.HistoryEvent
178+
currentIndex int
179+
nextEventID int64 // next expected eventID for sanity
180+
lastEventID int64 // last expected eventID, zero indicates read until end of stream
181+
lastHandledEventID int64 // last event ID that was processed
182+
next []*historypb.HistoryEvent
183+
nextMessages []*protocolpb.Message
184+
nextFlags []sdkFlag
185+
binaryChecksum string
186+
sdkVersion string
187+
sdkName string
185188
}
186189

187190
workflowTaskHeartbeatError struct {
@@ -219,13 +222,14 @@ type (
219222
}
220223
)
221224

222-
func newHistory(task *workflowTask, eventsHandler *workflowExecutionEventHandlerImpl) *history {
225+
func newHistory(lastHandledEventID int64, task *workflowTask, eventsHandler *workflowExecutionEventHandlerImpl) *history {
223226
result := &history{
224-
workflowTask: task,
225-
eventsHandler: eventsHandler,
226-
loadedEvents: task.task.History.Events,
227-
currentIndex: 0,
228-
lastEventID: task.task.GetStartedEventId(),
227+
workflowTask: task,
228+
eventsHandler: eventsHandler,
229+
loadedEvents: task.task.History.Events,
230+
currentIndex: 0,
231+
lastEventID: task.task.GetStartedEventId(),
232+
lastHandledEventID: lastHandledEventID,
229233
}
230234
if len(result.loadedEvents) > 0 {
231235
result.nextEventID = result.loadedEvents[0].GetEventId()
@@ -454,6 +458,11 @@ OrderEvents:
454458
}
455459

456460
eh.nextEventID++
461+
if eventID <= eh.lastHandledEventID {
462+
eh.currentIndex++
463+
continue
464+
}
465+
eh.lastHandledEventID = eventID
457466

458467
switch event.GetEventType() {
459468
case enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED:
@@ -583,16 +592,15 @@ func (w *workflowExecutionContextImpl) Unlock(err error) {
583592
defer w.mutex.Unlock()
584593
if err != nil || w.err != nil || w.isWorkflowCompleted ||
585594
(w.wth.cache.MaxWorkflowCacheSize() <= 0 && !w.hasPendingLocalActivityWork()) {
586-
// TODO: in case of closed, it asumes the close command always succeed. need server side change to return
595+
// TODO: in case of closed, it assumes the close command always succeed. need server side change to return
587596
// error to indicate the close failure case. This should be rare case. For now, always remove the cache, and
588597
// if the close command failed, the next command will have to rebuild the state.
589598
if w.wth.cache.getWorkflowCache().Exist(w.workflowInfo.WorkflowExecution.RunID) {
590599
w.wth.cache.removeWorkflowContext(w.workflowInfo.WorkflowExecution.RunID)
591600
w.cached = false
592-
} else {
593-
// sticky is disabled, manually clear the workflow state.
594-
w.clearState()
595601
}
602+
// Clear the state so other tasks waiting on the context know it should be discarded.
603+
w.clearState()
596604
} else if !w.cached {
597605
// Clear the state if we never cached the workflow so coroutines can be
598606
// exited
@@ -638,6 +646,7 @@ func (w *workflowExecutionContextImpl) clearState() {
638646
w.result = nil
639647
w.err = nil
640648
w.previousStartedEventID = 0
649+
w.lastHandledEventID = 0
641650
w.newCommands = nil
642651
w.newMessages = nil
643652

@@ -755,10 +764,10 @@ func (wth *workflowTaskHandlerImpl) GetOrCreateWorkflowContext(
755764
// Verify the cached state is current and for the correct worker
756765
if workflowContext != nil {
757766
workflowContext.Lock()
758-
if task.Query != nil && !isFullHistory && wth == workflowContext.wth {
767+
if task.Query != nil && !isFullHistory && wth == workflowContext.wth && !workflowContext.IsDestroyed() {
759768
// query task and we have a valid cached state
760769
metricsHandler.Counter(metrics.StickyCacheHit).Inc(1)
761-
} else if history.Events[0].GetEventId() == workflowContext.previousStartedEventID+1 && wth == workflowContext.wth {
770+
} else if history.Events[0].GetEventId() == workflowContext.previousStartedEventID+1 && wth == workflowContext.wth && !workflowContext.IsDestroyed() {
762771
// non query task and we have a valid cached state
763772
metricsHandler.Counter(metrics.StickyCacheHit).Inc(1)
764773
} else {
@@ -989,7 +998,14 @@ func (w *workflowExecutionContextImpl) ProcessWorkflowTask(workflowTask *workflo
989998
w.SetCurrentTask(task)
990999

9911000
eventHandler := w.getEventHandler()
992-
reorderedHistory := newHistory(workflowTask, eventHandler)
1001+
reorderedHistory := newHistory(w.lastHandledEventID, workflowTask, eventHandler)
1002+
defer func() {
1003+
// After processing the workflow task, update the last handled event ID
1004+
// to the last event ID in the history. We do this regardless of whether the workflow task
1005+
// was successfully processed or not. This is because a failed workflow task will cause the
1006+
// cache to be evicted and the next workflow task will start from the beginning of the history.
1007+
w.lastHandledEventID = reorderedHistory.lastHandledEventID
1008+
}()
9931009
var replayOutbox []outboxEntry
9941010
var replayCommands []*commandpb.Command
9951011
var respondEvents []*historypb.HistoryEvent
@@ -1400,6 +1416,12 @@ func (w *workflowExecutionContextImpl) SetCurrentTask(task *workflowservice.Poll
14001416
}
14011417

14021418
func (w *workflowExecutionContextImpl) SetPreviousStartedEventID(eventID int64) {
1419+
// We must reset the last event we handled to be after the last WFT we really completed
1420+
// + any command events (since the SDK "processed" those when it emitted the commands). This
1421+
// is also equal to what we just processed in the speculative task, minus two, since we
1422+
// would've just handled the most recent WFT started event, and we need to drop that & the
1423+
// schedule event just before it.
1424+
w.lastHandledEventID = w.lastHandledEventID - 2
14031425
w.previousStartedEventID = eventID
14041426
}
14051427

‎internal/internal_task_handlers_interfaces_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommands() {
187187

188188
workflowTask := &workflowTask{task: task, historyIterator: historyIterator}
189189

190-
eh := newHistory(workflowTask, nil)
190+
eh := newHistory(0, workflowTask, nil)
191191

192192
nextTask, err := eh.nextTask()
193193

@@ -232,7 +232,7 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommandsSdkFlags() {
232232

233233
workflowTask := &workflowTask{task: task, historyIterator: historyIterator}
234234

235-
eh := newHistory(workflowTask, nil)
235+
eh := newHistory(0, workflowTask, nil)
236236

237237
nextTask, err := eh.nextTask()
238238

@@ -301,7 +301,7 @@ func (s *PollLayerInterfacesTestSuite) TestMessageCommands() {
301301

302302
workflowTask := &workflowTask{task: task, historyIterator: historyIterator}
303303

304-
eh := newHistory(workflowTask, nil)
304+
eh := newHistory(0, workflowTask, nil)
305305

306306
nextTask, err := eh.nextTask()
307307
s.NoError(err)
@@ -370,7 +370,7 @@ func (s *PollLayerInterfacesTestSuite) TestEmptyPages() {
370370
}
371371

372372
workflowTask := &workflowTask{task: task, historyIterator: historyIterator}
373-
eh := newHistory(workflowTask, nil)
373+
eh := newHistory(0, workflowTask, nil)
374374

375375
type result struct {
376376
events []*historypb.HistoryEvent

‎internal/internal_task_handlers_test.go

+9
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,15 @@ func createTestUpsertWorkflowSearchAttributesForChangeVersion(eventID int64, wor
426426
}
427427
}
428428

429+
func createTestProtocolMessageUpdateRequest(ID string, eventID int64, request *updatepb.Request) *protocolpb.Message {
430+
return &protocolpb.Message{
431+
Id: uuid.New(),
432+
ProtocolInstanceId: ID,
433+
SequencingId: &protocolpb.Message_EventId{EventId: eventID},
434+
Body: protocol.MustMarshalAny(request),
435+
}
436+
}
437+
429438
func createWorkflowTask(
430439
events []*historypb.HistoryEvent,
431440
previousStartEventID int64,

‎internal/internal_task_pollers_test.go

+155
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,19 @@ import (
3030
"errors"
3131
"sync/atomic"
3232
"testing"
33+
"time"
3334

3435
"github.com/golang/mock/gomock"
3536
"github.com/stretchr/testify/require"
3637
commonpb "go.temporal.io/api/common/v1"
3738
historypb "go.temporal.io/api/history/v1"
39+
protocolpb "go.temporal.io/api/protocol/v1"
3840
taskqueuepb "go.temporal.io/api/taskqueue/v1"
41+
"go.temporal.io/api/update/v1"
3942
"go.temporal.io/api/workflowservice/v1"
4043
"go.temporal.io/api/workflowservicemock/v1"
4144
"google.golang.org/grpc"
45+
"google.golang.org/protobuf/types/known/durationpb"
4246
)
4347

4448
type countingTaskHandler struct {
@@ -222,3 +226,154 @@ func TestWFTCorruption(t *testing.T) {
222226
// Workflow should not be in cache
223227
require.Nil(t, cache.getWorkflowContext(runID))
224228
}
229+
230+
func TestWFTReset(t *testing.T) {
231+
cache := NewWorkerCache()
232+
params := workerExecutionParameters{
233+
cache: cache,
234+
}
235+
ensureRequiredParams(&params)
236+
wfType := commonpb.WorkflowType{Name: t.Name() + "-workflow-type"}
237+
reg := newRegistry()
238+
reg.RegisterWorkflowWithOptions(func(ctx Context) error {
239+
_ = SetUpdateHandler(ctx, "update", func(ctx Context) error {
240+
return nil
241+
}, UpdateHandlerOptions{
242+
Validator: func(ctx Context) error {
243+
return errors.New("rejecting for test")
244+
},
245+
})
246+
_ = Sleep(ctx, time.Second)
247+
return Sleep(ctx, time.Second)
248+
}, RegisterWorkflowOptions{
249+
Name: wfType.Name,
250+
})
251+
var (
252+
taskQueue = taskqueuepb.TaskQueue{Name: t.Name() + "task-queue"}
253+
history0 = historypb.History{Events: []*historypb.HistoryEvent{
254+
createTestEventWorkflowExecutionStarted(1, &historypb.WorkflowExecutionStartedEventAttributes{
255+
TaskQueue: &taskQueue,
256+
}),
257+
createTestEventWorkflowTaskScheduled(2, &historypb.WorkflowTaskScheduledEventAttributes{
258+
TaskQueue: &taskQueue,
259+
StartToCloseTimeout: &durationpb.Duration{Seconds: 10},
260+
Attempt: 1,
261+
}),
262+
createTestEventWorkflowTaskStarted(3),
263+
createTestEventWorkflowTaskCompleted(4, &historypb.WorkflowTaskCompletedEventAttributes{
264+
ScheduledEventId: 2,
265+
StartedEventId: 3,
266+
}),
267+
createTestEventTimerStarted(5, 5),
268+
createTestEventWorkflowTaskScheduled(6, &historypb.WorkflowTaskScheduledEventAttributes{
269+
TaskQueue: &taskQueue,
270+
StartToCloseTimeout: &durationpb.Duration{Seconds: 10},
271+
Attempt: 1,
272+
}),
273+
createTestEventWorkflowTaskStarted(7),
274+
}}
275+
messages = []*protocolpb.Message{
276+
createTestProtocolMessageUpdateRequest("test-update", 6, &update.Request{
277+
Meta: &update.Meta{
278+
UpdateId: "test-update",
279+
},
280+
Input: &update.Input{
281+
Name: "update",
282+
},
283+
}),
284+
}
285+
history1 = historypb.History{Events: []*historypb.HistoryEvent{
286+
createTestEventWorkflowTaskCompleted(4, &historypb.WorkflowTaskCompletedEventAttributes{
287+
ScheduledEventId: 2,
288+
StartedEventId: 3,
289+
}),
290+
createTestEventTimerStarted(5, 5),
291+
createTestEventWorkflowTaskScheduled(6, &historypb.WorkflowTaskScheduledEventAttributes{
292+
TaskQueue: &taskQueue,
293+
StartToCloseTimeout: &durationpb.Duration{Seconds: 10},
294+
Attempt: 1,
295+
}),
296+
createTestEventWorkflowTaskStarted(7),
297+
}}
298+
history2 = historypb.History{Events: []*historypb.HistoryEvent{
299+
createTestEventWorkflowTaskCompleted(4, &historypb.WorkflowTaskCompletedEventAttributes{
300+
ScheduledEventId: 2,
301+
StartedEventId: 3,
302+
}),
303+
createTestEventTimerStarted(5, 5),
304+
createTestEventTimerFired(6, 5),
305+
createTestEventWorkflowTaskScheduled(7, &historypb.WorkflowTaskScheduledEventAttributes{
306+
TaskQueue: &taskQueue,
307+
StartToCloseTimeout: &durationpb.Duration{Seconds: 10},
308+
Attempt: 1,
309+
}),
310+
createTestEventWorkflowTaskStarted(8),
311+
}}
312+
runID = t.Name() + "-run-id"
313+
wfID = t.Name() + "-workflow-id"
314+
wfe = commonpb.WorkflowExecution{RunId: runID, WorkflowId: wfID}
315+
ctrl = gomock.NewController(t)
316+
client = workflowservicemock.NewMockWorkflowServiceClient(ctrl)
317+
innerTaskHandler = newWorkflowTaskHandler(params, nil, reg)
318+
taskHandler = &countingTaskHandler{WorkflowTaskHandler: innerTaskHandler}
319+
contextManager = taskHandler
320+
pollResp0 = workflowservice.PollWorkflowTaskQueueResponse{
321+
Attempt: 1,
322+
WorkflowExecution: &wfe,
323+
WorkflowType: &wfType,
324+
History: &history0,
325+
Messages: messages,
326+
PreviousStartedEventId: 3,
327+
}
328+
task0 = workflowTask{task: &pollResp0}
329+
pollResp1 = workflowservice.PollWorkflowTaskQueueResponse{
330+
Attempt: 1,
331+
WorkflowExecution: &wfe,
332+
WorkflowType: &wfType,
333+
History: &history1,
334+
PreviousStartedEventId: 3,
335+
}
336+
task1 = workflowTask{task: &pollResp1}
337+
pollResp2 = workflowservice.PollWorkflowTaskQueueResponse{
338+
Attempt: 1,
339+
WorkflowExecution: &wfe,
340+
WorkflowType: &wfType,
341+
History: &history2,
342+
PreviousStartedEventId: 3,
343+
}
344+
task2 = workflowTask{task: &pollResp2}
345+
)
346+
347+
// Return a workflow task to reset the workflow to a previous state
348+
client.EXPECT().RespondWorkflowTaskCompleted(gomock.Any(), gomock.Any()).
349+
Return(&workflowservice.RespondWorkflowTaskCompletedResponse{
350+
ResetHistoryEventId: 3,
351+
}, nil).Times(3)
352+
// Return a workflow task to complete the workflow
353+
client.EXPECT().RespondWorkflowTaskCompleted(gomock.Any(), gomock.Any()).
354+
Return(&workflowservice.RespondWorkflowTaskCompletedResponse{}, nil)
355+
356+
poller := newWorkflowTaskPoller(taskHandler, contextManager, client, params)
357+
// Send a full history as part of the speculative WFT
358+
require.NoError(t, poller.processWorkflowTask(&task0))
359+
originalCachedExecution := cache.getWorkflowContext(runID)
360+
require.NotNil(t, originalCachedExecution)
361+
require.Equal(t, int64(3), originalCachedExecution.previousStartedEventID)
362+
require.Equal(t, int64(5), originalCachedExecution.lastHandledEventID)
363+
// Send some fake speculative WFTs to ensure the workflow is reset properly
364+
require.NoError(t, poller.processWorkflowTask(&task1))
365+
cachedExecution := cache.getWorkflowContext(runID)
366+
require.True(t, originalCachedExecution == cachedExecution)
367+
require.Equal(t, int64(3), cachedExecution.previousStartedEventID)
368+
require.Equal(t, int64(5), cachedExecution.lastHandledEventID)
369+
require.NoError(t, poller.processWorkflowTask(&task1))
370+
cachedExecution = cache.getWorkflowContext(runID)
371+
// Check the cached execution is the same as the original
372+
require.True(t, originalCachedExecution == cachedExecution)
373+
require.Equal(t, int64(3), cachedExecution.previousStartedEventID)
374+
require.Equal(t, int64(5), cachedExecution.lastHandledEventID)
375+
// Send a real WFT with new events
376+
require.NoError(t, poller.processWorkflowTask(&task2))
377+
cachedExecution = cache.getWorkflowContext(runID)
378+
require.True(t, originalCachedExecution == cachedExecution)
379+
}

0 commit comments

Comments
 (0)
Please sign in to comment.