Skip to content

Commit 208f13b

Browse files
authored
Merge branch 'main' into fix/session-manager-resilience
2 parents 61f3dc2 + d0443a1 commit 208f13b

File tree

18 files changed

+865
-35
lines changed

18 files changed

+865
-35
lines changed

examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def main(
4141
app = Server("mcp-streamable-http-stateless-demo")
4242

4343
@app.call_tool()
44-
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
44+
async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]:
4545
ctx = app.request_context
4646
interval = arguments.get("interval", 1.0)
4747
count = arguments.get("count", 5)

examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def main(
4545
app = Server("mcp-streamable-http-demo")
4646

4747
@app.call_tool()
48-
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
48+
async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]:
4949
ctx = app.request_context
5050
interval = arguments.get("interval", 1.0)
5151
count = arguments.get("count", 5)

examples/servers/simple-tool/mcp_simple_tool/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
async def fetch_website(
99
url: str,
10-
) -> list[types.Content]:
10+
) -> list[types.ContentBlock]:
1111
headers = {
1212
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
1313
}
@@ -29,7 +29,7 @@ def main(port: int, transport: str) -> int:
2929
app = Server("mcp-website-fetcher")
3030

3131
@app.call_tool()
32-
async def fetch_tool(name: str, arguments: dict) -> list[types.Content]:
32+
async def fetch_tool(name: str, arguments: dict) -> list[types.ContentBlock]:
3333
if name != "fetch":
3434
raise ValueError(f"Unknown tool: {name}")
3535
if "url" not in arguments:

src/mcp/server/fastmcp/prompts/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
import pydantic_core
88
from pydantic import BaseModel, Field, TypeAdapter, validate_call
99

10-
from mcp.types import Content, TextContent
10+
from mcp.types import ContentBlock, TextContent
1111

1212

1313
class Message(BaseModel):
1414
"""Base class for all prompt messages."""
1515

1616
role: Literal["user", "assistant"]
17-
content: Content
17+
content: ContentBlock
1818

19-
def __init__(self, content: str | Content, **kwargs: Any):
19+
def __init__(self, content: str | ContentBlock, **kwargs: Any):
2020
if isinstance(content, str):
2121
content = TextContent(type="text", text=content)
2222
super().__init__(content=content, **kwargs)
@@ -27,7 +27,7 @@ class UserMessage(Message):
2727

2828
role: Literal["user", "assistant"] = "user"
2929

30-
def __init__(self, content: str | Content, **kwargs: Any):
30+
def __init__(self, content: str | ContentBlock, **kwargs: Any):
3131
super().__init__(content=content, **kwargs)
3232

3333

@@ -36,7 +36,7 @@ class AssistantMessage(Message):
3636

3737
role: Literal["user", "assistant"] = "assistant"
3838

39-
def __init__(self, content: str | Content, **kwargs: Any):
39+
def __init__(self, content: str | ContentBlock, **kwargs: Any):
4040
super().__init__(content=content, **kwargs)
4141

4242

src/mcp/server/fastmcp/server.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@
5050
from mcp.server.stdio import stdio_server
5151
from mcp.server.streamable_http import EventStore
5252
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
53+
from mcp.server.transport_security import TransportSecuritySettings
5354
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
5455
from mcp.types import (
5556
AnyFunction,
56-
Content,
57+
ContentBlock,
5758
GetPromptResult,
5859
TextContent,
5960
ToolAnnotations,
@@ -118,6 +119,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
118119

119120
auth: AuthSettings | None = None
120121

122+
# Transport security settings (DNS rebinding protection)
123+
transport_security: TransportSecuritySettings | None = None
124+
121125

122126
def lifespan_wrapper(
123127
app: FastMCP,
@@ -256,7 +260,7 @@ def get_context(self) -> Context[ServerSession, object, Request]:
256260
request_context = None
257261
return Context(request_context=request_context, fastmcp=self)
258262

259-
async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Content]:
263+
async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock]:
260264
"""Call a tool by name with arguments."""
261265
context = self.get_context()
262266
result = await self._tool_manager.call_tool(name, arguments, context=context)
@@ -674,6 +678,7 @@ def sse_app(self, mount_path: str | None = None) -> Starlette:
674678

675679
sse = SseServerTransport(
676680
normalized_message_endpoint,
681+
security_settings=self.settings.transport_security,
677682
)
678683

679684
async def handle_sse(scope: Scope, receive: Receive, send: Send):
@@ -779,6 +784,7 @@ def streamable_http_app(self) -> Starlette:
779784
event_store=self._event_store,
780785
json_response=self.settings.json_response,
781786
stateless=self.settings.stateless_http, # Use the stateless setting
787+
security_settings=self.settings.transport_security,
782788
)
783789

784790
# Create the ASGI handler
@@ -872,12 +878,12 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -
872878

873879
def _convert_to_content(
874880
result: Any,
875-
) -> Sequence[Content]:
881+
) -> Sequence[ContentBlock]:
876882
"""Convert a result to a sequence of content objects."""
877883
if result is None:
878884
return []
879885

880-
if isinstance(result, Content):
886+
if isinstance(result, ContentBlock):
881887
return [result]
882888

883889
if isinstance(result, Image):

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def call_tool(self):
384384
def decorator(
385385
func: Callable[
386386
...,
387-
Awaitable[Iterable[types.Content]],
387+
Awaitable[Iterable[types.ContentBlock]],
388388
],
389389
):
390390
logger.debug("Registering handler for CallToolRequest")

src/mcp/server/sse.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ async def handle_sse(request):
5252
from starlette.types import Receive, Scope, Send
5353

5454
import mcp.types as types
55+
from mcp.server.transport_security import (
56+
TransportSecurityMiddleware,
57+
TransportSecuritySettings,
58+
)
5559
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5660

5761
logger = logging.getLogger(__name__)
@@ -71,16 +75,22 @@ class SseServerTransport:
7175

7276
_endpoint: str
7377
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
78+
_security: TransportSecurityMiddleware
7479

75-
def __init__(self, endpoint: str) -> None:
80+
def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
7681
"""
7782
Creates a new SSE server transport, which will direct the client to POST
7883
messages to the relative or absolute URL given.
84+
85+
Args:
86+
endpoint: The relative or absolute URL for POST messages.
87+
security_settings: Optional security settings for DNS rebinding protection.
7988
"""
8089

8190
super().__init__()
8291
self._endpoint = endpoint
8392
self._read_stream_writers = {}
93+
self._security = TransportSecurityMiddleware(security_settings)
8494
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
8595

8696
@asynccontextmanager
@@ -89,6 +99,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8999
logger.error("connect_sse received non-HTTP request")
90100
raise ValueError("connect_sse can only handle HTTP requests")
91101

102+
# Validate request headers for DNS rebinding protection
103+
request = Request(scope, receive)
104+
error_response = await self._security.validate_request(request, is_post=False)
105+
if error_response:
106+
await error_response(scope, receive, send)
107+
raise ValueError("Request validation failed")
108+
92109
logger.debug("Setting up SSE connection")
93110
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
94111
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
@@ -160,6 +177,11 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
160177
logger.debug("Handling POST message")
161178
request = Request(scope, receive)
162179

180+
# Validate request headers for DNS rebinding protection
181+
error_response = await self._security.validate_request(request, is_post=True)
182+
if error_response:
183+
return await error_response(scope, receive, send)
184+
163185
session_id_param = request.query_params.get("session_id")
164186
if session_id_param is None:
165187
logger.warning("Received request without session_id")

src/mcp/server/streamable_http.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from starlette.responses import Response
2525
from starlette.types import Receive, Scope, Send
2626

27+
from mcp.server.transport_security import (
28+
TransportSecurityMiddleware,
29+
TransportSecuritySettings,
30+
)
2731
from mcp.shared.message import ServerMessageMetadata, SessionMessage
2832
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
2933
from mcp.types import (
@@ -130,12 +134,14 @@ class StreamableHTTPServerTransport:
130134
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
131135
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
132136
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
137+
_security: TransportSecurityMiddleware
133138

134139
def __init__(
135140
self,
136141
mcp_session_id: str | None,
137142
is_json_response_enabled: bool = False,
138143
event_store: EventStore | None = None,
144+
security_settings: TransportSecuritySettings | None = None,
139145
) -> None:
140146
"""
141147
Initialize a new StreamableHTTP server transport.
@@ -148,6 +154,7 @@ def __init__(
148154
event_store: Event store for resumability support. If provided,
149155
resumability will be enabled, allowing clients to
150156
reconnect and resume messages.
157+
security_settings: Optional security settings for DNS rebinding protection.
151158
152159
Raises:
153160
ValueError: If the session ID contains invalid characters.
@@ -158,6 +165,7 @@ def __init__(
158165
self.mcp_session_id = mcp_session_id
159166
self.is_json_response_enabled = is_json_response_enabled
160167
self._event_store = event_store
168+
self._security = TransportSecurityMiddleware(security_settings)
161169
self._request_streams: dict[
162170
RequestId,
163171
tuple[
@@ -251,6 +259,14 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
251259
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
252260
"""Application entry point that handles all HTTP requests"""
253261
request = Request(scope, receive)
262+
263+
# Validate request headers for DNS rebinding protection
264+
is_post = request.method == "POST"
265+
error_response = await self._security.validate_request(request, is_post=is_post)
266+
if error_response:
267+
await error_response(scope, receive, send)
268+
return
269+
254270
if self._terminated:
255271
# If the session has been terminated, return 404 Not Found
256272
response = self._create_error_response(

src/mcp/server/streamable_http_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
EventStore,
2323
StreamableHTTPServerTransport,
2424
)
25+
from mcp.server.transport_security import TransportSecuritySettings
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -59,11 +60,13 @@ def __init__(
5960
event_store: EventStore | None = None,
6061
json_response: bool = False,
6162
stateless: bool = False,
63+
security_settings: TransportSecuritySettings | None = None,
6264
):
6365
self.app = app
6466
self.event_store = event_store
6567
self.json_response = json_response
6668
self.stateless = stateless
69+
self.security_settings = security_settings
6770

6871
# Session tracking (only used if not stateless)
6972
self._session_creation_lock = anyio.Lock()
@@ -161,6 +164,7 @@ async def _handle_stateless_request(
161164
mcp_session_id=None, # No session tracking in stateless mode
162165
is_json_response_enabled=self.json_response,
163166
event_store=None, # No event store in stateless mode
167+
security_settings=self.security_settings,
164168
)
165169

166170
# Start server in a new task
@@ -219,6 +223,7 @@ async def _handle_stateful_request(
219223
mcp_session_id=new_session_id,
220224
is_json_response_enabled=self.json_response,
221225
event_store=self.event_store, # May be None (no resumability)
226+
security_settings=self.security_settings,
222227
)
223228

224229
assert http_transport.mcp_session_id is not None

0 commit comments

Comments
 (0)