Skip to content

Allow WWW-Authenticate header to be accessed by client #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"@modelcontextprotocol/inspector-cli": "^0.15.0",
"@modelcontextprotocol/inspector-client": "^0.15.0",
"@modelcontextprotocol/inspector-server": "^0.15.0",
"@modelcontextprotocol/sdk": "^1.13.1",
"@modelcontextprotocol/sdk": "^1.15.0",
"concurrently": "^9.0.1",
"open": "^10.1.0",
"shell-quote": "^1.8.2",
Expand Down
150 changes: 101 additions & 49 deletions server/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,30 @@ const getHttpHeaders = (
const app = express();
app.use(cors());
app.use((req, res, next) => {
res.header("Access-Control-Expose-Headers", "mcp-session-id");
res.header("Access-Control-Expose-Headers", [
"mcp-session-id",
"WWW-Authenticate",
]);
next();
});

const maybeSetAuthHeader = (res: express.Response, header?: string) => {
if (header) {
res.setHeader("WWW-Authenticate", header);
}
};

const setAuthHeaderFromError = (res: express.Response, error: unknown) => {
if (
error &&
typeof error === "object" &&
"authHeader" in error &&
typeof error.authHeader === "string"
) {
maybeSetAuthHeader(res, error.authHeader);
}
};

const webAppTransports: Map<string, Transport> = new Map<string, Transport>(); // Web app transports by web app sessionId
const serverTransports: Map<string, Transport> = new Map<string, Transport>(); // Server Transports by web app sessionId

Expand Down Expand Up @@ -171,66 +191,89 @@ const authMiddleware = (
next();
};

const createTransport = async (req: express.Request): Promise<Transport> => {
const createTransport = async (
req: express.Request,
): Promise<{ transport: Transport; authHeader?: string }> => {
const query = req.query;
console.log("Query parameters:", JSON.stringify(query));

const originalFetch = globalThis.fetch;
let authHeader: string | undefined;

const interceptingFetch = async (
...args: Parameters<typeof fetch>
): Promise<Response> => {
const response = await originalFetch(...args);
if (response.status === 401 && response.headers.has("WWW-Authenticate")) {
authHeader = response.headers.get("WWW-Authenticate") ?? undefined;
}
return response;
};

const transportType = query.transportType as string;

if (transportType === "stdio") {
const command = query.command as string;
const origArgs = shellParseArgs(query.args as string) as string[];
const queryEnv = query.env ? JSON.parse(query.env as string) : {};
const env = { ...process.env, ...defaultEnvironment, ...queryEnv };
try {
if (transportType === "stdio") {
const command = query.command as string;
const origArgs = shellParseArgs(query.args as string) as string[];
const queryEnv = query.env ? JSON.parse(query.env as string) : {};
const env = { ...process.env, ...defaultEnvironment, ...queryEnv };

const { cmd, args } = findActualExecutable(command, origArgs);
const { cmd, args } = findActualExecutable(command, origArgs);

console.log(`STDIO transport: command=${cmd}, args=${args}`);
console.log(`STDIO transport: command=${cmd}, args=${args}`);

const transport = new StdioClientTransport({
command: cmd,
args,
env,
stderr: "pipe",
});
const transport = new StdioClientTransport({
command: cmd,
args,
env,
stderr: "pipe",
});

await transport.start();
return transport;
} else if (transportType === "sse") {
const url = query.url as string;
await transport.start();
return { transport, authHeader };
} else if (transportType === "sse") {
const url = query.url as string;

const headers = getHttpHeaders(req, transportType);
const headers = getHttpHeaders(req, transportType);

console.log(
`SSE transport: url=${url}, headers=${JSON.stringify(headers)}`,
);
console.log(
`SSE transport: url=${url}, headers=${JSON.stringify(headers)}`,
);

const transport = new SSEClientTransport(new URL(url), {
eventSourceInit: {
fetch: (url, init) => fetch(url, { ...init, headers }),
},
requestInit: {
headers,
},
});
await transport.start();
return transport;
} else if (transportType === "streamable-http") {
const headers = getHttpHeaders(req, transportType);

const transport = new StreamableHTTPClientTransport(
new URL(query.url as string),
{
const transport = new SSEClientTransport(new URL(url), {
requestInit: {
headers,
},
},
);
await transport.start();
return transport;
} else {
console.error(`Invalid transport type: ${transportType}`);
throw new Error("Invalid transport type specified");
fetch: (url, init) => interceptingFetch(url, { ...init, headers }),
});
await transport.start();
return { transport, authHeader };
} else if (transportType === "streamable-http") {
const headers = getHttpHeaders(req, transportType);

const transport = new StreamableHTTPClientTransport(
new URL(query.url as string),
{
requestInit: {
headers,
},
fetch: (url, init) => interceptingFetch(url, { ...init, headers }),
},
);
await transport.start();
return { transport, authHeader };
} else {
console.error(`Invalid transport type: ${transportType}`);
throw new Error("Invalid transport type specified");
}
} catch (error) {
if (error && typeof error === "object") {
(error as { authHeader?: string }).authHeader = authHeader;
}
throw error;
} finally {
// nothing to clean up
}
};

Expand Down Expand Up @@ -269,8 +312,11 @@ app.post(
try {
console.log("New StreamableHttp connection request");
try {
serverTransport = await createTransport(req);
const { transport, authHeader } = await createTransport(req);
serverTransport = transport;
maybeSetAuthHeader(res, authHeader);
} catch (error) {
setAuthHeaderFromError(res, error);
if (error instanceof SseError && error.code === 401) {
console.error(
"Received 401 Unauthorized from MCP server:",
Expand Down Expand Up @@ -374,9 +420,12 @@ app.get(
console.log("New STDIO connection request");
let serverTransport: Transport | undefined;
try {
serverTransport = await createTransport(req);
const { transport, authHeader } = await createTransport(req);
serverTransport = transport;
console.log("Created server transport");
maybeSetAuthHeader(res, authHeader);
} catch (error) {
setAuthHeaderFromError(res, error);
if (error instanceof SseError && error.code === 401) {
console.error(
"Received 401 Unauthorized from MCP server. Authentication failure.",
Expand Down Expand Up @@ -443,8 +492,11 @@ app.get(
);
let serverTransport: Transport | undefined;
try {
serverTransport = await createTransport(req);
const { transport, authHeader } = await createTransport(req);
serverTransport = transport;
maybeSetAuthHeader(res, authHeader);
} catch (error) {
setAuthHeaderFromError(res, error);
if (error instanceof SseError && error.code === 401) {
console.error(
"Received 401 Unauthorized from MCP server. Authentication failure.",
Expand Down