diff --git a/internal/mcpproxy/handlers.go b/internal/mcpproxy/handlers.go index 178821676d..569f73d9f9 100644 --- a/internal/mcpproxy/handlers.go +++ b/internal/mcpproxy/handlers.go @@ -234,6 +234,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { } err = m.handleSetLoggingLevel(ctx, s, w, msg, p, span) case "ping": + // Ping is intentionally not traced as it's a lightweight health check. err = m.handlePing(ctx, w, msg) case "prompts/list": p := &mcp.ListPromptsParams{} diff --git a/internal/tracing/mcp.go b/internal/tracing/mcp.go index d1afdc2323..3e66ec0a99 100644 --- a/internal/tracing/mcp.go +++ b/internal/tracing/mcp.go @@ -86,7 +86,9 @@ func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Requ parentCtx := m.propagator.Extract(ctx, mc) // Start the span with options appropriate for the semantic convention. - newCtx, span := m.tracer.Start(parentCtx, "mcp.request", trace.WithSpanKind(trace.SpanKindClient)) + // Convert method name to span name following mcp-go SDK patterns + spanName := getSpanName(req.Method) + newCtx, span := m.tracer.Start(parentCtx, spanName, trace.WithSpanKind(trace.SpanKindClient)) // Always inject trace context into the header mutation if provided. // This ensures trace propagation works even for unsampled spans. @@ -174,3 +176,37 @@ func (c metaMapCarrier) Keys() []string { return keys } + +// getSpanName converts MCP method names to span names following mcp-go SDK patterns. +func getSpanName(method string) string { + switch method { + case "initialize": + return "Initialize" + case "tools/list": + return "ListTools" + case "tools/call": + return "CallTool" + case "prompts/list": + return "ListPrompts" + case "prompts/get": + return "GetPrompt" + case "resources/list": + return "ListResources" + case "resources/read": + return "ReadResource" + case "resources/subscribe": + return "Subscribe" + case "resources/unsubscribe": + return "Unsubscribe" + case "resources/templates/list": + return "ListResourceTemplates" + case "logging/setLevel": + return "SetLoggingLevel" + case "completion/complete": + return "Complete" + case "ping": + return "Ping" + default: + return method + } +} diff --git a/internal/tracing/mcp_test.go b/internal/tracing/mcp_test.go index 40d1750cdf..d12a9a546e 100644 --- a/internal/tracing/mcp_test.go +++ b/internal/tracing/mcp_test.go @@ -6,6 +6,7 @@ package tracing import ( + "context" "testing" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -15,6 +16,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" + oteltrace "go.opentelemetry.io/otel/trace" ) func TestTracer_StartSpanAndInjectMeta(t *testing.T) { @@ -134,3 +136,106 @@ func Test_getMCPAttributes(t *testing.T) { }) } } + +func Test_getSpanName(t *testing.T) { + tests := []struct { + method string + expected string + }{ + {method: "initialize", expected: "Initialize"}, + {method: "tools/list", expected: "ListTools"}, + {method: "tools/call", expected: "CallTool"}, + {method: "prompts/list", expected: "ListPrompts"}, + {method: "prompts/get", expected: "GetPrompt"}, + {method: "resources/list", expected: "ListResources"}, + {method: "resources/read", expected: "ReadResource"}, + {method: "resources/subscribe", expected: "Subscribe"}, + {method: "resources/unsubscribe", expected: "Unsubscribe"}, + {method: "resources/templates/list", expected: "ListResourceTemplates"}, + {method: "logging/setLevel", expected: "SetLoggingLevel"}, + {method: "completion/complete", expected: "Complete"}, + {method: "ping", expected: "Ping"}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + actual := getSpanName(tt.method) + require.Equal(t, tt.expected, actual) + }) + } +} + +func TestMCPTracer_SpanName(t *testing.T) { + tests := []struct { + name string + method string + params mcp.Params + expectedSpanName string + }{ + { + name: "tools/list", + method: "tools/list", + params: &mcp.ListToolsParams{}, + expectedSpanName: "ListTools", + }, + { + name: "tools/call", + method: "tools/call", + params: &mcp.CallToolParams{Name: "test-tool"}, + expectedSpanName: "CallTool", + }, + { + name: "prompts/list", + method: "prompts/list", + params: &mcp.ListPromptsParams{}, + expectedSpanName: "ListPrompts", + }, + { + name: "prompts/get", + method: "prompts/get", + params: &mcp.GetPromptParams{Name: "test-prompt"}, + expectedSpanName: "GetPrompt", + }, + { + name: "resources/list", + method: "resources/list", + params: &mcp.ListResourcesParams{}, + expectedSpanName: "ListResources", + }, + { + name: "resources/read", + method: "resources/read", + params: &mcp.ReadResourceParams{URI: "test://uri"}, + expectedSpanName: "ReadResource", + }, + { + name: "initialize", + method: "initialize", + params: &mcp.InitializeParams{}, + expectedSpanName: "Initialize", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + exporter := tracetest.NewInMemoryExporter() + tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) + + tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator()) + + reqID, _ := jsonrpc.MakeID("test-id") + req := &jsonrpc.Request{ID: reqID, Method: tt.method} + + span := tracer.StartSpanAndInjectMeta(context.Background(), req, tt.params) + require.NotNil(t, span) + span.EndSpan() + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + actualSpan := spans[0] + + require.Equal(t, tt.expectedSpanName, actualSpan.Name) + require.Equal(t, oteltrace.SpanKindClient, actualSpan.SpanKind) + }) + } +} diff --git a/tests/extproc/mcp/env.go b/tests/extproc/mcp/env.go index a46f87bf76..1f58ddfff9 100644 --- a/tests/extproc/mcp/env.go +++ b/tests/extproc/mcp/env.go @@ -17,7 +17,6 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/require" - v1 "go.opentelemetry.io/proto/otlp/trace/v1" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/testing/testotel" @@ -241,16 +240,18 @@ func (m *mcpEnv) newSession(t *testing.T) *mcpSession { ret.session, err = m.client.Connect(t.Context(), &mcp.StreamableClientTransport{Endpoint: m.baseURL}, nil) require.NoError(t, err) span := m.collector.TakeSpan() - require.Equal(t, "mcp.request", span.Name) - require.Equal(t, v1.Span_SPAN_KIND_CLIENT, span.Kind) - t.Log("created new MCP session with ID ", ret.session.ID(), ", first span: ", span.String()) - require.NotNil(t, span) + requireMCPSpan(t, span, "Initialize", map[string]string{ + "mcp.method.name": "initialize", + "mcp.client.name": "demo-http-client", + "mcp.client.title": "", + "mcp.client.version": "0.1.0", + }) // **NOTE*** Do not add any direct access to the in-memory server session. Otherwise, the tests will result in // not being able to run in end-to-end tests. The test code must solely operate through the client side sessions. t.Cleanup(func() { - require.NoError(t, ret.session.Close()) + _ = ret.session.Close() m.mux.Lock() defer m.mux.Unlock() delete(m.sessions, ret.session.ID()) diff --git a/tests/extproc/mcp/mcp_test.go b/tests/extproc/mcp/mcp_test.go index f6706a3fd9..9c7c5a83e6 100644 --- a/tests/extproc/mcp/mcp_test.go +++ b/tests/extproc/mcp/mcp_test.go @@ -15,12 +15,15 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/mcp" dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" "github.com/prometheus/common/model" "github.com/stretchr/testify/require" - v1 "go.opentelemetry.io/proto/otlp/common/v1" + commonv1 "go.opentelemetry.io/proto/otlp/common/v1" + tracev1 "go.opentelemetry.io/proto/otlp/trace/v1" + "google.golang.org/protobuf/testing/protocmp" "k8s.io/apimachinery/pkg/util/sets" "github.com/envoyproxy/ai-gateway/tests/internal/testmcp" @@ -31,76 +34,67 @@ const ( defaultMCPPath = "/mcp" ) +var tests = []struct { + name string + testFn func(t *testing.T, m *mcpEnv) +}{ + {name: "ListTools", testFn: testListTools}, + {name: "ToolCall", testFn: testToolCall}, + {name: "ToolCallDumbEcho", testFn: testToolCallDumbEcho}, + {name: "ToolCallError", testFn: testToolCallError}, + {name: "ToolCountDown", testFn: testToolCountDown}, + {name: "Ping", testFn: testPing}, + {name: "LoggingSetLevel", testFn: testLoggingSetLevel}, + {name: "ListPrompts", testFn: testListPrompts}, + {name: "CodeReviewPrompts", testFn: testCodeReviewPrompts}, + {name: "PromptChangeNotifications", testFn: testPromptChangeNotifications}, + {name: "ListResources", testFn: testListResources}, + {name: "ReadResource", testFn: testReadResource}, + {name: "ReadResourceNotFound", testFn: testReadResourceNotFound}, + {name: "ListResourceTemplates", testFn: testListResourceTemplates}, + {name: "ResourceSubscribe", testFn: testResourceSubscribe}, + {name: "ResourceListChangeNotifications", testFn: testResourceListChangeNotifications}, + {name: "ListRootsAndChangeRoots", testFn: testListRootsAndChangeRoots}, + {name: "SamplingCreateMessage", testFn: testSamplingCreateMessage}, + {name: "Elicit", testFn: testElicit}, + {name: "NotificationCancelled", testFn: testNotificationCancelled}, + {name: "Complete", testFn: testComplete}, +} + func TestMCP(t *testing.T) { - tests := []struct { - name string - testFn func(t *testing.T, m *mcpEnv) - }{ - {name: "ListTools", testFn: testListTools}, - {name: "ToolCall", testFn: testToolCall}, - {name: "ToolCallDumbEcho", testFn: testToolCallDumbEcho}, - {name: "ToolCallError", testFn: testToolCallError}, - {name: "ToolCountDown", testFn: testToolCountDown}, - {name: "Ping", testFn: testPing}, - {name: "LoggingSetLevel", testFn: testLoggingSetLevel}, - {name: "ListPrompts", testFn: testListPrompts}, - {name: "CodeReviewPrompts", testFn: testCodeReviewPrompts}, - {name: "PromptChangeNotifications", testFn: testPromptChangeNotifications}, - {name: "ListResources", testFn: testListResources}, - {name: "ReadResource", testFn: testReadResource}, - {name: "ReadResourceNotFound", testFn: testReadResourceNotFound}, - {name: "ListResourceTemplates", testFn: testListResourceTemplates}, - {name: "ResourceSubscribe", testFn: testResourceSubscribe}, - {name: "ResourceListChangeNotifications", testFn: testResourceListChangeNotifications}, - {name: "ListRootsAndChangeRoots", testFn: testListRootsAndChangeRoots}, - {name: "SamplingCreateMessage", testFn: testSamplingCreateMessage}, - {name: "Elicit", testFn: testElicit}, - {name: "NotificationCancelled", testFn: testNotificationCancelled}, - {name: "Complete", testFn: testComplete}, + env := requireNewMCPEnv(t, false, 1200*time.Second, defaultMCPPath) + for _, tc := range tests { + t.Run(tc.name+"/force_json=false", func(t *testing.T) { + tc.testFn(t, env) + }) } - _ = tests - - t.Run("default/force_json=true", func(t *testing.T) { - env := requireNewMCPEnv(t, true, 1200*time.Second, defaultMCPPath) - for _, tc := range tests { - t.Run(tc.name+"/force_json=true", func(t *testing.T) { - tc.testFn(t, env) - }) - } - }) - t.Run("default/force_json=false", func(t *testing.T) { - env := requireNewMCPEnv(t, false, 1200*time.Second, defaultMCPPath) - for _, tc := range tests { - t.Run(tc.name+"/force_json=false", func(t *testing.T) { - tc.testFn(t, env) - }) - } - }) - t.Run("custom_write_timeout", func(t *testing.T) { - env := requireNewMCPEnv(t, false, 2*time.Second, defaultMCPPath) - for _, tc := range tests { - t.Run(tc.name+"/custom_write_timeout", func(t *testing.T) { - tc.testFn(t, env) - }) - } - }) - t.Run("/mcp/yet/another/path", func(t *testing.T) { - env := requireNewMCPEnv(t, false, 1200*time.Second, "/mcp/yet/another/path") - t.Run("call", func(t *testing.T) { - testToolCallDumbEcho(t, env) +} + +func TestMCP_forceJSONResponse(t *testing.T) { + env := requireNewMCPEnv(t, true, 1200*time.Second, defaultMCPPath) + for _, tc := range tests { + t.Run(tc.name+"/force_json=true", func(t *testing.T) { + tc.testFn(t, env) }) - t.Run("list", func(t *testing.T) { - testListToolsRequireOnlyDumb(t, env) + } +} + +func TestMCP_customWriteTimeout(t *testing.T) { + env := requireNewMCPEnv(t, false, 2*time.Second, defaultMCPPath) + for _, tc := range tests { + t.Run(tc.name+"/custom_write_timeout", func(t *testing.T) { + tc.testFn(t, env) }) + } +} + +func TestMCP_differentPath(t *testing.T) { + env := requireNewMCPEnv(t, false, 1200*time.Second, "/mcp/yet/another/path") + t.Run("call", func(t *testing.T) { + testToolCallDumbEcho(t, env) }) - t.Run("/awesome-path", func(t *testing.T) { - env := requireNewMCPEnv(t, false, 1200*time.Second, "/awesome-path") - t.Run("call", func(t *testing.T) { - testToolCallDumbEcho(t, env) - }) - t.Run("list", func(t *testing.T) { - testListToolsRequireOnlyDumb(t, env) - }) + t.Run("list", func(t *testing.T) { + testListToolsRequireOnlyDumb(t, env) }) } @@ -112,10 +106,9 @@ func testListTools(t *testing.T, m *mcpEnv) { for _, tool := range tools.Tools { names = append(names, tool.Name) } - span := m.collector.TakeSpan() - requireKeyValue(t, span.Attributes, "mcp.method.name", stringAnyValue("tools/list")) - // equal to the count of backends. - require.Len(t, span.Events, 2) + requireMCPSpan(t, m.collector.TakeSpan(), "ListTools", map[string]string{ + "mcp.method.name": "tools/list", + }) // Hardcode names rather than using testmcp.*Tool.Tool.Name because some // tools are created dynamically (e.g. add_prompt). @@ -162,10 +155,7 @@ func testToolCall(t *testing.T, m *mcpEnv) { require.Len(t, res.Content, 1) require.IsType(t, &mcp.TextContent{}, res.Content[0]) require.Equal(t, helloText, res.Content[0].(*mcp.TextContent).Text) - span := m.collector.TakeSpan() - requireKeyValue(t, span.Attributes, "mcp.method.name", stringAnyValue("tools/call")) - require.Len(t, span.Events, 1) - require.Equal(t, "route to backend", span.Events[0].Name) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolEcho.Tool.Name, false, "") res, err = s.session.CallTool(t.Context(), &mcp.CallToolParams{ Name: "default-mcp-backend__" + testmcp.ToolSum.Tool.Name, @@ -176,8 +166,7 @@ func testToolCall(t *testing.T, m *mcpEnv) { require.Len(t, res.Content, 1) require.IsType(t, &mcp.TextContent{}, res.Content[0]) require.Equal(t, "42", res.Content[0].(*mcp.TextContent).Text) - span = m.collector.TakeSpan() - requireKeyValue(t, span.Attributes, "mcp.method.name", stringAnyValue("tools/call")) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolSum.Tool.Name, false, "") } func testToolCallDumbEcho(t *testing.T, m *mcpEnv) { @@ -193,6 +182,7 @@ func testToolCallDumbEcho(t *testing.T, m *mcpEnv) { require.Len(t, res.Content, 1) require.IsType(t, &mcp.TextContent{}, res.Content[0]) require.Equal(t, "dumb echo: "+helloText, res.Content[0].(*mcp.TextContent).Text) + requireToolSpan(t, m.collector.TakeSpan(), "dumb-mcp-backend", testmcp.ToolDumbEcho.Tool.Name, false, "") } func testToolCallError(t *testing.T, m *mcpEnv) { @@ -210,6 +200,7 @@ func testToolCallError(t *testing.T, m *mcpEnv) { require.Len(t, res.Content, 1) require.IsType(t, &mcp.TextContent{}, res.Content[0]) require.Equal(t, errTool, res.Content[0].(*mcp.TextContent).Text) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolError.Tool.Name, false, "") }) // Protocol errors or tool invocation errors (such as validation errors) are @@ -222,6 +213,7 @@ func testToolCallError(t *testing.T, m *mcpEnv) { require.Error(t, err) require.Nil(t, res) require.Contains(t, err.Error(), "minLength") + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolError.Tool.Name, false, "") }) } @@ -243,6 +235,10 @@ func testToolCountDown(t *testing.T, m *mcpEnv) { Level: "error", }) require.NoError(t, err) + requireMCPSpan(t, m.collector.TakeSpan(), "SetLoggingLevel", map[string]string{ + "mcp.method.name": "logging/setLevel", + "mcp.logging.level": "error", + }) var res *mcp.CallToolResult callErrorCh := make(chan error, 1) @@ -255,6 +251,7 @@ func testToolCountDown(t *testing.T, m *mcpEnv) { callErrorCh <- err doneBool.Store(true) }() + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolCountDown.Tool.Name, false, "") // we cannot assume the order of notifications, so we use a set to track them. counts := sets.New[int]() @@ -280,6 +277,7 @@ func testToolCountDown(t *testing.T, m *mcpEnv) { require.Equal(t, expectedMsg, notif.Params.Message) require.Contains(t, counts.UnsortedList(), n) counts.Delete(n) + // Progress notifications from server-to-client do not create spans in the gateway. } // we cannot assume the order of logging messages, so we just check if we received one. @@ -297,6 +295,7 @@ func testToolCountDown(t *testing.T, m *mcpEnv) { require.Equal(t, mcp.LoggingLevel("error"), param.Level) require.Contains(t, param.Data, "count down: ") expectedLogs.Delete(param.Data.(string)) + // Logging messages from server-to-client do not create spans in the gateway. } require.Empty(t, expectedLogs) @@ -318,6 +317,7 @@ func testPing(t *testing.T, m *mcpEnv) { err := s.session.Ping(t.Context(), &mcp.PingParams{}) require.NoError(t, err) + // Pings do not create spans in the gateway. } func testLoggingSetLevel(t *testing.T, m *mcpEnv) { @@ -326,6 +326,10 @@ func testLoggingSetLevel(t *testing.T, m *mcpEnv) { Level: "debug", }) require.NoError(t, err) + requireMCPSpan(t, m.collector.TakeSpan(), "SetLoggingLevel", map[string]string{ + "mcp.method.name": "logging/setLevel", + "mcp.logging.level": "debug", + }) } func testListPrompts(t *testing.T, m *mcpEnv) { @@ -335,6 +339,9 @@ func testListPrompts(t *testing.T, m *mcpEnv) { require.Len(t, list.Prompts, 1) require.Equal(t, defaultMCPBackendResourcePrefix+testmcp.CodeReviewPrompt.Name, list.Prompts[0].Name) require.Equal(t, testmcp.CodeReviewPrompt.Description, list.Prompts[0].Description) + requireMCPSpan(t, m.collector.TakeSpan(), "ListPrompts", map[string]string{ + "mcp.method.name": "prompts/list", + }) } func testCodeReviewPrompts(t *testing.T, m *mcpEnv) { @@ -349,6 +356,10 @@ func testCodeReviewPrompts(t *testing.T, m *mcpEnv) { require.Equal(t, mcp.Role("user"), resp.Messages[0].Role) require.IsType(t, &mcp.TextContent{}, resp.Messages[0].Content) require.Contains(t, resp.Messages[0].Content.(*mcp.TextContent).Text, "Please review the following code: 1+1") + requireMCPSpan(t, m.collector.TakeSpan(), "GetPrompt", map[string]string{ + "mcp.method.name": "prompts/get", + "mcp.prompt.name": defaultMCPBackendResourcePrefix + "code_review", + }) } func testPromptChangeNotifications(t *testing.T, m *mcpEnv) { @@ -356,12 +367,16 @@ func testPromptChangeNotifications(t *testing.T, m *mcpEnv) { list, err := s.session.ListPrompts(t.Context(), &mcp.ListPromptsParams{}) require.NoError(t, err) require.Len(t, list.Prompts, 1) + requireMCPSpan(t, m.collector.TakeSpan(), "ListPrompts", map[string]string{ + "mcp.method.name": "prompts/list", + }) res, err := s.session.CallTool(t.Context(), &mcp.CallToolParams{ Name: "default-mcp-backend__" + testmcp.ToolAddPromptName, }) require.NoError(t, err) require.False(t, res.IsError) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolAddPromptName, false, "") // Wait for the notification. var req *mcp.PromptListChangedRequest @@ -378,6 +393,9 @@ func testPromptChangeNotifications(t *testing.T, m *mcpEnv) { list, err = s.session.ListPrompts(t.Context(), &mcp.ListPromptsParams{}) require.NoError(t, err) require.Len(t, list.Prompts, 2) + requireMCPSpan(t, m.collector.TakeSpan(), "ListPrompts", map[string]string{ + "mcp.method.name": "prompts/list", + }) } func testListResources(t *testing.T, m *mcpEnv) { @@ -387,6 +405,9 @@ func testListResources(t *testing.T, m *mcpEnv) { require.Len(t, list.Resources, 1) require.Equal(t, defaultMCPBackendResourcePrefix+testmcp.DummyResource.Name, list.Resources[0].Name) require.Equal(t, testmcp.DummyResource.Description, list.Resources[0].Description) + requireMCPSpan(t, m.collector.TakeSpan(), "ListResources", map[string]string{ + "mcp.method.name": "resources/list", + }) } func testReadResource(t *testing.T, m *mcpEnv) { @@ -399,6 +420,10 @@ func testReadResource(t *testing.T, m *mcpEnv) { require.Equal(t, testmcp.DummyResource.URI, r.Contents[0].URI) require.Equal(t, testmcp.DummyResource.MIMEType, r.Contents[0].MIMEType) require.Equal(t, "dummy", string(r.Contents[0].Blob)) + requireMCPSpan(t, m.collector.TakeSpan(), "ReadResource", map[string]string{ + "mcp.method.name": "resources/read", + "mcp.resource.uri": defaultMCPBackendResourcePrefix + "file:///dummy.txt", + }) } func testReadResourceNotFound(t *testing.T, m *mcpEnv) { @@ -409,6 +434,10 @@ func testReadResourceNotFound(t *testing.T, m *mcpEnv) { require.Error(t, err) require.ErrorContains(t, err, "Resource not found") require.Nil(t, r) + requireMCPSpan(t, m.collector.TakeSpan(), "ReadResource", map[string]string{ + "mcp.method.name": "resources/read", + "mcp.resource.uri": defaultMCPBackendResourcePrefix + "file:///notfound.txt", + }) } func testListResourceTemplates(t *testing.T, m *mcpEnv) { @@ -418,6 +447,9 @@ func testListResourceTemplates(t *testing.T, m *mcpEnv) { require.Len(t, list.ResourceTemplates, 1) require.Equal(t, defaultMCPBackendResourcePrefix+testmcp.DummyResourceTemplate.Name, list.ResourceTemplates[0].Name) require.Equal(t, testmcp.DummyResourceTemplate.Description, list.ResourceTemplates[0].Description) + requireMCPSpan(t, m.collector.TakeSpan(), "ListResourceTemplates", map[string]string{ + "mcp.method.name": "resources/templates/list", + }) } func testResourceSubscribe(t *testing.T, m *mcpEnv) { @@ -427,11 +459,18 @@ func testResourceSubscribe(t *testing.T, m *mcpEnv) { require.Len(t, list.Resources, 1) require.Equal(t, defaultMCPBackendResourcePrefix+testmcp.DummyResource.Name, list.Resources[0].Name) require.Equal(t, testmcp.DummyResource.Description, list.Resources[0].Description) + requireMCPSpan(t, m.collector.TakeSpan(), "ListResources", map[string]string{ + "mcp.method.name": "resources/list", + }) err = s.session.Subscribe(t.Context(), &mcp.SubscribeParams{ URI: defaultMCPBackendResourcePrefix + list.Resources[0].URI, }) require.NoError(t, err) + requireMCPSpan(t, m.collector.TakeSpan(), "Subscribe", map[string]string{ + "mcp.method.name": "resources/subscribe", + "mcp.resource.uri": defaultMCPBackendResourcePrefix + list.Resources[0].URI, + }) // Update the resource. res, err := s.session.CallTool(t.Context(), &mcp.CallToolParams{ @@ -440,8 +479,9 @@ func testResourceSubscribe(t *testing.T, m *mcpEnv) { }) require.NoError(t, err) require.False(t, res.IsError) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolResourceUpdateNotificationName, false, "") // Wait for the subscribe notification. - requireEventuallyNotificationCountMessages(t, s, "subscribe: 1") + requireEventuallyNotificationCountMessages(t, s, m, "subscribe: 1") // Wait for the notification. var req *mcp.ResourceUpdatedNotificationRequest @@ -458,8 +498,12 @@ func testResourceSubscribe(t *testing.T, m *mcpEnv) { URI: defaultMCPBackendResourcePrefix + list.Resources[0].URI, }) require.NoError(t, err) + requireMCPSpan(t, m.collector.TakeSpan(), "Unsubscribe", map[string]string{ + "mcp.method.name": "resources/unsubscribe", + "mcp.resource.uri": defaultMCPBackendResourcePrefix + list.Resources[0].URI, + }) // Wait for the unsubscribe notification. - requireEventuallyNotificationCountMessages(t, s, "unsubscribe: 1") + requireEventuallyNotificationCountMessages(t, s, m, "unsubscribe: 1") res, err = s.session.CallTool(t.Context(), &mcp.CallToolParams{ Name: "default-mcp-backend__" + testmcp.ToolResourceUpdateNotificationName, @@ -467,6 +511,7 @@ func testResourceSubscribe(t *testing.T, m *mcpEnv) { }) require.NoError(t, err) require.False(t, res.IsError) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolResourceUpdateNotificationName, false, "") // Wait for the notification. select { @@ -482,6 +527,9 @@ func testResourceListChangeNotifications(t *testing.T, m *mcpEnv) { list, err := s.session.ListResources(t.Context(), &mcp.ListResourcesParams{}) require.NoError(t, err) require.Len(t, list.Resources, 1) + requireMCPSpan(t, m.collector.TakeSpan(), "ListResources", map[string]string{ + "mcp.method.name": "resources/list", + }) res, err := s.session.CallTool(t.Context(), &mcp.CallToolParams{ Name: "default-mcp-backend__" + testmcp.ToolAddOrDeleteDummyResourceName, @@ -489,6 +537,8 @@ func testResourceListChangeNotifications(t *testing.T, m *mcpEnv) { }) require.NoError(t, err) require.False(t, res.IsError) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolAddOrDeleteDummyResourceName, false, "") + // Clean up, otherwise it will affect ListResources in other tests. t.Cleanup(func() { res, err = s.session.CallTool(context.Background(), &mcp.CallToolParams{ @@ -497,6 +547,8 @@ func testResourceListChangeNotifications(t *testing.T, m *mcpEnv) { }) require.NoError(t, err) require.False(t, res.IsError) + // Consume the span from the cleanup operation. + _ = m.collector.TakeSpan() }) // Wait for the notification. @@ -514,6 +566,9 @@ func testResourceListChangeNotifications(t *testing.T, m *mcpEnv) { list, err = s.session.ListResources(t.Context(), &mcp.ListResourcesParams{}) require.NoError(t, err) require.Len(t, list.Resources, 2) + requireMCPSpan(t, m.collector.TakeSpan(), "ListResources", map[string]string{ + "mcp.method.name": "resources/list", + }) } func testListRootsAndChangeRoots(t *testing.T, m *mcpEnv) { @@ -527,12 +582,16 @@ func testListRootsAndChangeRoots(t *testing.T, m *mcpEnv) { require.Len(t, res.Content, 1) require.IsType(t, &mcp.TextContent{}, res.Content[0]) require.Contains(t, res.Content[0].(*mcp.TextContent).Text, fmt.Sprintf("root %q found", mcpDefaultRootName)) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolContainsRootTool.Tool.Name, false, "") m.mux.Lock() defer m.mux.Unlock() // This will trigger a notifications/roots/list_changed notification from client to server. m.client.RemoveRoots(mcpDefaultRootURI) - requireEventuallyNotificationCountMessages(t, s, "roots_list_changed: 1") + // Assert the span from the client-to-server notification + requireMCPSpan(t, m.collector.TakeSpan(), "notifications/roots/list_changed", map[string]string{ + "mcp.method.name": "notifications/roots/list_changed", + }) // Now the default root should not be found. ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) @@ -546,9 +605,12 @@ func testListRootsAndChangeRoots(t *testing.T, m *mcpEnv) { require.Len(t, res.Content, 1) require.IsType(t, &mcp.TextContent{}, res.Content[0]) require.Contains(t, res.Content[0].(*mcp.TextContent).Text, fmt.Sprintf("root %q not found", mcpDefaultRootName)) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolContainsRootTool.Tool.Name, false, "") + + requireEventuallyNotificationCountMessages(t, s, m, "roots_list_changed: 1") requireMetricGreaterThan(t, m, "mcp_method_count_total", map[string]string{ - "mcp_method_name": "notifications/resources/list_changed", + "mcp_method_name": "notifications/roots/list_changed", }, 0) requireMetricGreaterThan(t, m, "mcp_method_count_total", map[string]string{ @@ -567,6 +629,7 @@ func testSamplingCreateMessage(t *testing.T, m *mcpEnv) { }) require.NoError(t, err) require.False(t, res.IsError) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolCreateMessage.Tool.Name, false, "") // Wait for the request from the server. var req *mcp.CreateMessageRequest @@ -586,6 +649,11 @@ func testSamplingCreateMessage(t *testing.T, m *mcpEnv) { require.NotNil(t, req.Params) require.IsTypef(t, &mcp.CreateMessageParams{}, req.Params, "expected CreateMessageParams, got %T", req.Params) + // The gateway encodes progress tokens as: base64(original)____ + // where type is 's' for string, 'i' for int, 'f' for float. + // Original token "sampling-foo" becomes "c2FtcGxpbmctZm9v__s__default-mcp-backend". + requireNotificationProgressSpan(t, m.collector.TakeSpan(), "foo", "c2FtcGxpbmctZm9v__s__default-mcp-backend") + requireMetricGreaterThan(t, m, "mcp_progress_notifications_total", nil, 0) } @@ -596,6 +664,7 @@ func testElicit(t *testing.T, m *mcpEnv) { }) require.NoError(t, err) require.False(t, res.IsError) + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolElicitEmail.Tool.Name, false, "") // Wait for the request from the server. var req *mcp.ElicitRequest @@ -607,6 +676,8 @@ func testElicit(t *testing.T, m *mcpEnv) { require.NotNil(t, req) require.NotNil(t, req.Params) require.IsTypef(t, &mcp.ElicitParams{}, req.Params, "expected ElicitParams, got %T", req.Params) + // Elicit requests from server-to-client do not create spans in the gateway. + // These are server-initiated requests handled without tracing. } func testNotificationCancelled(t *testing.T, m *mcpEnv) { @@ -625,6 +696,11 @@ func testNotificationCancelled(t *testing.T, m *mcpEnv) { select { case <-time.After(time.Microsecond * 500): cancel() + // Wait for the goroutine to complete so its span doesn't leak to the next test. + <-doneCh + // Consume the CallTool span from the cancelled operation. + // This span will have an exception event due to cancellation. + requireToolSpan(t, m.collector.TakeSpan(), "default-mcp-backend", testmcp.ToolDelay.Tool.Name, false, "context canceled") // we cannot do the test in TearDownSuite, // we need to wait a while for notifications/cancelled, // metric won't be updated if the test exits too early. @@ -651,6 +727,11 @@ func testComplete(t *testing.T, m *mcpEnv) { require.NoError(t, err) completionValues := []string{"python", "pytorch", "pyside"} require.Equal(t, completionValues, result.Completion.Values) + requireMCPSpan(t, m.collector.TakeSpan(), "Complete", map[string]string{ + "mcp.method.name": "completion/complete", + "mcp.complete.argument.name": "language", + "mcp.complete.argument.value": "py", + }) } var metricParser = expfmt.NewTextParser(model.UTF8Validation) @@ -733,19 +814,22 @@ func requireMetricGreaterThan(t *testing.T, m *mcpEnv, metricName string, metric }, retrieveMetricsTime, retrieveMetricsTick) } -func requireEventuallyNotificationCountMessages(t *testing.T, s *mcpSession, expected string) { +func requireEventuallyNotificationCountMessages(t *testing.T, s *mcpSession, m *mcpEnv, expected string) { require.Eventually(t, func() bool { res, err := s.session.CallTool(t.Context(), &mcp.CallToolParams{ Name: "default-mcp-backend__" + testmcp.ToolNotificationCountsName, }) if err != nil { t.Log("error calling tool: ", err) + _ = m.collector.TakeSpan() // Drain the span from the failed call return false } if res.IsError { t.Log("tool returned error: ", res.Content) + _ = m.collector.TakeSpan() // Drain the span from the error call return false } + _ = m.collector.TakeSpan() // Drain the span - we're polling so we don't assert intermediate spans for _, content := range res.Content { txt, ok := content.(*mcp.TextContent) @@ -760,22 +844,102 @@ func requireEventuallyNotificationCountMessages(t *testing.T, s *mcpSession, exp }, time.Second*3, time.Millisecond*500) } -func stringAnyValue(s string) *v1.AnyValue { - return &v1.AnyValue{ - Value: &v1.AnyValue_StringValue{ - StringValue: s, - }, +// requireMCPSpan verifies that a span has the expected name and attributes. +// It combines base MCP attributes with additional attributes provided by the caller, +// then compares the entire attribute map against the span's attributes. +func requireMCPSpan(t *testing.T, span *tracev1.Span, expectedName string, additionalAttrs map[string]string) { + t.Helper() + require.NotNil(t, span, "expected span but got nil") + require.Equalf(t, expectedName, span.Name, "span name mismatch, full span: %s", span.String()) + + // Extract all attributes from span into map[string]string + attrsFromSpan := make(map[string]string) + for _, attr := range span.Attributes { + if attr.Value.Value != nil { + if _, ok := attr.Value.Value.(*commonv1.AnyValue_StringValue); ok { + attrsFromSpan[attr.Key] = attr.Value.GetStringValue() + } + } } + + // Combine base attributes with additional attributes + combined := make(map[string]string) + // Base attributes that are always present + combined["mcp.protocol.version"] = "2025-06-18" + combined["mcp.transport"] = "http" + // mcp.request.id is dynamic, so we copy it from span + if reqID, ok := attrsFromSpan["mcp.request.id"]; ok { + combined["mcp.request.id"] = reqID + } + // Add additional attributes provided by caller + for k, v := range additionalAttrs { + combined[k] = v + } + + require.Equalf(t, combined, attrsFromSpan, "span attributes mismatch, full span: %s", span.String()) } -func requireKeyValue(t *testing.T, attrs []*v1.KeyValue, key string, val *v1.AnyValue) { - found := false - for _, a := range attrs { - if a.Key == key { - found = a.Value.String() == val.String() - break +// requireToolSpan verifies a CallTool span contains the expected attributes and events. +// +// - backendName: the MCP backend name (e.g. "default-mcp-backend") +// - toolName: the unprefixed tool name (e.g. "echo"). The function will verify the +// span contains the full prefixed name: backendName + "__" + toolName +// - isNew: whether this is a new backend session (mcp.session.new attribute) +// - exceptionMessage: expected exception message substring, or empty string +func requireToolSpan(t *testing.T, span *tracev1.Span, backendName string, toolName string, isNew bool, exceptionMessage string) { + t.Helper() + + requireMCPSpan(t, span, "CallTool", map[string]string{ + "mcp.method.name": "tools/call", + "mcp.tool.name": backendName + "__" + toolName, + }) + // Verify the "route to backend" event and optionally exception event + expectedEvents := []*tracev1.Span_Event{ + { + Name: "route to backend", + Attributes: []*commonv1.KeyValue{ + {Key: "mcp.backend.name", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: backendName}}}, + {Key: "mcp.session.new", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_BoolValue{BoolValue: isNew}}}, + }, + }, + } + if exceptionMessage != "" { + expectedEvents = append(expectedEvents, &tracev1.Span_Event{ + Name: "exception", + Attributes: []*commonv1.KeyValue{ + {Key: "exception.type", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "internal_error"}}}, + {Key: "exception.message", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: exceptionMessage}}}, + }, + }) + } + + // Normalize span attributes for comparison + for _, event := range span.Events { + event.TimeUnixNano = 0 + var filteredAttrs []*commonv1.KeyValue + for _, attr := range event.Attributes { + if attr.Key == "mcp.session.id" { + continue // the call site won't know the backend session ID + } + if attr.Key == "exception.message" && exceptionMessage != "" { + // exception messages are substring match due to IPs, etc. + actualMsg := attr.Value.GetStringValue() + require.Contains(t, actualMsg, exceptionMessage) + attr.Value = &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: exceptionMessage}} + } + filteredAttrs = append(filteredAttrs, attr) } + event.Attributes = filteredAttrs } + require.Empty(t, cmp.Diff(expectedEvents, span.Events, protocmp.Transform())) +} - require.Truef(t, found, "%s=%s not found: %v", key, val.String(), attrs) +// requireNotificationProgressSpan verifies a notifications/progress span with the expected attributes. +func requireNotificationProgressSpan(t *testing.T, span *tracev1.Span, message string, token string) { + t.Helper() + requireMCPSpan(t, span, "notifications/progress", map[string]string{ + "mcp.method.name": "notifications/progress", + "mcp.notifications.progress.message": message, + "mcp.notifications.progress.token": token, + }) }