diff --git a/src/core/handlers/RequestHandler.ts b/src/core/handlers/RequestHandler.ts index 5c0c3ffa7..f9d34f384 100644 --- a/src/core/handlers/RequestHandler.ts +++ b/src/core/handlers/RequestHandler.ts @@ -1,6 +1,9 @@ -import { invariant } from 'outvariant' import { getCallFrame } from '../utils/internal/getCallFrame' -import { isIterable } from '../utils/internal/isIterable' +import { + AsyncIterable, + Iterable, + isIterable, +} from '../utils/internal/isIterable' import type { ResponseResolutionContext } from '../utils/executeHandlers' import type { MaybePromise } from '../typeUtils' import { StrictRequest, StrictResponse } from '..//HttpResponse' @@ -52,7 +55,12 @@ export type AsyncResponseResolverReturnType< ResponseBodyType extends DefaultBodyType, > = MaybePromise< | ResponseResolverReturnType - | Generator< + | Iterable< + MaybeAsyncResponseResolverReturnType, + MaybeAsyncResponseResolverReturnType, + MaybeAsyncResponseResolverReturnType + > + | AsyncIterable< MaybeAsyncResponseResolverReturnType, MaybeAsyncResponseResolverReturnType, MaybeAsyncResponseResolverReturnType @@ -117,12 +125,18 @@ export abstract class RequestHandler< public isUsed: boolean protected resolver: ResponseResolver - private resolverGenerator?: Generator< - MaybeAsyncResponseResolverReturnType, - MaybeAsyncResponseResolverReturnType, - MaybeAsyncResponseResolverReturnType - > - private resolverGeneratorResult?: Response | StrictResponse + private resolverIterator?: + | Iterator< + MaybeAsyncResponseResolverReturnType, + MaybeAsyncResponseResolverReturnType, + MaybeAsyncResponseResolverReturnType + > + | AsyncIterator< + MaybeAsyncResponseResolverReturnType, + MaybeAsyncResponseResolverReturnType, + MaybeAsyncResponseResolverReturnType + > + private resolverIteratorResult?: Response | StrictResponse private options?: HandlerOptions constructor(args: RequestHandlerArgs) { @@ -256,6 +270,9 @@ export abstract class RequestHandler< return null } + // Preemptively mark the handler as used. + // Generators will undo this because only when the resolver reaches the + // "done" state of the generator that it considers the handler used. this.isUsed = true // Create a response extraction wrapper around the resolver @@ -301,48 +318,38 @@ export abstract class RequestHandler< resolver: ResponseResolver, ): ResponseResolver { return async (info): Promise> => { - const result = this.resolverGenerator || (await resolver(info)) - - if (isIterable>(result)) { - // Immediately mark this handler as unused. - // Only when the generator is done, the handler will be - // considered used. - this.isUsed = false - - const { value, done } = result[Symbol.iterator]().next() - const nextResponse = await value - - if (done) { - this.isUsed = true + if (!this.resolverIterator) { + const result = await resolver(info) + if (!isIterable(result)) { + return result } + this.resolverIterator = + Symbol.iterator in result + ? result[Symbol.iterator]() + : result[Symbol.asyncIterator]() + } - // If the generator is done and there is no next value, - // return the previous generator's value. - if (!nextResponse && done) { - invariant( - this.resolverGeneratorResult, - 'Failed to returned a previously stored generator response: the value is not a valid Response.', - ) - - // Clone the previously stored response from the generator - // so that it could be read again. - return this.resolverGeneratorResult.clone() as StrictResponse - } + // Opt-out from marking this handler as used. + this.isUsed = false - if (!this.resolverGenerator) { - this.resolverGenerator = result - } + const { done, value } = await this.resolverIterator.next() + const nextResponse = await value - if (nextResponse) { - // Also clone the response before storing it - // so it could be read again. - this.resolverGeneratorResult = nextResponse?.clone() - } + if (nextResponse) { + this.resolverIteratorResult = nextResponse.clone() + } + + if (done) { + // A one-time generator resolver stops affecting the network + // only after it's been completely exhausted. + this.isUsed = true - return nextResponse + // Clone the previously stored response so it can be read + // when receiving it repeatedly from the "done" generator. + return this.resolverIteratorResult?.clone() } - return result + return nextResponse } } diff --git a/src/core/utils/internal/isIterable.ts b/src/core/utils/internal/isIterable.ts index c3ef63783..670f2b649 100644 --- a/src/core/utils/internal/isIterable.ts +++ b/src/core/utils/internal/isIterable.ts @@ -1,12 +1,32 @@ +/** + * This is the same as TypeScript's `Iterable`, but with all three type parameters. + * @todo Remove once TypeScript 5.6 is the minimum. + */ +export interface Iterable { + [Symbol.iterator](): Iterator +} + +/** + * This is the same as TypeScript's `AsyncIterable`, but with all three type parameters. + * @todo Remove once TypeScript 5.6 is the minimum. + */ +export interface AsyncIterable { + [Symbol.asyncIterator](): AsyncIterator +} + /** * Determines if the given function is an iterator. */ export function isIterable( fn: any, -): fn is Generator { +): fn is + | Iterable + | AsyncIterable { if (!fn) { return false } - return typeof (fn as Generator)[Symbol.iterator] == 'function' + return ( + Reflect.has(fn, Symbol.iterator) || Reflect.has(fn, Symbol.asyncIterator) + ) } diff --git a/test/node/rest-api/response/generator.test.ts b/test/node/rest-api/response/generator.test.ts new file mode 100644 index 000000000..27b3ec61e --- /dev/null +++ b/test/node/rest-api/response/generator.test.ts @@ -0,0 +1,109 @@ +/** + * @vitest-environment node + */ +import { http, HttpResponse, delay } from 'msw' +import { setupServer } from 'msw/node' + +const server = setupServer() + +async function fetchJson(input: string | URL | Request, init?: RequestInit) { + return fetch(input, init).then((response) => response.json()) +} + +beforeAll(() => { + server.listen() +}) + +afterEach(() => { + server.resetHandlers() +}) + +afterAll(() => { + server.close() +}) + +it('supports generator function as response resolver', async () => { + server.use( + http.get('https://example.com/weather', function* () { + let degree = 10 + + while (degree < 13) { + degree++ + yield HttpResponse.json(degree) + } + + degree++ + return HttpResponse.json(degree) + }), + ) + + // Must respond with yielded responses. + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(11) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(12) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(13) + // Must respond with the final "done" response. + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(14) + // Must keep responding with the final "done" response. + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(14) +}) + +it('supports async generator function as response resolver', async () => { + server.use( + http.get('https://example.com/weather', async function* () { + await delay(20) + + let degree = 10 + + while (degree < 13) { + degree++ + yield HttpResponse.json(degree) + } + + degree++ + return HttpResponse.json(degree) + }), + ) + + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(11) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(12) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(13) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(14) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(14) +}) + +it('supports generator function as one-time response resolver', async () => { + server.use( + http.get( + 'https://example.com/weather', + function* () { + let degree = 10 + + while (degree < 13) { + degree++ + yield HttpResponse.json(degree) + } + + degree++ + return HttpResponse.json(degree) + }, + { once: true }, + ), + http.get('*', () => { + return HttpResponse.json('fallback') + }), + ) + + // Must respond with the yielded incrementing responses. + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(11) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(12) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(13) + // Must respond with the "done" final response from the iterator. + await expect(fetchJson('https://example.com/weather')).resolves.toEqual(14) + // Must respond with the other handler since the generator one is used. + await expect(fetchJson('https://example.com/weather')).resolves.toEqual( + 'fallback', + ) + await expect(fetchJson('https://example.com/weather')).resolves.toEqual( + 'fallback', + ) +}) diff --git a/test/typings/http.test-d.ts b/test/typings/http.test-d.ts index f782db09f..f40397c1a 100644 --- a/test/typings/http.test-d.ts +++ b/test/typings/http.test-d.ts @@ -195,3 +195,21 @@ it('infers a narrower json response type', () => { return HttpResponse.json({ a: 1, b: 2 }) }) }) + +it('errors when returning non-Response data from resolver', () => { + http.get( + '/resource', + // @ts-expect-error + () => 123, + ) + http.get( + '/resource', + // @ts-expect-error + () => 'foo', + ) + http.get( + '/resource', + // @ts-expect-error + () => ({}), + ) +}) diff --git a/test/typings/resolver-generator.test-d.ts b/test/typings/resolver-generator.test-d.ts new file mode 100644 index 000000000..f69478ef1 --- /dev/null +++ b/test/typings/resolver-generator.test-d.ts @@ -0,0 +1,50 @@ +import { it } from 'vitest' +import { http, HttpResponse } from 'msw' + +it('supports generator function as response resolver', () => { + http.get('/', function* () { + yield HttpResponse.json({ value: 1 }) + yield HttpResponse.json({ value: 2 }) + return HttpResponse.json({ value: 3 }) + }) + + http.get('/', function* () { + yield HttpResponse.json({ value: 'one' }) + yield HttpResponse.json({ + // @ts-expect-error Expected string, got number. + value: 2, + }) + return HttpResponse.json({ value: 'three' }) + }) +}) + +it('supports async generator function as response resolver', () => { + http.get('/', async function* () { + yield HttpResponse.json({ value: 1 }) + yield HttpResponse.json({ value: 2 }) + return HttpResponse.json({ value: 3 }) + }) + + http.get('/', async function* () { + yield HttpResponse.json({ value: 'one' }) + yield HttpResponse.json({ + // @ts-expect-error Expected string, got number. + value: 2, + }) + return HttpResponse.json({ value: 'three' }) + }) +}) + +it('supports returning nothing from generator resolvers', () => { + http.get('/', function* () {}) + http.get('/', async function* () {}) +}) + +it('supports returning undefined from generator resolvers', () => { + http.get('/', function* () { + return undefined + }) + http.get('/', async function* () { + return undefined + }) +}) diff --git a/test/typings/vitest.config.mts b/test/typings/vitest.config.mts index a97c9070e..4f879ab8e 100644 --- a/test/typings/vitest.config.mts +++ b/test/typings/vitest.config.mts @@ -32,6 +32,9 @@ export default defineConfig({ const tsConfigPath = tsConfigPaths.find((path) => fs.existsSync(path), ) as string + + console.log('Using tsconfig at: %s', tsConfigPath) + return tsConfigPath })(), },