diff --git a/python/instrumentation/openinference-instrumentation-mcp/pyproject.toml b/python/instrumentation/openinference-instrumentation-mcp/pyproject.toml index f71a24aefe..07fe9f6096 100644 --- a/python/instrumentation/openinference-instrumentation-mcp/pyproject.toml +++ b/python/instrumentation/openinference-instrumentation-mcp/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ [project.optional-dependencies] instruments = [ - "mcp >= 1.3.0", + "mcp >= 1.6.0", ] [project.entry-points.opentelemetry_instrumentor] diff --git a/python/instrumentation/openinference-instrumentation-mcp/src/openinference/instrumentation/mcp/__init__.py b/python/instrumentation/openinference-instrumentation-mcp/src/openinference/instrumentation/mcp/__init__.py index d103e2039a..ec859c48d1 100644 --- a/python/instrumentation/openinference-instrumentation-mcp/src/openinference/instrumentation/mcp/__init__.py +++ b/python/instrumentation/openinference-instrumentation-mcp/src/openinference/instrumentation/mcp/__init__.py @@ -1,14 +1,14 @@ -from typing import Any, Awaitable, Callable, Collection, TypeVar, cast +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast from opentelemetry import context, propagate from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore from opentelemetry.instrumentation.utils import unwrap -from wrapt import register_post_import_hook, wrap_function_wrapper +from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper from openinference.instrumentation.mcp.package import _instruments -T = TypeVar("T") - class MCPInstrumentor(BaseInstrumentor): # type: ignore """ @@ -19,51 +19,154 @@ def instrumentation_dependencies(self) -> Collection[str]: return _instruments def _instrument(self, **kwargs: Any) -> None: - register_post_import_hook(self._patch, "mcp") - - def _patch(self, module: Any) -> None: - wrap_function_wrapper( - "mcp.client.session", - "ClientSession.send_request", - self._client_request_wrapper, + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.client.sse", "sse_client", self._transport_wrapper + ), + "mcp.client.sse", + ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.sse", "SseServerTransport.connect_sse", self._transport_wrapper + ), + "mcp.server.sse", + ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.client.stdio", "stdio_client", self._transport_wrapper + ), + "mcp.client.stdio", ) - wrap_function_wrapper( - "mcp.server.lowlevel.server", - "Server._handle_request", - self._server_request_wrapper, + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.stdio", "stdio_server", self._transport_wrapper + ), + "mcp.server.stdio", + ) + + # While we prefer to instrument the lowest level primitive, the transports above, it doesn't + # mean context will be propagated to handlers automatically. Notably, the MCP SDK passes + # server messages to a handler with a separate stream in between, losing context. We go + # ahead and instrument this second stream just to propagate context so transports can still + # be used independently while also supporting the major usage of the MCP SDK. Notably, this + # may be a reasonable generic instrumentation for anyio itself to allow its streams to + # propagate context broadly. + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.session", "ServerSession.__init__", self._base_session_init_wrapper + ), + "mcp.server.session", ) def _uninstrument(self, **kwargs: Any) -> None: - unwrap("mcp.client.session.ClientSession", "send_request") - unwrap("mcp.server.lowlevel.server", "_handle_request") + unwrap("mcp.client.stdio", "stdio_client") + unwrap("mcp.server.stdio", "stdio_server") + + @asynccontextmanager + async def _transport_wrapper( + self, wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any + ) -> AsyncGenerator[Tuple["InstrumentedStreamReader", "InstrumentedStreamWriter"], None]: + async with wrapped(*args, **kwargs) as (read_stream, write_stream): + yield InstrumentedStreamReader(read_stream), InstrumentedStreamWriter(write_stream) + + def _base_session_init_wrapper( + self, wrapped: Callable[..., None], instance: Any, args: Any, kwargs: Any + ) -> None: + wrapped(*args, **kwargs) + reader = getattr(instance, "_incoming_message_stream_reader", None) + writer = getattr(instance, "_incoming_message_stream_writer", None) + if reader and writer: + setattr( + instance, "_incoming_message_stream_reader", ContextAttachingStreamReader(reader) + ) + setattr(instance, "_incoming_message_stream_writer", ContextSavingStreamWriter(writer)) + + +class InstrumentedStreamReader(ObjectProxy): # type: ignore + # ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73 + async def __aenter__(self) -> Any: + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def __aiter__(self) -> AsyncGenerator[Any, None]: + from mcp.types import JSONRPCMessage, JSONRPCRequest + + async for item in self.__wrapped__: + request = cast(JSONRPCMessage, item).root + + if not isinstance(request, JSONRPCRequest): + yield item + continue - def _client_request_wrapper( - self, wrapped: Callable[..., T], instance: Any, args: Any, kwargs: Any - ) -> T: - from mcp.types import JSONRPCMessage, Request, RequestParams + if request.params: + meta = request.params.get("_meta") + if meta: + ctx = propagate.extract(meta) + restore = context.attach(ctx) + try: + yield item + continue + finally: + context.detach(restore) + yield item - message = cast(JSONRPCMessage, args[0]) - request = cast(Request[RequestParams, Any], message.root) + +class InstrumentedStreamWriter(ObjectProxy): # type: ignore + # ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73 + async def __aenter__(self) -> Any: + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def send(self, item: Any) -> Any: + from mcp.types import JSONRPCMessage, JSONRPCRequest + + request = cast(JSONRPCMessage, item).root + if not isinstance(request, JSONRPCRequest): + return await self.__wrapped__.send(item) + meta = None if not request.params: - request.params = RequestParams() - if not request.params.meta: - request.params.meta = RequestParams.Meta() - propagate.get_global_textmap().inject(request.params.meta.__pydantic_extra__) - return wrapped(*args, **kwargs) - - async def _server_request_wrapper( - self, wrapped: Callable[..., Awaitable[T]], instance: Any, args: Any, kwargs: Any - ) -> T: - from mcp.types import Request, RequestParams - - request = cast(Request[RequestParams, Any], args[1]) - if hasattr(request, "params") and hasattr(request.params, "meta"): - meta = request.params.meta - if meta and hasattr(meta, "__pydantic_extra__"): - ctx = propagate.extract(meta.__pydantic_extra__) - restore = context.attach(ctx) - try: - return await wrapped(*args, **kwargs) - finally: - context.detach(restore) - return await wrapped(*args, **kwargs) + request.params = {} + meta = request.params.setdefault("_meta", {}) + propagate.get_global_textmap().inject(meta) + return await self.__wrapped__.send(item) + + +@dataclass(slots=True, frozen=True) +class ItemWithContext: + item: Any + ctx: context.Context + + +class ContextSavingStreamWriter(ObjectProxy): # type: ignore + # ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73 + async def __aenter__(self) -> Any: + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def send(self, item: Any) -> Any: + ctx = context.get_current() + return await self.__wrapped__.send(ItemWithContext(item, ctx)) + + +class ContextAttachingStreamReader(ObjectProxy): # type: ignore + # ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73 + async def __aenter__(self) -> Any: + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def __aiter__(self) -> AsyncGenerator[Any, None]: + async for item in self.__wrapped__: + item_with_context = cast(ItemWithContext, item) + restore = context.attach(item_with_context.ctx) + try: + yield item_with_context.item + finally: + context.detach(restore) diff --git a/python/instrumentation/openinference-instrumentation-mcp/src/openinference/instrumentation/mcp/package.py b/python/instrumentation/openinference-instrumentation-mcp/src/openinference/instrumentation/mcp/package.py index 0b3ae1af0b..e43aa85663 100644 --- a/python/instrumentation/openinference-instrumentation-mcp/src/openinference/instrumentation/mcp/package.py +++ b/python/instrumentation/openinference-instrumentation-mcp/src/openinference/instrumentation/mcp/package.py @@ -1 +1 @@ -_instruments = ("mcp >= 1.3.0",) +_instruments = ("mcp >= 1.6.0",) diff --git a/python/instrumentation/openinference-instrumentation-mcp/test-requirements.txt b/python/instrumentation/openinference-instrumentation-mcp/test-requirements.txt index 890e64accd..6dd1ad33dc 100644 --- a/python/instrumentation/openinference-instrumentation-mcp/test-requirements.txt +++ b/python/instrumentation/openinference-instrumentation-mcp/test-requirements.txt @@ -1,4 +1,4 @@ -mcp==1.3.0 +mcp==1.6.0 httpx opentelemetry-exporter-otlp-proto-http diff --git a/python/instrumentation/openinference-instrumentation-mcp/tests/mcpserver.py b/python/instrumentation/openinference-instrumentation-mcp/tests/mcpserver.py index e0c6ca2b20..eabf3442b7 100644 --- a/python/instrumentation/openinference-instrumentation-mcp/tests/mcpserver.py +++ b/python/instrumentation/openinference-instrumentation-mcp/tests/mcpserver.py @@ -1,7 +1,6 @@ import os from typing import Literal, cast -from mcp.server.fastmcp import FastMCP from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.trace.export import SimpleSpanProcessor @@ -19,13 +18,20 @@ MCPInstrumentor().instrument(tracer_provider=tracer_provider) +# Make sure instrumentation is loaded before MCP. +from mcp.server.fastmcp import Context, FastMCP # noqa: E402 + +from tests.whoami import TestClientResult, WhoamiRequest # noqa: E402 + server = FastMCP(port=0) @server.tool() -def hello() -> str: +async def hello(ctx: Context) -> str: # type: ignore with tracer.start_as_current_span("hello"): - return "World!" + response = await ctx.session.send_request(WhoamiRequest(method="whoami"), TestClientResult) + name = response.root.name + return f"Hello {name}!" try: diff --git a/python/instrumentation/openinference-instrumentation-mcp/tests/test_instrumenter.py b/python/instrumentation/openinference-instrumentation-mcp/tests/test_instrumenter.py index 126c305874..174872ed1e 100644 --- a/python/instrumentation/openinference-instrumentation-mcp/tests/test_instrumenter.py +++ b/python/instrumentation/openinference-instrumentation-mcp/tests/test_instrumenter.py @@ -7,35 +7,63 @@ import pytest from mcp import ClientSession -from mcp.client.sse import sse_client -from mcp.client.stdio import StdioServerParameters, stdio_client -from mcp.types import TextContent +from mcp.shared.session import RequestResponder +from mcp.types import ClientResult, ServerNotification, ServerRequest, TextContent from opentelemetry.trace import Tracer from tests.collector import OTLPServer, Telemetry +from tests.whoami import TestClientResult, TestServerRequest, WhoamiResult # The way MCP SDK creates async tasks means we need this to be called inline with the test, # not as a fixture. @asynccontextmanager -async def mcp_client(transport: str, otlp_endpoint: str) -> AsyncGenerator[ClientSession, None]: +async def mcp_client( + transport: str, tracer: Tracer, otlp_endpoint: str +) -> AsyncGenerator[ClientSession, None]: + # Lazy import to get instrumented versions. Users will use opentelemetry-instrument or otherwise + # initialize instrumentation as early as possible and should not run into issues, but we control + # instrumentation through fixtures instead. + from mcp.client.sse import sse_client + from mcp.client.stdio import StdioServerParameters, stdio_client + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if not isinstance(message, RequestResponder) or message.request.root.method != "whoami": + return + with message as responder, tracer.start_as_current_span("whoami"): + await responder.respond(TestClientResult(WhoamiResult(name="OpenInference"))) # type: ignore + server_script = str(Path(__file__).parent / "mcpserver.py") + pythonpath = str(Path(__file__).parent.parent) match transport: case "stdio": async with stdio_client( StdioServerParameters( command=sys.executable, args=[server_script], - env={"MCP_TRANSPORT": "stdio", "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint}, + env={ + "MCP_TRANSPORT": "stdio", + "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint, + "PYTHONPATH": pythonpath, + }, ) - ) as (reader, writer), ClientSession(reader, writer) as client: + ) as (reader, writer), ClientSession( + reader, writer, message_handler=message_handler + ) as client: + client._receive_request_type = TestServerRequest await client.initialize() yield client case "sse": proc = await asyncio.create_subprocess_exec( sys.executable, server_script, - env={"MCP_TRANSPORT": "sse", "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint}, + env={ + "MCP_TRANSPORT": "sse", + "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint, + "PYTHONPATH": pythonpath, + }, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) @@ -50,7 +78,8 @@ async def mcp_client(transport: str, otlp_endpoint: str) -> AsyncGenerator[Clien async with sse_client(f"http://localhost:{port}/sse") as ( reader, writer, - ), ClientSession(reader, writer) as client: + ), ClientSession(reader, writer, message_handler=message_handler) as client: + client._receive_request_type = TestServerRequest await client.initialize() yield client break @@ -63,7 +92,9 @@ async def mcp_client(transport: str, otlp_endpoint: str) -> AsyncGenerator[Clien async def test_hello( transport: str, tracer: Tracer, telemetry: Telemetry, otlp_collector: OTLPServer ) -> None: - async with mcp_client(transport, f"http://localhost:{otlp_collector.server_port}/") as client: + async with mcp_client( + transport, tracer, f"http://localhost:{otlp_collector.server_port}/" + ) as client: with tracer.start_as_current_span("root"): tools_res = await client.list_tools() assert len(tools_res.tools) == 1 @@ -71,19 +102,24 @@ async def test_hello( tool_res = await client.call_tool("hello") content = tool_res.content[0] assert isinstance(content, TextContent) - assert content.text == "World!" + assert content.text == "Hello OpenInference!" - assert len(telemetry.traces) == 2 for resource_spans in telemetry.traces: - assert len(resource_spans.scope_spans) == 1 for scope_spans in resource_spans.scope_spans: - assert len(scope_spans.spans) == 1 match scope_spans.scope.name: case "mcp-test-client": - client_span = scope_spans.spans[0] + for span in scope_spans.spans: + match span.name: + case "root": + root_span = span + case "whoami": + whoami_span = span case "mcp-test-server": server_span = scope_spans.spans[0] - assert client_span.name == "root" + assert root_span.name == "root" assert server_span.name == "hello" - assert server_span.trace_id == client_span.trace_id - assert server_span.parent_span_id == client_span.span_id + assert whoami_span.name == "whoami" + assert server_span.trace_id == root_span.trace_id + assert server_span.parent_span_id == root_span.span_id + assert whoami_span.trace_id == root_span.trace_id + assert whoami_span.parent_span_id == server_span.span_id diff --git a/python/instrumentation/openinference-instrumentation-mcp/tests/whoami.py b/python/instrumentation/openinference-instrumentation-mcp/tests/whoami.py new file mode 100644 index 0000000000..69b30df5fa --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-mcp/tests/whoami.py @@ -0,0 +1,21 @@ +from typing import Literal + +from mcp.types import Request, RequestParams, Result +from pydantic import RootModel + + +class WhoamiRequest(Request[RequestParams | None, Literal["whoami"]]): + method: Literal["whoami"] + params: RequestParams | None = None + + +class WhoamiResult(Result): + name: str + + +class TestServerRequest(RootModel[WhoamiRequest]): + pass + + +class TestClientResult(RootModel[WhoamiResult]): + pass