@@ -6,23 +6,23 @@ import type {
6
6
Expression ,
7
7
Identifier ,
8
8
ImportDeclaration ,
9
- ImportExpression ,
10
9
VariableDeclaration ,
11
10
} from 'estree'
12
11
import type { SourceMap } from 'magic-string'
12
+ import type { RollupAstNode } from 'rollup'
13
13
import type { Plugin , Rollup } from 'vite'
14
14
import type { Node , Positioned } from './esmWalker'
15
15
import { findNodeAround } from 'acorn-walk'
16
16
import MagicString from 'magic-string'
17
17
import { createFilter } from 'vite'
18
- import { esmWalker , getArbitraryModuleIdentifier } from './esmWalker'
18
+ import { esmWalker } from './esmWalker'
19
19
20
20
interface HoistMocksOptions {
21
21
/**
22
22
* List of modules that should always be imported before compiler hints.
23
- * @default [ 'vitest']
23
+ * @default 'vitest'
24
24
*/
25
- hoistedModules ?: string [ ]
25
+ hoistedModule ?: string
26
26
/**
27
27
* @default ["vi", "vitest"]
28
28
*/
@@ -106,11 +106,14 @@ function isIdentifier(node: any): node is Positioned<Identifier> {
106
106
return node . type === 'Identifier'
107
107
}
108
108
109
- function getBetterEnd ( code : string , node : Node ) {
109
+ function getNodeTail ( code : string , node : Node ) {
110
110
let end = node . end
111
111
if ( code [ node . end ] === ';' ) {
112
112
end += 1
113
113
}
114
+ if ( code [ node . end ] === '\n' ) {
115
+ return end + 1
116
+ }
114
117
if ( code [ node . end + 1 ] === '\n' ) {
115
118
end += 1
116
119
}
@@ -160,48 +163,43 @@ export function hoistMocks(
160
163
dynamicImportMockMethodNames = [ 'mock' , 'unmock' , 'doMock' , 'doUnmock' ] ,
161
164
hoistedMethodNames = [ 'hoisted' ] ,
162
165
utilsObjectNames = [ 'vi' , 'vitest' ] ,
163
- hoistedModules = [ 'vitest' ] ,
166
+ hoistedModule = 'vitest' ,
164
167
} = options
165
168
166
- const hoistIndex = code . match ( hashbangRE ) ?. [ 0 ] . length ?? 0
169
+ // hoist at the start of the file, after the hashbang
170
+ let hoistIndex = hashbangRE . exec ( code ) ?. [ 0 ] . length ?? 0
167
171
168
172
let hoistedModuleImported = false
169
173
170
174
let uid = 0
171
175
const idToImportMap = new Map < string , string > ( )
172
176
177
+ const imports : {
178
+ node : RollupAstNode < ImportDeclaration >
179
+ id : string
180
+ } [ ] = [ ]
181
+
173
182
// this will transform import statements into dynamic ones, if there are imports
174
183
// it will keep the import as is, if we don't need to mock anything
175
184
// in browser environment it will wrap the module value with "vitest_wrap_module" function
176
185
// that returns a proxy to the module so that named exports can be mocked
177
- const transformImportDeclaration = ( node : ImportDeclaration ) => {
178
- const source = node . source . value as string
179
-
180
- const importId = `__vi_import_${ uid ++ } __`
181
- const hasSpecifiers = node . specifiers . length > 0
182
- const code = hasSpecifiers
183
- ? `const ${ importId } = await import('${ source } ')\n`
184
- : `await import('${ source } ')\n`
185
- return {
186
- code,
187
- id : importId ,
188
- }
189
- }
190
-
191
- function defineImport ( node : Positioned < ImportDeclaration > ) {
186
+ function defineImport (
187
+ importNode : ImportDeclaration & {
188
+ start : number
189
+ end : number
190
+ } ,
191
+ ) {
192
+ const source = importNode . source . value as string
192
193
// always hoist vitest import to top of the file, so
193
194
// "vi" helpers can access it
194
- if ( hoistedModules . includes ( node . source . value as string ) ) {
195
+ if ( hoistedModule === source ) {
195
196
hoistedModuleImported = true
196
197
return
197
198
}
199
+ const importId = `__vi_import_${ uid ++ } __`
200
+ imports . push ( { id : importId , node : importNode } )
198
201
199
- const declaration = transformImportDeclaration ( node )
200
- if ( ! declaration ) {
201
- return null
202
- }
203
- s . appendLeft ( hoistIndex , declaration . code )
204
- return declaration . id
202
+ return importId
205
203
}
206
204
207
205
// 1. check all import statements and record id -> importName map
@@ -214,13 +212,20 @@ export function hoistMocks(
214
212
if ( ! importId ) {
215
213
continue
216
214
}
217
- s . remove ( node . start , getBetterEnd ( code , node ) )
218
215
for ( const spec of node . specifiers ) {
219
216
if ( spec . type === 'ImportSpecifier' ) {
220
- idToImportMap . set (
221
- spec . local . name ,
222
- `${ importId } .${ getArbitraryModuleIdentifier ( spec . imported ) } ` ,
223
- )
217
+ if ( spec . imported . type === 'Identifier' ) {
218
+ idToImportMap . set (
219
+ spec . local . name ,
220
+ `${ importId } .${ spec . imported . name } ` ,
221
+ )
222
+ }
223
+ else {
224
+ idToImportMap . set (
225
+ spec . local . name ,
226
+ `${ importId } [${ JSON . stringify ( spec . imported . value as string ) } ]` ,
227
+ )
228
+ }
224
229
}
225
230
else if ( spec . type === 'ImportDefaultSpecifier' ) {
226
231
idToImportMap . set ( spec . local . name , `${ importId } .default` )
@@ -235,7 +240,7 @@ export function hoistMocks(
235
240
236
241
const declaredConst = new Set < string > ( )
237
242
const hoistedNodes : Positioned <
238
- CallExpression | VariableDeclaration | AwaitExpression
243
+ CallExpression | VariableDeclaration | AwaitExpression
239
244
> [ ] = [ ]
240
245
241
246
function createSyntaxError ( node : Positioned < Node > , message : string ) {
@@ -300,6 +305,8 @@ export function hoistMocks(
300
305
}
301
306
}
302
307
308
+ const usedUtilityExports = new Set < string > ( )
309
+
303
310
esmWalker ( ast , {
304
311
onIdentifier ( id , info , parentStack ) {
305
312
const binding = idToImportMap . get ( id . name )
@@ -333,6 +340,7 @@ export function hoistMocks(
333
340
&& isIdentifier ( node . callee . property )
334
341
) {
335
342
const methodName = node . callee . property . name
343
+ usedUtilityExports . add ( node . callee . object . name )
336
344
337
345
if ( hoistableMockMethodNames . includes ( methodName ) ) {
338
346
const method = `${ node . callee . object . name } .${ methodName } `
@@ -347,6 +355,35 @@ export function hoistMocks(
347
355
`Cannot export the result of "${ method } ". Remove export declaration because "${ method } " doesn\'t return anything.` ,
348
356
)
349
357
}
358
+ // rewrite vi.mock(import('..')) into vi.mock('..')
359
+ if (
360
+ node . type === 'CallExpression'
361
+ && node . callee . type === 'MemberExpression'
362
+ && dynamicImportMockMethodNames . includes ( ( node . callee . property as Identifier ) . name )
363
+ ) {
364
+ const moduleInfo = node . arguments [ 0 ] as Positioned < Expression >
365
+ // vi.mock(import('./path')) -> vi.mock('./path')
366
+ if ( moduleInfo . type === 'ImportExpression' ) {
367
+ const source = moduleInfo . source as Positioned < Expression >
368
+ s . overwrite (
369
+ moduleInfo . start ,
370
+ moduleInfo . end ,
371
+ s . slice ( source . start , source . end ) ,
372
+ )
373
+ }
374
+ // vi.mock(await import('./path')) -> vi.mock('./path')
375
+ if (
376
+ moduleInfo . type === 'AwaitExpression'
377
+ && moduleInfo . argument . type === 'ImportExpression'
378
+ ) {
379
+ const source = moduleInfo . argument . source as Positioned < Expression >
380
+ s . overwrite (
381
+ moduleInfo . start ,
382
+ moduleInfo . end ,
383
+ s . slice ( source . start , source . end ) ,
384
+ )
385
+ }
386
+ }
350
387
hoistedNodes . push ( node )
351
388
}
352
389
// vi.doMock(import('./path')) -> vi.doMock('./path')
@@ -394,9 +431,8 @@ export function hoistMocks(
394
431
'AwaitExpression' ,
395
432
) ?. node as Positioned < AwaitExpression > | undefined
396
433
// hoist "await vi.hoisted(async () => {})" or "vi.hoisted(() => {})"
397
- hoistedNodes . push (
398
- awaitedExpression ?. argument === node ? awaitedExpression : node ,
399
- )
434
+ const moveNode = awaitedExpression ?. argument === node ? awaitedExpression : node
435
+ hoistedNodes . push ( moveNode )
400
436
}
401
437
}
402
438
}
@@ -446,24 +482,6 @@ export function hoistMocks(
446
482
)
447
483
}
448
484
449
- function rewriteMockDynamicImport (
450
- nodeCode : string ,
451
- moduleInfo : Positioned < ImportExpression > ,
452
- expressionStart : number ,
453
- expressionEnd : number ,
454
- mockStart : number ,
455
- ) {
456
- const source = moduleInfo . source as Positioned < Expression >
457
- const importPath = s . slice ( source . start , source . end )
458
- const nodeCodeStart = expressionStart - mockStart
459
- const nodeCodeEnd = expressionEnd - mockStart
460
- return (
461
- nodeCode . slice ( 0 , nodeCodeStart )
462
- + importPath
463
- + nodeCode . slice ( nodeCodeEnd )
464
- )
465
- }
466
-
467
485
// validate hoistedNodes doesn't have nodes inside other nodes
468
486
for ( let i = 0 ; i < hoistedNodes . length ; i ++ ) {
469
487
const node = hoistedNodes [ i ]
@@ -479,61 +497,55 @@ export function hoistMocks(
479
497
}
480
498
}
481
499
482
- // Wait for imports to be hoisted and then hoist the mocks
483
- const hoistedCode = hoistedNodes
484
- . map ( ( node ) => {
485
- const end = getBetterEnd ( code , node )
486
- /**
487
- * In the following case, we need to change the `user` to user: __vi_import_x__.user
488
- * So we should get the latest code from `s`.
489
- *
490
- * import user from './user'
491
- * vi.mock('./mock.js', () => ({ getSession: vi.fn().mockImplementation(() => ({ user })) }))
492
- */
493
- let nodeCode = s . slice ( node . start , end )
494
-
495
- // rewrite vi.mock(import('..')) into vi.mock('..')
496
- if (
497
- node . type === 'CallExpression'
498
- && node . callee . type === 'MemberExpression'
499
- && dynamicImportMockMethodNames . includes ( ( node . callee . property as Identifier ) . name )
500
- ) {
501
- const moduleInfo = node . arguments [ 0 ] as Positioned < Expression >
502
- // vi.mock(import('./path')) -> vi.mock('./path')
503
- if ( moduleInfo . type === 'ImportExpression' ) {
504
- nodeCode = rewriteMockDynamicImport (
505
- nodeCode ,
506
- moduleInfo ,
507
- moduleInfo . start ,
508
- moduleInfo . end ,
509
- node . start ,
510
- )
511
- }
512
- // vi.mock(await import('./path')) -> vi.mock('./path')
513
- if (
514
- moduleInfo . type === 'AwaitExpression'
515
- && moduleInfo . argument . type === 'ImportExpression'
516
- ) {
517
- nodeCode = rewriteMockDynamicImport (
518
- nodeCode ,
519
- moduleInfo . argument as Positioned < ImportExpression > ,
520
- moduleInfo . start ,
521
- moduleInfo . end ,
522
- node . start ,
523
- )
524
- }
525
- }
500
+ // hoist vi.mock/vi.hoisted
501
+ for ( const node of hoistedNodes ) {
502
+ const end = getNodeTail ( code , node )
503
+ if ( hoistIndex === end ) {
504
+ hoistIndex = end
505
+ }
506
+ // don't hoist into itself if it's already at the top
507
+ else if ( hoistIndex !== node . start ) {
508
+ s . move ( node . start , end , hoistIndex )
509
+ }
510
+ }
526
511
527
- s . remove ( node . start , end )
528
- return `${ nodeCode } ${ nodeCode . endsWith ( '\n' ) ? '' : '\n' } `
529
- } )
530
- . join ( '' )
512
+ // hoist actual dynamic imports last so they are inserted after all hoisted mocks
513
+ for ( const { node : importNode , id : importId } of imports ) {
514
+ const source = importNode . source . value as string
531
515
532
- if ( hoistedCode || hoistedModuleImported ) {
533
- s . prepend (
534
- ( ! hoistedModuleImported && hoistedCode ? API_NOT_FOUND_CHECK ( utilsObjectNames ) : '' )
535
- + hoistedCode ,
516
+ s . update (
517
+ importNode . start ,
518
+ importNode . end ,
519
+ `const ${ importId } = await import(${ JSON . stringify (
520
+ source ,
521
+ ) } );\n`,
536
522
)
523
+
524
+ if ( importNode . start === hoistIndex ) {
525
+ // no need to hoist, but update hoistIndex to keep the order
526
+ hoistIndex = importNode . end
527
+ }
528
+ else {
529
+ // There will be an error if the module is called before it is imported,
530
+ // so the module import statement is hoisted to the top
531
+ s . move ( importNode . start , importNode . end , hoistIndex )
532
+ }
533
+ }
534
+
535
+ if ( ! hoistedModuleImported && hoistedNodes . length ) {
536
+ const utilityImports = [ ...usedUtilityExports ]
537
+ // "vi" or "vitest" is imported from a module other than "vitest"
538
+ if ( utilityImports . some ( name => idToImportMap . has ( name ) ) ) {
539
+ s . prepend ( API_NOT_FOUND_CHECK ( utilityImports ) )
540
+ }
541
+ // if "vi" or "vitest" are not imported at all, import them
542
+ else if ( utilityImports . length ) {
543
+ s . prepend (
544
+ `import { ${ [ ...usedUtilityExports ] . join ( ', ' ) } } from ${ JSON . stringify (
545
+ hoistedModule ,
546
+ ) } \n`,
547
+ )
548
+ }
537
549
}
538
550
539
551
return {
0 commit comments