Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [

[project.optional-dependencies]
instruments = [
"mcp >= 1.3.0",
"mcp >= 1.6.0",
]

[project.entry-points.opentelemetry_instrumentor]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any, Awaitable, Callable, Collection, TypeVar, cast
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast

from opentelemetry import context, propagate
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore
from opentelemetry.instrumentation.utils import unwrap
from wrapt import register_post_import_hook, wrap_function_wrapper
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper

from openinference.instrumentation.mcp.package import _instruments

T = TypeVar("T")


class MCPInstrumentor(BaseInstrumentor): # type: ignore
"""
Expand All @@ -19,51 +19,154 @@ def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs: Any) -> None:
register_post_import_hook(self._patch, "mcp")

def _patch(self, module: Any) -> None:
wrap_function_wrapper(
"mcp.client.session",
"ClientSession.send_request",
self._client_request_wrapper,
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.sse", "sse_client", self._transport_wrapper
),
"mcp.client.sse",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.server.sse", "SseServerTransport.connect_sse", self._transport_wrapper
),
"mcp.server.sse",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.stdio", "stdio_client", self._transport_wrapper
),
"mcp.client.stdio",
)
wrap_function_wrapper(
"mcp.server.lowlevel.server",
"Server._handle_request",
self._server_request_wrapper,
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.server.stdio", "stdio_server", self._transport_wrapper
),
"mcp.server.stdio",
)

# While we prefer to instrument the lowest level primitive, the transports above, it doesn't
# mean context will be propagated to handlers automatically. Notably, the MCP SDK passes
# server messages to a handler with a separate stream in between, losing context. We go
# ahead and instrument this second stream just to propagate context so transports can still
# be used independently while also supporting the major usage of the MCP SDK. Notably, this
# may be a reasonable generic instrumentation for anyio itself to allow its streams to
# propagate context broadly.
Comment on lines +47 to +53
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation here.

register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.server.session", "ServerSession.__init__", self._base_session_init_wrapper
),
"mcp.server.session",
)

def _uninstrument(self, **kwargs: Any) -> None:
unwrap("mcp.client.session.ClientSession", "send_request")
unwrap("mcp.server.lowlevel.server", "_handle_request")
unwrap("mcp.client.stdio", "stdio_client")
unwrap("mcp.server.stdio", "stdio_server")

@asynccontextmanager
async def _transport_wrapper(
self, wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
) -> AsyncGenerator[Tuple["InstrumentedStreamReader", "InstrumentedStreamWriter"], None]:
async with wrapped(*args, **kwargs) as (read_stream, write_stream):
yield InstrumentedStreamReader(read_stream), InstrumentedStreamWriter(write_stream)

def _base_session_init_wrapper(
self, wrapped: Callable[..., None], instance: Any, args: Any, kwargs: Any
) -> None:
wrapped(*args, **kwargs)
reader = getattr(instance, "_incoming_message_stream_reader", None)
writer = getattr(instance, "_incoming_message_stream_writer", None)
if reader and writer:
setattr(
instance, "_incoming_message_stream_reader", ContextAttachingStreamReader(reader)
)
setattr(instance, "_incoming_message_stream_writer", ContextSavingStreamWriter(writer))


class InstrumentedStreamReader(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
async def __aenter__(self) -> Any:
return await self.__wrapped__.__aenter__()

async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)

async def __aiter__(self) -> AsyncGenerator[Any, None]:
from mcp.types import JSONRPCMessage, JSONRPCRequest

async for item in self.__wrapped__:
request = cast(JSONRPCMessage, item).root

if not isinstance(request, JSONRPCRequest):
yield item
continue

def _client_request_wrapper(
self, wrapped: Callable[..., T], instance: Any, args: Any, kwargs: Any
) -> T:
from mcp.types import JSONRPCMessage, Request, RequestParams
if request.params:
meta = request.params.get("_meta")
if meta:
ctx = propagate.extract(meta)
restore = context.attach(ctx)
try:
yield item
continue
finally:
context.detach(restore)
yield item

message = cast(JSONRPCMessage, args[0])
request = cast(Request[RequestParams, Any], message.root)

class InstrumentedStreamWriter(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
async def __aenter__(self) -> Any:
return await self.__wrapped__.__aenter__()

async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)

async def send(self, item: Any) -> Any:
from mcp.types import JSONRPCMessage, JSONRPCRequest

request = cast(JSONRPCMessage, item).root
if not isinstance(request, JSONRPCRequest):
return await self.__wrapped__.send(item)
meta = None
if not request.params:
request.params = RequestParams()
if not request.params.meta:
request.params.meta = RequestParams.Meta()
propagate.get_global_textmap().inject(request.params.meta.__pydantic_extra__)
return wrapped(*args, **kwargs)

async def _server_request_wrapper(
self, wrapped: Callable[..., Awaitable[T]], instance: Any, args: Any, kwargs: Any
) -> T:
from mcp.types import Request, RequestParams

request = cast(Request[RequestParams, Any], args[1])
if hasattr(request, "params") and hasattr(request.params, "meta"):
meta = request.params.meta
if meta and hasattr(meta, "__pydantic_extra__"):
ctx = propagate.extract(meta.__pydantic_extra__)
restore = context.attach(ctx)
try:
return await wrapped(*args, **kwargs)
finally:
context.detach(restore)
return await wrapped(*args, **kwargs)
request.params = {}
meta = request.params.setdefault("_meta", {})
propagate.get_global_textmap().inject(meta)
return await self.__wrapped__.send(item)


@dataclass(slots=True, frozen=True)
class ItemWithContext:
item: Any
ctx: context.Context


class ContextSavingStreamWriter(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
async def __aenter__(self) -> Any:
return await self.__wrapped__.__aenter__()

async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)

async def send(self, item: Any) -> Any:
ctx = context.get_current()
return await self.__wrapped__.send(ItemWithContext(item, ctx))


class ContextAttachingStreamReader(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
async def __aenter__(self) -> Any:
return await self.__wrapped__.__aenter__()

async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)

async def __aiter__(self) -> AsyncGenerator[Any, None]:
async for item in self.__wrapped__:
item_with_context = cast(ItemWithContext, item)
restore = context.attach(item_with_context.ctx)
try:
yield item_with_context.item
finally:
context.detach(restore)
Original file line number Diff line number Diff line change
@@ -1 +1 @@
_instruments = ("mcp >= 1.3.0",)
_instruments = ("mcp >= 1.6.0",)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mcp==1.3.0
mcp==1.6.0

httpx
opentelemetry-exporter-otlp-proto-http
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from typing import Literal, cast

from mcp.server.fastmcp import FastMCP
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
Expand All @@ -19,13 +18,20 @@

MCPInstrumentor().instrument(tracer_provider=tracer_provider)

# Make sure instrumentation is loaded before MCP.
from mcp.server.fastmcp import Context, FastMCP # noqa: E402

from tests.whoami import TestClientResult, WhoamiRequest # noqa: E402

server = FastMCP(port=0)


@server.tool()
def hello() -> str:
async def hello(ctx: Context) -> str: # type: ignore
with tracer.start_as_current_span("hello"):
return "World!"
response = await ctx.session.send_request(WhoamiRequest(method="whoami"), TestClientResult)
name = response.root.name
return f"Hello {name}!"


try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,63 @@

import pytest
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import TextContent
from mcp.shared.session import RequestResponder
from mcp.types import ClientResult, ServerNotification, ServerRequest, TextContent
from opentelemetry.trace import Tracer

from tests.collector import OTLPServer, Telemetry
from tests.whoami import TestClientResult, TestServerRequest, WhoamiResult


# The way MCP SDK creates async tasks means we need this to be called inline with the test,
# not as a fixture.
@asynccontextmanager
async def mcp_client(transport: str, otlp_endpoint: str) -> AsyncGenerator[ClientSession, None]:
async def mcp_client(
transport: str, tracer: Tracer, otlp_endpoint: str
) -> AsyncGenerator[ClientSession, None]:
# Lazy import to get instrumented versions. Users will use opentelemetry-instrument or otherwise
# initialize instrumentation as early as possible and should not run into issues, but we control
# instrumentation through fixtures instead.
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client

async def message_handler(
message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception,
) -> None:
if not isinstance(message, RequestResponder) or message.request.root.method != "whoami":
return
with message as responder, tracer.start_as_current_span("whoami"):
await responder.respond(TestClientResult(WhoamiResult(name="OpenInference"))) # type: ignore

server_script = str(Path(__file__).parent / "mcpserver.py")
pythonpath = str(Path(__file__).parent.parent)
match transport:
case "stdio":
async with stdio_client(
StdioServerParameters(
command=sys.executable,
args=[server_script],
env={"MCP_TRANSPORT": "stdio", "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint},
env={
"MCP_TRANSPORT": "stdio",
"OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint,
"PYTHONPATH": pythonpath,
},
)
) as (reader, writer), ClientSession(reader, writer) as client:
) as (reader, writer), ClientSession(
reader, writer, message_handler=message_handler
) as client:
client._receive_request_type = TestServerRequest
await client.initialize()
yield client
case "sse":
proc = await asyncio.create_subprocess_exec(
sys.executable,
server_script,
env={"MCP_TRANSPORT": "sse", "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint},
env={
"MCP_TRANSPORT": "sse",
"OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint,
"PYTHONPATH": pythonpath,
},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
Expand All @@ -50,7 +78,8 @@ async def mcp_client(transport: str, otlp_endpoint: str) -> AsyncGenerator[Clien
async with sse_client(f"http://localhost:{port}/sse") as (
reader,
writer,
), ClientSession(reader, writer) as client:
), ClientSession(reader, writer, message_handler=message_handler) as client:
client._receive_request_type = TestServerRequest
await client.initialize()
yield client
break
Expand All @@ -63,27 +92,34 @@ async def mcp_client(transport: str, otlp_endpoint: str) -> AsyncGenerator[Clien
async def test_hello(
transport: str, tracer: Tracer, telemetry: Telemetry, otlp_collector: OTLPServer
) -> None:
async with mcp_client(transport, f"http://localhost:{otlp_collector.server_port}/") as client:
async with mcp_client(
transport, tracer, f"http://localhost:{otlp_collector.server_port}/"
) as client:
with tracer.start_as_current_span("root"):
tools_res = await client.list_tools()
assert len(tools_res.tools) == 1
assert tools_res.tools[0].name == "hello"
tool_res = await client.call_tool("hello")
content = tool_res.content[0]
assert isinstance(content, TextContent)
assert content.text == "World!"
assert content.text == "Hello OpenInference!"

assert len(telemetry.traces) == 2
for resource_spans in telemetry.traces:
assert len(resource_spans.scope_spans) == 1
for scope_spans in resource_spans.scope_spans:
assert len(scope_spans.spans) == 1
match scope_spans.scope.name:
case "mcp-test-client":
client_span = scope_spans.spans[0]
for span in scope_spans.spans:
match span.name:
case "root":
root_span = span
case "whoami":
whoami_span = span
case "mcp-test-server":
server_span = scope_spans.spans[0]
assert client_span.name == "root"
assert root_span.name == "root"
assert server_span.name == "hello"
assert server_span.trace_id == client_span.trace_id
assert server_span.parent_span_id == client_span.span_id
assert whoami_span.name == "whoami"
assert server_span.trace_id == root_span.trace_id
assert server_span.parent_span_id == root_span.span_id
assert whoami_span.trace_id == root_span.trace_id
assert whoami_span.parent_span_id == server_span.span_id
Loading