@@ -34,7 +34,6 @@ import {
34
34
collectFields ,
35
35
createGraphQLError ,
36
36
fakePromise ,
37
- getAbortPromise ,
38
37
getArgumentValues ,
39
38
getDefinedRootType ,
40
39
GraphQLResolveInfo ,
@@ -52,11 +51,10 @@ import {
52
51
Path ,
53
52
pathToArray ,
54
53
promiseReduce ,
55
- registerAbortSignalListener ,
56
54
} from '@graphql-tools/utils' ;
57
55
import { TypedDocumentNode } from '@graphql-typed-document-node/core' ;
58
56
import { DisposableSymbols } from '@whatwg-node/disposablestack' ;
59
- import { handleMaybePromise } from '@whatwg-node/promise-helpers' ;
57
+ import { createDeferredPromise , handleMaybePromise } from '@whatwg-node/promise-helpers' ;
60
58
import { coerceError } from './coerceError.js' ;
61
59
import { flattenAsyncIterable } from './flattenAsyncIterable.js' ;
62
60
import { invariant } from './invariant.js' ;
@@ -127,6 +125,8 @@ export interface ExecutionContext<TVariables = any, TContext = any> {
127
125
errors : Array < GraphQLError > ;
128
126
subsequentPayloads : Set < AsyncPayloadRecord > ;
129
127
signal ?: AbortSignal ;
128
+ onSignalAbort ?( handler : ( ) => void ) : void ;
129
+ signalPromise ?: Promise < never > ;
130
130
}
131
131
132
132
export interface FormattedExecutionResult <
@@ -421,6 +421,8 @@ export function buildExecutionContext<TData = any, TVariables = any, TContext =
421
421
signal,
422
422
} = args ;
423
423
424
+ signal ?. throwIfAborted ( ) ;
425
+
424
426
// If the schema used for execution is invalid, throw an error.
425
427
assertValidSchema ( schema ) ;
426
428
@@ -489,6 +491,31 @@ export function buildExecutionContext<TData = any, TVariables = any, TContext =
489
491
return coercedVariableValues . errors ;
490
492
}
491
493
494
+ signal ?. throwIfAborted ( ) ;
495
+
496
+ let onSignalAbort : ExecutionContext [ 'onSignalAbort' ] ;
497
+ let signalPromise : ExecutionContext [ 'signalPromise' ] ;
498
+
499
+ if ( signal ) {
500
+ const listeners = new Set < ( ) => void > ( ) ;
501
+ const signalDeferred = createDeferredPromise < never > ( ) ;
502
+ signalPromise = signalDeferred . promise ;
503
+ const sharedListener = ( ) => {
504
+ signalDeferred . reject ( signal . reason ) ;
505
+ signal . removeEventListener ( 'abort' , sharedListener ) ;
506
+ } ;
507
+ signal . addEventListener ( 'abort' , sharedListener , { once : true } ) ;
508
+ signalPromise . catch ( ( ) => {
509
+ for ( const listener of listeners ) {
510
+ listener ( ) ;
511
+ }
512
+ listeners . clear ( ) ;
513
+ } ) ;
514
+ onSignalAbort = handler => {
515
+ listeners . add ( handler ) ;
516
+ } ;
517
+ }
518
+
492
519
return {
493
520
schema,
494
521
fragments,
@@ -502,6 +529,8 @@ export function buildExecutionContext<TData = any, TVariables = any, TContext =
502
529
subsequentPayloads : new Set ( ) ,
503
530
errors : [ ] ,
504
531
signal,
532
+ onSignalAbort,
533
+ signalPromise,
505
534
} ;
506
535
}
507
536
@@ -626,7 +655,7 @@ function executeFields(
626
655
}
627
656
}
628
657
} catch ( error ) {
629
- if ( containsPromise ) {
658
+ if ( error !== exeContext . signal ?. reason && containsPromise ) {
630
659
// Ensure that any promises returned by other fields are handled, as they may also reject.
631
660
return handleMaybePromise (
632
661
( ) => promiseForObject ( results , exeContext . signal ) ,
@@ -649,7 +678,7 @@ function executeFields(
649
678
// Otherwise, results is a map from field name to the result of resolving that
650
679
// field, which is possibly a promise. Return a promise that will return this
651
680
// same map, but with any promises replaced with the values they resolved to.
652
- return promiseForObject ( results , exeContext . signal ) ;
681
+ return promiseForObject ( results , exeContext . signal , exeContext . signalPromise ) ;
653
682
}
654
683
655
684
/**
@@ -679,6 +708,7 @@ function executeField(
679
708
680
709
// Get the resolve function, regardless of if its result is normal or abrupt (error).
681
710
try {
711
+ exeContext . signal ?. throwIfAborted ( ) ;
682
712
// Build a JS object of arguments from the field.arguments AST, using the
683
713
// variables scope to fulfill any variable references.
684
714
// TODO: find a way to memoize, in case this field is within a List type.
@@ -973,8 +1003,9 @@ async function completeAsyncIteratorValue(
973
1003
iterator : AsyncIterator < unknown > ,
974
1004
asyncPayloadRecord ?: AsyncPayloadRecord ,
975
1005
) : Promise < ReadonlyArray < unknown > > {
976
- if ( exeContext . signal && iterator . return ) {
977
- registerAbortSignalListener ( exeContext . signal , ( ) => {
1006
+ exeContext . signal ?. throwIfAborted ( ) ;
1007
+ if ( iterator . return ) {
1008
+ exeContext . onSignalAbort ?.( ( ) => {
978
1009
iterator . return ?.( ) ;
979
1010
} ) ;
980
1011
}
@@ -1755,18 +1786,25 @@ function executeSubscription(exeContext: ExecutionContext): MaybePromise<AsyncIt
1755
1786
const result = resolveFn ( rootValue , args , contextValue , info ) ;
1756
1787
1757
1788
if ( isPromise ( result ) ) {
1758
- return result . then ( assertEventStream ) . then ( undefined , error => {
1759
- throw locatedError ( error , fieldNodes , pathToArray ( path ) ) ;
1760
- } ) ;
1789
+ return result
1790
+ . then ( result => assertEventStream ( result , exeContext . signal , exeContext . onSignalAbort ) )
1791
+ . then ( undefined , error => {
1792
+ throw locatedError ( error , fieldNodes , pathToArray ( path ) ) ;
1793
+ } ) ;
1761
1794
}
1762
1795
1763
- return assertEventStream ( result , exeContext . signal ) ;
1796
+ return assertEventStream ( result , exeContext . signal , exeContext . onSignalAbort ) ;
1764
1797
} catch ( error ) {
1765
1798
throw locatedError ( error , fieldNodes , pathToArray ( path ) ) ;
1766
1799
}
1767
1800
}
1768
1801
1769
- function assertEventStream ( result : unknown , signal ?: AbortSignal ) : AsyncIterable < unknown > {
1802
+ function assertEventStream (
1803
+ result : unknown ,
1804
+ signal ?: AbortSignal ,
1805
+ onSignalAbort ?: ( handler : ( ) => void ) => void ,
1806
+ ) : AsyncIterable < unknown > {
1807
+ signal ?. throwIfAborted ( ) ;
1770
1808
if ( result instanceof Error ) {
1771
1809
throw result ;
1772
1810
}
@@ -1777,13 +1815,13 @@ function assertEventStream(result: unknown, signal?: AbortSignal): AsyncIterable
1777
1815
'Subscription field must return Async Iterable. ' + `Received: ${ inspect ( result ) } .` ,
1778
1816
) ;
1779
1817
}
1780
- if ( signal ) {
1818
+ if ( onSignalAbort ) {
1781
1819
return {
1782
1820
[ Symbol . asyncIterator ] ( ) {
1783
1821
const asyncIterator = result [ Symbol . asyncIterator ] ( ) ;
1784
1822
1785
1823
if ( asyncIterator . return ) {
1786
- registerAbortSignalListener ( signal , ( ) => {
1824
+ onSignalAbort ?. ( ( ) => {
1787
1825
asyncIterator . return ?.( ) ;
1788
1826
} ) ;
1789
1827
}
@@ -2110,8 +2148,6 @@ function yieldSubsequentPayloads(
2110
2148
) : AsyncGenerator < SubsequentIncrementalExecutionResult , void , void > {
2111
2149
let isDone = false ;
2112
2150
2113
- const abortPromise = exeContext . signal ? getAbortPromise ( exeContext . signal ) : undefined ;
2114
-
2115
2151
async function next ( ) : Promise < IteratorResult < SubsequentIncrementalExecutionResult , void > > {
2116
2152
if ( isDone ) {
2117
2153
return { value : undefined , done : true } ;
@@ -2121,8 +2157,8 @@ function yieldSubsequentPayloads(
2121
2157
record => record . promise ,
2122
2158
) ;
2123
2159
2124
- if ( abortPromise ) {
2125
- await Promise . race ( [ abortPromise , ...subSequentPayloadPromises ] ) ;
2160
+ if ( exeContext . signalPromise ) {
2161
+ await Promise . race ( [ exeContext . signalPromise , ...subSequentPayloadPromises ] ) ;
2126
2162
} else {
2127
2163
await Promise . race ( subSequentPayloadPromises ) ;
2128
2164
}
0 commit comments