Skip to content

Commit a0cfc84

Browse files
committed
test: fix environment variable tests to verify actual values
The tests in test_translate_stdio_endpoint.py were not properly verifying environment variables. They sent variable names via stdin instead of command line arguments, and never read subprocess output to verify the values were actually set. Changes: - Add _read_output() helper to read subprocess output via pubsub - Update test script to accept env var names as command line arguments - Rewrite tests to read and verify actual environment variable values - Remove unnecessary os.environ.copy() in test_multiple_env_vars All tests now properly validate that environment variables are correctly passed to and set in the subprocess. Signed-off-by: Jonathan Springer <[email protected]> Signed-off-by: Jonathan Springer <[email protected]> Clean up async processing and RuntimeWarnings Signed-off-by: Jonathan Springer <[email protected]> Let's use jq where we can get away with it in tests Signed-off-by: Jonathan Springer <[email protected]> Fix a flake8 documentation find. Signed-off-by: Jonathan Springer <[email protected]> Pylint fixes Signed-off-by: Jonathan Springer <[email protected]>
1 parent fa92384 commit a0cfc84

File tree

3 files changed

+173
-164
lines changed

3 files changed

+173
-164
lines changed

mcpgateway/translate.py

Lines changed: 39 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
23
'''Location: ./mcpgateway/translate.py
34
Copyright 2025
45
SPDX-License-Identifier: Apache-2.0
@@ -123,7 +124,7 @@
123124
import shlex
124125
import signal
125126
import sys
126-
from typing import Any, AsyncIterator, cast, Dict, List, Optional, Sequence, Tuple
127+
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
127128
from urllib.parse import urlencode
128129
import uuid
129130

@@ -146,6 +147,7 @@
146147
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
147148
from starlette.applications import Starlette
148149
from starlette.routing import Route
150+
from starlette.types import Receive, Scope, Send
149151

150152
# First-Party
151153
from mcpgateway.services.logging_service import LoggingService
@@ -380,6 +382,10 @@ async def start(self, additional_env_vars: Optional[Dict[str, str]] = None) -> N
380382
>>> asyncio.run(test_start()) # doctest: +SKIP
381383
True
382384
"""
385+
# Stop existing subprocess before starting a new one
386+
if self._proc is not None:
387+
await self.stop()
388+
383389
LOGGER.info(f"Starting stdio subprocess: {self._cmd}")
384390

385391
# Build environment from base + configured + additional
@@ -388,12 +394,18 @@ async def start(self, additional_env_vars: Optional[Dict[str, str]] = None) -> N
388394
if additional_env_vars:
389395
env.update(additional_env_vars)
390396

397+
# System-critical environment variables that must never be cleared
398+
system_critical_vars = {"PATH", "HOME", "TMPDIR", "TEMP", "TMP", "USER", "LOGNAME", "SHELL", "LANG", "LC_ALL", "LC_CTYPE", "PYTHONHOME", "PYTHONPATH"}
399+
391400
# Clear any mapped env vars that weren't provided in headers to avoid inheritance
392401
if self._header_mappings:
393402
for env_var_name in self._header_mappings.values():
394-
if env_var_name not in (additional_env_vars or {}):
395-
env[env_var_name] = ""
403+
if env_var_name not in (additional_env_vars or {}) and env_var_name not in system_critical_vars:
404+
# Delete the variable instead of setting to empty string to avoid
405+
# breaking subprocess initialization
406+
env.pop(env_var_name, None)
396407

408+
LOGGER.debug(f"Subprocess environment variables: {list(env.keys())}")
397409
self._proc = await asyncio.create_subprocess_exec(
398410
*shlex.split(self._cmd),
399411
stdin=asyncio.subprocess.PIPE,
@@ -406,6 +418,8 @@ async def start(self, additional_env_vars: Optional[Dict[str, str]] = None) -> N
406418
if not self._proc.stdin or not self._proc.stdout:
407419
raise RuntimeError(f"Failed to create subprocess with stdin/stdout pipes for command: {self._cmd}")
408420

421+
LOGGER.debug("Subprocess started successfully")
422+
409423
self._stdin = self._proc.stdin
410424
self._pump_task = asyncio.create_task(self._pump_stdout())
411425

@@ -677,7 +691,7 @@ def _build_fastapi(
677691
# Add CORS middleware if origins specified
678692
if cors_origins:
679693
app.add_middleware(
680-
cast("type", CORSMiddleware),
694+
CORSMiddleware,
681695
allow_origins=cors_origins,
682696
allow_credentials=True,
683697
allow_methods=["*"],
@@ -1073,7 +1087,7 @@ async def _run_stdio_to_sse(
10731087
log_level=log_level,
10741088
lifespan="off",
10751089
)
1076-
server = uvicorn.Server(config)
1090+
uvicorn_server = uvicorn.Server(config)
10771091

10781092
shutting_down = asyncio.Event() # 🔄 make shutdown idempotent
10791093

@@ -1103,24 +1117,15 @@ async def _shutdown() -> None:
11031117
await stdio.stop()
11041118
# Graceful shutdown by setting the shutdown event
11051119
# Use getattr to safely access should_exit attribute
1106-
setattr(server, "should_exit", getattr(server, "should_exit", False) or True)
1120+
setattr(uvicorn_server, "should_exit", getattr(uvicorn_server, "should_exit", False) or True)
11071121

11081122
loop = asyncio.get_running_loop()
11091123
for sig in (signal.SIGINT, signal.SIGTERM):
11101124
with suppress(NotImplementedError): # Windows lacks add_signal_handler
1111-
1112-
def shutdown_handler(*args): # pylint: disable=unused-argument
1113-
"""Handle shutdown signal by creating shutdown task.
1114-
1115-
Args:
1116-
*args: Signal handler arguments (unused).
1117-
"""
1118-
asyncio.create_task(_shutdown())
1119-
1120-
loop.add_signal_handler(sig, shutdown_handler)
1125+
loop.add_signal_handler(sig, lambda *_: asyncio.create_task(_shutdown()))
11211126

11221127
LOGGER.info(f"Bridge ready → http://{host}:{port}{sse_path}")
1123-
await server.serve()
1128+
await uvicorn_server.serve()
11241129
await _shutdown() # final cleanup
11251130

11261131

@@ -1377,7 +1382,7 @@ async def _run_stdio_to_streamable_http(
13771382
LOGGER.info(f"Starting stdio to streamable HTTP bridge for command: {cmd}")
13781383

13791384
# Create a simple MCP server that will proxy to stdio subprocess
1380-
server = MCPServer(name="stdio-proxy")
1385+
mcp_server = MCPServer(name="stdio-proxy")
13811386

13821387
# Create subprocess for stdio communication
13831388
process = await asyncio.create_subprocess_exec(
@@ -1392,13 +1397,13 @@ async def _run_stdio_to_streamable_http(
13921397

13931398
# Set up the streamable HTTP session manager with the server
13941399
session_manager = StreamableHTTPSessionManager(
1395-
app=server,
1400+
app=mcp_server,
13961401
stateless=stateless,
13971402
json_response=json_response,
13981403
)
13991404

14001405
# Create Starlette app to host the streamable HTTP endpoint
1401-
async def handle_mcp(request) -> None:
1406+
async def handle_mcp(request: Request) -> None:
14021407
"""Handle MCP requests via streamable HTTP.
14031408
14041409
Args:
@@ -1418,8 +1423,8 @@ async def handle_mcp(request) -> None:
14181423
>>> asyncio.run(test_handle())
14191424
True
14201425
"""
1421-
# The session manager handles all the protocol details
1422-
await session_manager.handle_request(request.scope, request.receive, request.send)
1426+
# The session manager handles all the protocol details - #TODO: I don't like accessing _send directly
1427+
await session_manager.handle_request(request.scope, request.receive, request._send) # pylint: disable=W0212
14231428

14241429
routes = [
14251430
Route("/mcp", handle_mcp, methods=["GET", "POST"]),
@@ -1430,12 +1435,8 @@ async def handle_mcp(request) -> None:
14301435

14311436
# Add CORS middleware if specified
14321437
if cors:
1433-
# Import here to avoid unnecessary dependency when CORS not used
1434-
# Third-Party
1435-
from starlette.middleware.cors import CORSMiddleware as StarletteCORS # pylint: disable=import-outside-toplevel
1436-
14371438
app.add_middleware(
1438-
cast("type", StarletteCORS),
1439+
CORSMiddleware,
14391440
allow_origins=cors,
14401441
allow_credentials=True,
14411442
allow_methods=["*"],
@@ -1450,7 +1451,7 @@ async def handle_mcp(request) -> None:
14501451
log_level=log_level,
14511452
lifespan="off",
14521453
)
1453-
server = uvicorn.Server(config)
1454+
uvicorn_server = uvicorn.Server(config)
14541455

14551456
shutting_down = asyncio.Event()
14561457

@@ -1466,21 +1467,12 @@ async def _shutdown() -> None:
14661467
await asyncio.wait_for(process.wait(), 5)
14671468
# Graceful shutdown by setting the shutdown event
14681469
# Use getattr to safely access should_exit attribute
1469-
setattr(server, "should_exit", getattr(server, "should_exit", False) or True)
1470+
setattr(uvicorn_server, "should_exit", getattr(uvicorn_server, "should_exit", False) or True)
14701471

14711472
loop = asyncio.get_running_loop()
14721473
for sig in (signal.SIGINT, signal.SIGTERM):
14731474
with suppress(NotImplementedError): # Windows lacks add_signal_handler
1474-
1475-
def shutdown_handler(*args): # pylint: disable=unused-argument
1476-
"""Handle shutdown signal by creating shutdown task.
1477-
1478-
Args:
1479-
*args: Signal handler arguments (unused).
1480-
"""
1481-
asyncio.create_task(_shutdown())
1482-
1483-
loop.add_signal_handler(sig, shutdown_handler)
1475+
loop.add_signal_handler(sig, lambda *_: asyncio.create_task(_shutdown()))
14841476

14851477
# Pump messages between stdio and HTTP
14861478
async def pump_stdio_to_http() -> None:
@@ -1537,7 +1529,7 @@ async def pump_http_to_stdio(data: str) -> None:
15371529

15381530
try:
15391531
LOGGER.info(f"Streamable HTTP bridge ready → http://{host}:{port}/mcp")
1540-
await server.serve()
1532+
await uvicorn_server.serve()
15411533
finally:
15421534
pump_task.cancel()
15431535
await _shutdown()
@@ -1816,7 +1808,7 @@ async def _run_multi_protocol_server( # pylint: disable=too-many-positional-arg
18161808
# Add CORS middleware if specified
18171809
if cors:
18181810
app.add_middleware(
1819-
cast("type", CORSMiddleware),
1811+
CORSMiddleware,
18201812
allow_origins=cors,
18211813
allow_credentials=True,
18221814
allow_methods=["*"],
@@ -2060,7 +2052,7 @@ async def mcp_post(request: Request) -> Response:
20602052
return PlainTextResponse("accepted", status_code=status.HTTP_202_ACCEPTED)
20612053

20622054
# ASGI wrapper to route GET/other /mcp scopes to streamable_manager.handle_request
2063-
async def mcp_asgi_wrapper(scope, receive, send):
2055+
async def mcp_asgi_wrapper(scope: Scope, receive: Receive, send: Send) -> None:
20642056
"""
20652057
ASGI middleware that intercepts HTTP requests to the `/mcp` endpoint.
20662058
@@ -2069,9 +2061,9 @@ async def mcp_asgi_wrapper(scope, receive, send):
20692061
passed to the original FastAPI application.
20702062
20712063
Args:
2072-
scope (dict): The ASGI scope dictionary containing request metadata.
2073-
receive (Callable): An awaitable that yields incoming ASGI events.
2074-
send (Callable): An awaitable used to send ASGI events.
2064+
scope (Scope): The ASGI scope dictionary containing request metadata.
2065+
receive (Receive): An awaitable that yields incoming ASGI events.
2066+
send (Send): An awaitable used to send ASGI events.
20752067
"""
20762068
if scope.get("type") == "http" and scope.get("path") == "/mcp" and streamable_manager:
20772069
# Let StreamableHTTPSessionManager handle session-oriented streaming
@@ -2082,7 +2074,7 @@ async def mcp_asgi_wrapper(scope, receive, send):
20822074
await original_app(scope, receive, send)
20832075

20842076
# Replace the app used by uvicorn with the ASGI wrapper
2085-
app = mcp_asgi_wrapper
2077+
app = mcp_asgi_wrapper # type: ignore[assignment]
20862078

20872079
# ---------------------- Server lifecycle ----------------------
20882080
config = uvicorn.Config(
@@ -2112,16 +2104,7 @@ async def _shutdown() -> None:
21122104
loop = asyncio.get_running_loop()
21132105
for sig in (signal.SIGINT, signal.SIGTERM):
21142106
with suppress(NotImplementedError):
2115-
2116-
def shutdown_handler(*args): # pylint: disable=unused-argument
2117-
"""Handle shutdown signal by creating shutdown task.
2118-
2119-
Args:
2120-
*args: Signal handler arguments (unused).
2121-
"""
2122-
asyncio.create_task(_shutdown())
2123-
2124-
loop.add_signal_handler(sig, shutdown_handler)
2107+
loop.add_signal_handler(sig, lambda *_: asyncio.create_task(_shutdown()))
21252108

21262109
# If we have a streamable manager, start its context so it can accept ASGI /mcp
21272110
if streamable_manager:

tests/unit/mcpgateway/test_translate.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -958,8 +958,12 @@ def _fake_main(argv=None):
958958
assert executed == ["main_called"]
959959

960960

961+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
961962
def test_main_function_stdio(monkeypatch, translate):
962-
"""Test main() function with --stdio argument."""
963+
"""Test main() function with --stdio argument.
964+
965+
Note: This test closes coroutines which may generate RuntimeWarnings during garbage collection.
966+
"""
963967
executed: list[str] = []
964968

965969
async def _fake_stdio_runner(*args):
@@ -982,8 +986,12 @@ def _fake_asyncio_run(coro):
982986
assert "asyncio_run" in executed
983987

984988

989+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
985990
def test_main_function_sse(monkeypatch, translate):
986-
"""Test main() function with --sse argument."""
991+
"""Test main() function with --sse argument.
992+
993+
Note: This test closes coroutines which may generate RuntimeWarnings during garbage collection.
994+
"""
987995
executed: list[str] = []
988996

989997
async def _fake_sse_runner(*args):
@@ -1003,8 +1011,13 @@ def _fake_asyncio_run(coro):
10031011
assert "asyncio_run" in executed
10041012

10051013

1014+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
10061015
def test_main_function_keyboard_interrupt(monkeypatch, translate, capsys):
1007-
"""Test main() function handles KeyboardInterrupt gracefully."""
1016+
"""Test main() function handles KeyboardInterrupt gracefully.
1017+
1018+
Note: This test raises KeyboardInterrupt which prevents the coroutine from being awaited,
1019+
resulting in a RuntimeWarning during garbage collection. This is expected behavior.
1020+
"""
10081021

10091022
def _raise_keyboard_interrupt(*args):
10101023
raise KeyboardInterrupt()
@@ -1019,8 +1032,13 @@ def _raise_keyboard_interrupt(*args):
10191032
assert captured.out == "\n" # Should print newline to restore shell prompt
10201033

10211034

1035+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
10221036
def test_main_function_not_implemented_error(monkeypatch, translate, capsys):
1023-
"""Test main() function handles NotImplementedError."""
1037+
"""Test main() function handles NotImplementedError.
1038+
1039+
Note: This test raises NotImplementedError which prevents the coroutine from being awaited,
1040+
resulting in a RuntimeWarning during garbage collection. This is expected behavior.
1041+
"""
10241042

10251043
# def _raise_not_implemented(coro, *a, **kw):
10261044
# # close the coroutine if the autouse fixture didn't remove it
@@ -1405,29 +1423,9 @@ def __init__(self, routes=None):
14051423
def add_middleware(self, middleware_class, **kwargs):
14061424
calls.append(f"add_middleware_{middleware_class.__name__}")
14071425

1408-
# Mock Starlette CORS middleware import
1409-
class MockCORSMiddleware:
1410-
def __init__(self, **kwargs):
1411-
pass
1412-
1413-
# Mock the import path for CORS middleware
1414-
# Standard
1415-
import types
1416-
1417-
cors_module = types.ModuleType("cors")
1418-
cors_module.CORSMiddleware = MockCORSMiddleware
1419-
middleware_module = types.ModuleType("middleware")
1420-
middleware_module.cors = cors_module
1421-
starlette_module = types.ModuleType("starlette")
1422-
starlette_module.middleware = middleware_module
1423-
14241426
# Standard
14251427
import sys
14261428

1427-
sys.modules["starlette"] = starlette_module
1428-
sys.modules["starlette.middleware"] = middleware_module
1429-
sys.modules["starlette.middleware.cors"] = cors_module
1430-
14311429
class MockTask:
14321430
def cancel(self):
14331431
pass
@@ -1470,7 +1468,7 @@ async def mock_shutdown():
14701468
await translate._run_stdio_to_streamable_http("echo test", 8000, "info", cors=["http://example.com"])
14711469

14721470
# Verify CORS middleware was added (using our Mock class name)
1473-
assert "add_middleware_MockCORSMiddleware" in calls
1471+
assert "add_middleware_CORSMiddleware" in calls
14741472
finally:
14751473
# Clean up sys.modules to avoid affecting other tests
14761474
sys.modules.pop("starlette", None)

0 commit comments

Comments
 (0)