Skip to content

Commit

Permalink
feat: allow extending toEqual (fix vitest-dev#2875)
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranmk committed Jan 8, 2024
1 parent a73c1c2 commit a0f502e
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 40 deletions.
67 changes: 67 additions & 0 deletions docs/api/expect.md
Original file line number Diff line number Diff line change
Expand Up @@ -1405,3 +1405,70 @@ Don't forget to include the ambient declaration file in your `tsconfig.json`.
:::tip
If you want to know more, checkout [guide on extending matchers](/guide/extending-matchers).
:::

## expect.addEqualityTesters <Badge type="info">1.2.0+</Badge>

- **Type:** `(tester: Array<Tester>) => void`

You can use this method to define custom matchers to test if two object equals are equal. It is compatible with Jest's `expect.extend`.

```ts
class AnagramComparator {
public word: string

constructor(word: string) {
this.word = word
}

equals(other: AnagramComparator): boolean {
const cleanStr1 = this.word.replace(/ /g, '').toLowerCase()
const cleanStr2 = other.word.replace(/ /g, '').toLowerCase()

const sortedStr1 = cleanStr1.split('').sort().join('')
const sortedStr2 = cleanStr2.split('').sort().join('')

return sortedStr1 === sortedStr2
}
}

function createAnagramComparator(word: string) {
return new AnagramComparator(word)
}

function isAnagramComparator(a: unknown): a is AnagramComparator {
return a instanceof AnagramComparator
}

const areObjectsEqual: Tester = (
a: unknown,
b: unknown,
): boolean | undefined => {
const isAAnagramComparator = isAnagramComparator(a)
const isBAnagramComparator = isAnagramComparator(b)

if (isAAnagramComparator && isBAnagramComparator)
return a.equals(b)

else if (isAAnagramComparator === isBAnagramComparator)
return undefined

else
return false
}

function* toIterator<T>(array: Array<T>): Iterator<T> {
for (const obj of array)
yield obj
}

const customObject1 = createAnagramComparator('listen')
const customObject2 = createAnagramComparator('silent')

expect.addEqualityTesters([areObjectsEqual])
```

```ts
test('objects are equal with different amount', () => {
expect(customObject1).toEqual(customObject2)
})
```
3 changes: 2 additions & 1 deletion examples/vitesse/src/auto-import.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ declare global {
declare global {
// @ts-ignore
export type { Component, ComponentPublicInstance, ComputedRef, InjectionKey, PropType, Ref, VNode } from 'vue'
}
export type { Component, ComponentPublicInstance, ComputedRef, ExtractDefaultPropTypes, ExtractPropTypes, ExtractPublicPropTypes, InjectionKey, PropType, Ref, VNode, WritableComputedRef } from 'vue'
}
1 change: 1 addition & 0 deletions packages/expect/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ export * from './constants'
export * from './types'
export { getState, setState } from './state'
export { JestChaiExpect } from './jest-expect'
export { addCustomEqualityTesters } from './jest-matcher-utils'
export { JestExtend } from './jest-extend'
export { setupColors } from '@vitest/utils'
10 changes: 6 additions & 4 deletions packages/expect/src/jest-asymmetric-matchers.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ChaiPlugin, MatcherState } from './types'
import { GLOBAL_EXPECT } from './constants'
import { getState } from './state'
import { diff, getMatcherUtils, stringify } from './jest-matcher-utils'
import { diff, getCustomEqualityTesters, getMatcherUtils, stringify } from './jest-matcher-utils'

import { equals, isA, iterableEquality, pluralize, subsetEquality } from './jest-utils'

Expand All @@ -26,7 +26,7 @@ export abstract class AsymmetricMatcher<
...getState(expect || (globalThis as any)[GLOBAL_EXPECT]),
equals,
isNot: this.inverse,
customTesters: [],
customTesters: getCustomEqualityTesters(),
utils: {
...getMatcherUtils(),
diff,
Expand Down Expand Up @@ -116,8 +116,9 @@ export class ObjectContaining extends AsymmetricMatcher<Record<string, unknown>>

let result = true

const matcherContext = this.getMatcherContext()
for (const property in this.sample) {
if (!this.hasProperty(other, property) || !equals(this.sample[property], other[property])) {
if (!this.hasProperty(other, property) || !equals(this.sample[property], other[property], matcherContext.customTesters)) {
result = false
break
}
Expand Down Expand Up @@ -149,11 +150,12 @@ export class ArrayContaining<T = unknown> extends AsymmetricMatcher<Array<T>> {
)
}

const matcherContext = this.getMatcherContext()
const result
= this.sample.length === 0
|| (Array.isArray(other)
&& this.sample.every(item =>
other.some(another => equals(item, another)),
other.some(another => equals(item, another, matcherContext.customTesters)),
))

return this.inverse ? !result : result
Expand Down
15 changes: 9 additions & 6 deletions packages/expect/src/jest-expect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { Test } from '@vitest/runner'
import type { Assertion, ChaiPlugin } from './types'
import { arrayBufferEquality, generateToBeMessage, iterableEquality, equals as jestEquals, sparseArrayEquality, subsetEquality, typeEquality } from './jest-utils'
import type { AsymmetricMatcher } from './jest-asymmetric-matchers'
import { diff, stringify } from './jest-matcher-utils'
import { diff, getCustomEqualityTesters, stringify } from './jest-matcher-utils'
import { JEST_MATCHERS_OBJECT } from './constants'
import { recordAsyncExpect, wrapSoft } from './utils'

Expand All @@ -23,6 +23,7 @@ declare class DOMTokenList {
export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
const { AssertionError } = chai
const c = () => getColors()
const customTesters = getCustomEqualityTesters()

function def(name: keyof Assertion | (keyof Assertion)[], fn: ((this: Chai.AssertionStatic & Assertion, ...args: any[]) => any)) {
const addMethod = (n: keyof Assertion) => {
Expand Down Expand Up @@ -80,7 +81,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
const equal = jestEquals(
actual,
expected,
[iterableEquality],
[...customTesters, iterableEquality],
)

return this.assert(
Expand All @@ -98,6 +99,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
obj,
expected,
[
...customTesters,
iterableEquality,
typeEquality,
sparseArrayEquality,
Expand Down Expand Up @@ -125,6 +127,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
actual,
expected,
[
...customTesters,
iterableEquality,
typeEquality,
sparseArrayEquality,
Expand All @@ -140,7 +143,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
const toEqualPass = jestEquals(
actual,
expected,
[iterableEquality],
[...customTesters, iterableEquality],
)

if (toEqualPass)
Expand All @@ -159,7 +162,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
def('toMatchObject', function (expected) {
const actual = this._obj
return this.assert(
jestEquals(actual, expected, [iterableEquality, subsetEquality]),
jestEquals(actual, expected, [...customTesters, iterableEquality, subsetEquality]),
'expected #{this} to match object #{exp}',
'expected #{this} to not match object #{exp}',
expected,
Expand Down Expand Up @@ -208,7 +211,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
def('toContainEqual', function (expected) {
const obj = utils.flag(this, 'object')
const index = Array.from(obj).findIndex((item) => {
return jestEquals(item, expected)
return jestEquals(item, expected, customTesters)
})

this.assert(
Expand Down Expand Up @@ -339,7 +342,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
return utils.getPathInfo(actual, propertyName)
}
const { value, exists } = getValue()
const pass = exists && (args.length === 1 || jestEquals(expected, value))
const pass = exists && (args.length === 1 || jestEquals(expected, value, customTesters))

const valueString = args.length === 1 ? '' : ` with value ${utils.objDisplay(expected)}`

Expand Down
5 changes: 2 additions & 3 deletions packages/expect/src/jest-extend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { ASYMMETRIC_MATCHERS_OBJECT, JEST_MATCHERS_OBJECT } from './constants'
import { AsymmetricMatcher } from './jest-asymmetric-matchers'
import { getState } from './state'

import { diff, getMatcherUtils, stringify } from './jest-matcher-utils'
import { diff, getCustomEqualityTesters, getMatcherUtils, stringify } from './jest-matcher-utils'

import {
equals,
Expand All @@ -33,8 +33,7 @@ function getMatcherState(assertion: Chai.AssertionStatic & Chai.Assertion, expec

const matcherState: MatcherState = {
...getState(expect),
// TODO: implement via expect.addEqualityTesters
customTesters: [],
customTesters: getCustomEqualityTesters(),
isNot,
utils: jestUtils,
promise,
Expand Down
23 changes: 21 additions & 2 deletions packages/expect/src/jest-matcher-utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { getColors, stringify } from '@vitest/utils'
import type { MatcherHintOptions } from './types'
import { getColors, getType, stringify } from '@vitest/utils'
import type { MatcherHintOptions, Tester } from './types'
import { JEST_MATCHERS_OBJECT } from './constants'

export { diff } from '@vitest/utils/diff'
export { stringify }
Expand Down Expand Up @@ -101,3 +102,21 @@ export function getMatcherUtils() {
printExpected,
}
}

export function addCustomEqualityTesters(newTesters: Array<Tester>): void {
if (!Array.isArray(newTesters)) {
throw new TypeError(
`expect.customEqualityTesters: Must be set to an array of Testers. Was given "${getType(
newTesters,
)}"`,
)
}

(globalThis as any)[JEST_MATCHERS_OBJECT].customEqualityTesters.push(
...newTesters,
)
}

export function getCustomEqualityTesters(): Array<Tester> {
return (globalThis as any)[JEST_MATCHERS_OBJECT].customEqualityTesters
}
52 changes: 32 additions & 20 deletions packages/expect/src/jest-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

import { isObject } from '@vitest/utils'
import type { Tester } from './types'
import type { Tester, TesterContext } from './types'

// Extracted out of jasmine 2.5.2
export function equals(
Expand Down Expand Up @@ -87,8 +87,9 @@ function eq(
if (asymmetricResult !== undefined)
return asymmetricResult

const testerContext: TesterContext = { equals }
for (let i = 0; i < customTesters.length; i++) {
const customTesterResult = customTesters[i](a, b)
const customTesterResult = customTesters[i].call(testerContext, a, b, customTesters)
if (customTesterResult !== undefined)
return customTesterResult
}
Expand Down Expand Up @@ -298,7 +299,7 @@ function hasIterator(object: any) {
return !!(object != null && object[IteratorSymbol])
}

export function iterableEquality(a: any, b: any, aStack: Array<any> = [], bStack: Array<any> = []): boolean | undefined {
export function iterableEquality(a: any, b: any, customTesters: Array<Tester> = [], aStack: Array<any> = [], bStack: Array<any> = []): boolean | undefined {
if (
typeof a !== 'object'
|| typeof b !== 'object'
Expand All @@ -324,7 +325,20 @@ export function iterableEquality(a: any, b: any, aStack: Array<any> = [], bStack
aStack.push(a)
bStack.push(b)

const iterableEqualityWithStack = (a: any, b: any) => iterableEquality(a, b, [...aStack], [...bStack])
const filteredCustomTesters: Array<Tester> = [
...customTesters.filter(t => t !== iterableEquality),
iterableEqualityWithStack,
]

function iterableEqualityWithStack(a: any, b: any) {
return iterableEquality(
a,
b,
[...filteredCustomTesters],
[...aStack],
[...bStack],
)
}

if (a.size !== undefined) {
if (a.size !== b.size) {
Expand All @@ -336,7 +350,7 @@ export function iterableEquality(a: any, b: any, aStack: Array<any> = [], bStack
if (!b.has(aValue)) {
let has = false
for (const bValue of b) {
const isEqual = equals(aValue, bValue, [iterableEqualityWithStack])
const isEqual = equals(aValue, bValue, filteredCustomTesters)
if (isEqual === true)
has = true
}
Expand All @@ -357,20 +371,16 @@ export function iterableEquality(a: any, b: any, aStack: Array<any> = [], bStack
for (const aEntry of a) {
if (
!b.has(aEntry[0])
|| !equals(aEntry[1], b.get(aEntry[0]), [iterableEqualityWithStack])
|| !equals(aEntry[1], b.get(aEntry[0]), filteredCustomTesters)
) {
let has = false
for (const bEntry of b) {
const matchedKey = equals(aEntry[0], bEntry[0], [
iterableEqualityWithStack,
])
const matchedKey = equals(aEntry[0], bEntry[0], filteredCustomTesters)

let matchedValue = false
if (matchedKey === true) {
matchedValue = equals(aEntry[1], bEntry[1], [
iterableEqualityWithStack,
])
}
if (matchedKey === true)
matchedValue = equals(aEntry[1], bEntry[1], filteredCustomTesters)

if (matchedValue === true)
has = true
}
Expand All @@ -394,7 +404,7 @@ export function iterableEquality(a: any, b: any, aStack: Array<any> = [], bStack
const nextB = bIterator.next()
if (
nextB.done
|| !equals(aValue, nextB.value, [iterableEqualityWithStack])
|| !equals(aValue, nextB.value, filteredCustomTesters)
)
return false
}
Expand Down Expand Up @@ -430,7 +440,8 @@ function isObjectWithKeys(a: any) {
&& !(a instanceof Date)
}

export function subsetEquality(object: unknown, subset: unknown): boolean | undefined {
export function subsetEquality(object: unknown, subset: unknown, customTesters: Array<Tester> = []): boolean | undefined {
const filteredCustomTesters = customTesters.filter(t => t !== subsetEquality)
// subsetEquality needs to keep track of the references
// it has already visited to avoid infinite loops in case
// there are circular references in the subset passed to it.
Expand All @@ -443,15 +454,15 @@ export function subsetEquality(object: unknown, subset: unknown): boolean | unde
return Object.keys(subset).every((key) => {
if (isObjectWithKeys(subset[key])) {
if (seenReferences.has(subset[key]))
return equals(object[key], subset[key], [iterableEquality])
return equals(object[key], subset[key], filteredCustomTesters)

seenReferences.set(subset[key], true)
}
const result
= object != null
&& hasPropertyInObject(object, key)
&& equals(object[key], subset[key], [
iterableEquality,
...filteredCustomTesters,
subsetEqualityWithContext(seenReferences),
])
// The main goal of using seenReference is to avoid circular node on tree.
Expand Down Expand Up @@ -504,15 +515,16 @@ export function arrayBufferEquality(a: unknown, b: unknown): boolean | undefined
return true
}

export function sparseArrayEquality(a: unknown, b: unknown): boolean | undefined {
export function sparseArrayEquality(a: unknown, b: unknown, customTesters: Array<Tester> = []): boolean | undefined {
if (!Array.isArray(a) || !Array.isArray(b))
return undefined

// A sparse array [, , 1] will have keys ["2"] whereas [undefined, undefined, 1] will have keys ["0", "1", "2"]
const aKeys = Object.keys(a)
const bKeys = Object.keys(b)
const filteredCustomTesters = customTesters.filter(t => t !== sparseArrayEquality)
return (
equals(a, b, [iterableEquality, typeEquality], true) && equals(aKeys, bKeys)
equals(a, b, filteredCustomTesters, true) && equals(aKeys, bKeys)
)
}

Expand Down

0 comments on commit a0f502e

Please sign in to comment.