@@ -52,6 +52,10 @@ async def handle_sse(request):
5252from starlette .types import Receive , Scope , Send
5353
5454import mcp .types as types
55+ from mcp .server .transport_security import (
56+ TransportSecurityMiddleware ,
57+ TransportSecuritySettings ,
58+ )
5559from mcp .shared .message import ServerMessageMetadata , SessionMessage
5660
5761logger = logging .getLogger (__name__ )
@@ -71,16 +75,22 @@ class SseServerTransport:
7175
7276 _endpoint : str
7377 _read_stream_writers : dict [UUID , MemoryObjectSendStream [SessionMessage | Exception ]]
78+ _security : TransportSecurityMiddleware
7479
75- def __init__ (self , endpoint : str ) -> None :
80+ def __init__ (self , endpoint : str , security_settings : TransportSecuritySettings | None = None ) -> None :
7681 """
7782 Creates a new SSE server transport, which will direct the client to POST
7883 messages to the relative or absolute URL given.
84+
85+ Args:
86+ endpoint: The relative or absolute URL for POST messages.
87+ security_settings: Optional security settings for DNS rebinding protection.
7988 """
8089
8190 super ().__init__ ()
8291 self ._endpoint = endpoint
8392 self ._read_stream_writers = {}
93+ self ._security = TransportSecurityMiddleware (security_settings )
8494 logger .debug (f"SseServerTransport initialized with endpoint: { endpoint } " )
8595
8696 @asynccontextmanager
@@ -89,6 +99,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8999 logger .error ("connect_sse received non-HTTP request" )
90100 raise ValueError ("connect_sse can only handle HTTP requests" )
91101
102+ # Validate request headers for DNS rebinding protection
103+ request = Request (scope , receive )
104+ error_response = await self ._security .validate_request (request , is_post = False )
105+ if error_response :
106+ await error_response (scope , receive , send )
107+ raise ValueError ("Request validation failed" )
108+
92109 logger .debug ("Setting up SSE connection" )
93110 read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ]
94111 read_stream_writer : MemoryObjectSendStream [SessionMessage | Exception ]
@@ -160,6 +177,11 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
160177 logger .debug ("Handling POST message" )
161178 request = Request (scope , receive )
162179
180+ # Validate request headers for DNS rebinding protection
181+ error_response = await self ._security .validate_request (request , is_post = True )
182+ if error_response :
183+ return await error_response (scope , receive , send )
184+
163185 session_id_param = request .query_params .get ("session_id" )
164186 if session_id_param is None :
165187 logger .warning ("Received request without session_id" )
0 commit comments