Skip to content

Commit 9d8a543

Browse files
committed
refactor(vitest): simplify poll assertion control flow
1 parent 22eff98 commit 9d8a543

2 files changed

Lines changed: 75 additions & 46 deletions

File tree

packages/utils/src/timers.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,21 @@ export function setSafeTimers(): void {
7272

7373
(globalThis as any)[SAFE_TIMERS_SYMBOL] = timers
7474
}
75+
76+
/**
77+
* Returns a promise that resolves after the specified duration.
78+
*
79+
* @param timeout - Delay in milliseconds
80+
* @param scheduler - Timer function to use, defaults to `setTimeout`. Useful for mocked timers.
81+
*
82+
* @example
83+
* await delay(100)
84+
*
85+
* @example
86+
* // With mocked timers
87+
* const { setTimeout } = getSafeTimers()
88+
* await delay(100, setTimeout)
89+
*/
90+
export function delay(timeout: number, scheduler: typeof setTimeout = setTimeout): Promise<void> {
91+
return new Promise(resolve => scheduler(resolve, timeout))
92+
}

packages/vitest/src/integrations/chai/poll.ts

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import type { Assertion, ExpectStatic } from '@vitest/expect'
22
import type { Test } from '@vitest/runner'
33
import { chai } from '@vitest/expect'
4-
import { getSafeTimers } from '@vitest/utils/timers'
4+
import { delay, getSafeTimers } from '@vitest/utils/timers'
55
import { getWorkerState } from '../../runtime/utils'
66

77
// these matchers are not supported because they don't make sense with poll
@@ -26,6 +26,25 @@ const unsupported = [
2626
// resolves
2727
]
2828

29+
/**
30+
* Attaches a `cause` property to the error if missing, copies the stack trace from the source, and throws.
31+
*
32+
* @param error - The error to throw
33+
* @param source - Error to copy the stack trace from
34+
*
35+
* @throws Always throws the provided error with an amended stack trace
36+
*/
37+
function throwWithCause(error: any, source: Error) {
38+
if (error.cause == null) {
39+
error.cause = new Error('Matcher did not succeed in time.')
40+
}
41+
42+
throw copyStackTrace(
43+
error,
44+
source,
45+
)
46+
}
47+
2948
export function createExpectPoll(expect: ExpectStatic): ExpectStatic['poll'] {
3049
return function poll(fn, options = {}) {
3150
const state = getWorkerState()
@@ -64,60 +83,52 @@ export function createExpectPoll(expect: ExpectStatic): ExpectStatic['poll'] {
6483

6584
return function (this: any, ...args: any[]) {
6685
const STACK_TRACE_ERROR = new Error('STACK_TRACE_ERROR')
67-
const promise = () => new Promise<void>((resolve, reject) => {
68-
let intervalId: any
69-
let timeoutId: any
70-
let lastError: any
86+
const promise = async () => {
7187
const { setTimeout, clearTimeout } = getSafeTimers()
72-
const rejectWithCause = (error: any) => {
73-
if (error.cause == null) {
74-
error.cause = new Error('Matcher did not succeed in time.')
75-
}
76-
reject(
77-
copyStackTrace(
78-
error,
79-
STACK_TRACE_ERROR,
80-
),
81-
)
82-
}
83-
const check = async () => {
84-
try {
85-
chai.util.flag(assertion, '_name', key)
86-
const obj = await fn()
87-
chai.util.flag(assertion, 'object', obj)
88+
89+
let executionPhase: 'fn' | 'assertion' = 'fn'
90+
let hasTimedOut = false
91+
let lastError: any
92+
93+
const timerId = setTimeout(() => {
94+
hasTimedOut = true
95+
}, timeout)
96+
97+
chai.util.flag(assertion, '_name', key)
98+
99+
try {
100+
while (true) {
101+
const isLastAttempt = hasTimedOut
102+
103+
if (isLastAttempt) {
104+
chai.util.flag(assertion, '_isLastPollAttempt', true)
105+
}
106+
88107
try {
89-
resolve(await assertionFunction.call(assertion, ...args))
108+
executionPhase = 'fn'
109+
const obj = await fn()
110+
chai.util.flag(assertion, 'object', obj)
111+
112+
executionPhase = 'assertion'
113+
const output = await assertionFunction.call(assertion, ...args)
114+
115+
return output
90116
}
91117
catch (err) {
92-
if (chai.util.flag(assertion, '_poll.assert_once')) {
93-
clearTimeout(intervalId)
94-
clearTimeout(timeoutId)
118+
lastError = err
95119

96-
rejectWithCause(err)
97-
}
98-
else {
99-
throw err
120+
if (isLastAttempt || (executionPhase === 'assertion' && chai.util.flag(assertion, '_poll.assert_once'))) {
121+
throwWithCause(lastError, STACK_TRACE_ERROR)
100122
}
101-
}
102-
clearTimeout(intervalId)
103-
clearTimeout(timeoutId)
104-
}
105-
catch (err) {
106-
lastError = err
107-
if (!chai.util.flag(assertion, '_isLastPollAttempt')) {
108-
intervalId = setTimeout(check, interval)
123+
124+
await delay(interval, setTimeout)
109125
}
110126
}
111127
}
112-
timeoutId = setTimeout(() => {
113-
clearTimeout(intervalId)
114-
chai.util.flag(assertion, '_isLastPollAttempt', true)
115-
check()
116-
.then(() => rejectWithCause(lastError))
117-
.catch(e => rejectWithCause(e))
118-
}, timeout)
119-
check()
120-
})
128+
finally {
129+
clearTimeout(timerId)
130+
}
131+
}
121132
let awaited = false
122133
test.onFinished ??= []
123134
test.onFinished.push(() => {

0 commit comments

Comments
 (0)