diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 3cb4e8a3..2d116344 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -262,6 +262,38 @@ describe("SSEClientTransport", () => { expect(lastServerRequest.headers.authorization).toBe(authToken); }); + it("uses custom fetch implementation from options", async () => { + const authToken = "Bearer custom-token"; + + const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set("Authorization", authToken); + return fetch(url.toString(), { ...init, headers }); + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + fetch: fetchWithAuth, + }); + + await transport.start(); + + expect(lastServerRequest.headers.authorization).toBe(authToken); + + // Send a message to verify fetchWithAuth used for POST as well + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + expect(fetchWithAuth).toHaveBeenCalledTimes(2); + expect(lastServerRequest.method).toBe("POST"); + expect(lastServerRequest.headers.authorization).toBe(authToken); + }); + it("passes custom headers to fetch requests", async () => { const customHeaders = { Authorization: "Bearer test-token", diff --git a/src/client/sse.ts b/src/client/sse.ts index 2546d508..faffecc4 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -1,5 +1,5 @@ import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource"; -import { Transport } from "../shared/transport.js"; +import { Transport, FetchLike } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; @@ -47,6 +47,11 @@ export type SSEClientTransportOptions = { * Customizes recurring POST requests to the server. */ requestInit?: RequestInit; + + /** + * Custom fetch implementation used for all network requests. + */ + fetch?: FetchLike; }; /** @@ -62,6 +67,7 @@ export class SSEClientTransport implements Transport { private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _fetch?: FetchLike; private _protocolVersion?: string; onclose?: () => void; @@ -77,6 +83,7 @@ export class SSEClientTransport implements Transport { this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._fetch = opts?.fetch; } private async _authThenStart(): Promise { @@ -117,7 +124,7 @@ export class SSEClientTransport implements Transport { } private _startOrAuth(): Promise { - const fetchImpl = (this?._eventSourceInit?.fetch || fetch) as typeof fetch +const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch return new Promise((resolve, reject) => { this._eventSource = new EventSource( this._url.href, @@ -242,7 +249,7 @@ export class SSEClientTransport implements Transport { signal: this._abortController?.signal, }; - const response = await fetch(this._endpoint, init); +const response = await (this._fetch ?? fetch)(this._endpoint, init); if (!response.ok) { if (response.status === 401 && this._authProvider) { diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 11dfe7d4..dcd76528 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,4 +1,4 @@ -import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js"; +import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js"; import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { JSONRPCMessage } from "../types.js"; @@ -443,6 +443,35 @@ describe("StreamableHTTPClientTransport", () => { expect(errorSpy).toHaveBeenCalled(); }); + it("uses custom fetch implementation", async () => { + const authToken = "Bearer custom-token"; + + const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set("Authorization", authToken); + return (global.fetch as jest.Mock)(url, { ...init, headers }); + }); + + (global.fetch as jest.Mock) + .mockResolvedValueOnce( + new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } }) + ) + .mockResolvedValueOnce(new Response(null, { status: 202 })); + + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { fetch: fetchWithAuth }); + + await transport.start(); + await (transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise })._startOrAuthSse({}); + + await transport.send({ jsonrpc: "2.0", method: "test", params: {}, id: "1" } as JSONRPCMessage); + + expect(fetchWithAuth).toHaveBeenCalled(); + for (const call of (global.fetch as jest.Mock).mock.calls) { + const headers = call[1].headers as Headers; + expect(headers.get("Authorization")).toBe(authToken); + } + }); + it("should always send specified custom headers", async () => { const requestInit = { @@ -530,7 +559,7 @@ describe("StreamableHTTPClientTransport", () => { // Second retry - should double (2^1 * 100 = 200) expect(getDelay(1)).toBe(200); - // Third retry - should double again (2^2 * 100 = 400) + // Third retry - should double again (2^2 * 100 = 400) expect(getDelay(2)).toBe(400); // Fourth retry - should double again (2^3 * 100 = 800) diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 73078442..b81f1a5d 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -1,4 +1,4 @@ -import { Transport } from "../shared/transport.js"; +import { Transport, FetchLike } from "../shared/transport.js"; import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { EventSourceParserStream } from "eventsource-parser/stream"; @@ -23,7 +23,7 @@ export class StreamableHTTPError extends Error { /** * Options for starting or authenticating an SSE connection */ -interface StartSSEOptions { +export interface StartSSEOptions { /** * The resumption token used to continue long-running requests that were interrupted. * @@ -99,6 +99,11 @@ export type StreamableHTTPClientTransportOptions = { */ requestInit?: RequestInit; + /** + * Custom fetch implementation used for all network requests. + */ + fetch?: FetchLike; + /** * Options to configure the reconnection behavior. */ @@ -122,6 +127,7 @@ export class StreamableHTTPClientTransport implements Transport { private _resourceMetadataUrl?: URL; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _fetch?: FetchLike; private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; @@ -138,6 +144,7 @@ export class StreamableHTTPClientTransport implements Transport { this._resourceMetadataUrl = undefined; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._fetch = opts?.fetch; this._sessionId = opts?.sessionId; this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; } @@ -200,7 +207,7 @@ export class StreamableHTTPClientTransport implements Transport { headers.set("last-event-id", resumptionToken); } - const response = await fetch(this._url, { +const response = await (this._fetch ?? fetch)(this._url, { method: "GET", headers, signal: this._abortController?.signal, @@ -251,15 +258,15 @@ export class StreamableHTTPClientTransport implements Transport { private _normalizeHeaders(headers: HeadersInit | undefined): Record { if (!headers) return {}; - + if (headers instanceof Headers) { return Object.fromEntries(headers.entries()); } - + if (Array.isArray(headers)) { return Object.fromEntries(headers); } - + return { ...headers as Record }; } @@ -414,7 +421,7 @@ export class StreamableHTTPClientTransport implements Transport { signal: this._abortController?.signal, }; - const response = await fetch(this._url, init); +const response = await (this._fetch ?? fetch)(this._url, init); // Handle session ID received during initialization const sessionId = response.headers.get("mcp-session-id"); @@ -520,7 +527,7 @@ export class StreamableHTTPClientTransport implements Transport { signal: this._abortController?.signal, }; - const response = await fetch(this._url, init); +const response = await (this._fetch ?? fetch)(this._url, init); // We specifically handle 405 as a valid response according to the spec, // meaning the server does not support explicit session termination diff --git a/src/shared/transport.ts b/src/shared/transport.ts index 96b291fa..386b6bae 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -1,10 +1,12 @@ import { JSONRPCMessage, MessageExtraInfo, RequestId } from "../types.js"; +export type FetchLike = (url: string | URL, init?: RequestInit) => Promise; + /** * Options for sending a JSON-RPC message. */ export type TransportSendOptions = { - /** + /** * If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with. */ relatedRequestId?: RequestId; @@ -38,7 +40,7 @@ export interface Transport { /** * Sends a JSON-RPC message (request or response). - * + * * If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with. */ send(message: JSONRPCMessage, options?: TransportSendOptions): Promise; @@ -64,9 +66,9 @@ export interface Transport { /** * Callback for when a message (request or response) is received over the connection. - * + * * Includes the requestInfo and authInfo if the transport is authenticated. - * + * * The requestInfo can be used to get the original request information (headers, etc.) */ onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;