diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 78f99e20..e99aa2ae 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -96,17 +96,17 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes } var rpcRequest = message as JsonRpcRequest; - JsonRpcMessage? rpcResponseCandidate = null; + JsonRpcMessageWithId? rpcResponseOrError = null; if (response.Content.Headers.ContentType?.MediaType == "application/json") { var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - rpcResponseCandidate = await ProcessMessageAsync(responseContent, cancellationToken).ConfigureAwait(false); + rpcResponseOrError = await ProcessMessageAsync(responseContent, rpcRequest, cancellationToken).ConfigureAwait(false); } else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream") { using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken); - rpcResponseCandidate = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false); + rpcResponseOrError = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false); } if (rpcRequest is null) @@ -114,12 +114,12 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes return response; } - if (rpcResponseCandidate is not JsonRpcMessageWithId messageWithId || messageWithId.Id != rpcRequest.Id) + if (rpcResponseOrError is null) { throw new McpException($"Streamable HTTP POST response completed without a reply to request with ID: {rpcRequest.Id}"); } - if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseCandidate is JsonRpcResponse) + if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseOrError is JsonRpcResponse) { // We've successfully initialized! Copy session-id and start GET request if any. if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues)) @@ -193,20 +193,20 @@ private async Task ReceiveUnsolicitedMessagesAsync() continue; } - var message = await ProcessMessageAsync(sseEvent.Data, cancellationToken).ConfigureAwait(false); + var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false); - // The server SHOULD end the response here anyway, but we won't leave it to chance. This transport makes + // The server SHOULD end the HTTP response body here anyway, but we won't leave it to chance. This transport makes // a GET request for any notifications that might need to be sent after the completion of each POST. - if (message is JsonRpcMessageWithId messageWithId && relatedRpcRequest?.Id == messageWithId.Id) + if (rpcResponseOrError is not null) { - return messageWithId; + return rpcResponseOrError; } } return null; } - private async Task ProcessMessageAsync(string data, CancellationToken cancellationToken) + private async Task ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) { try { @@ -218,7 +218,12 @@ private async Task ReceiveUnsolicitedMessagesAsync() } await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - return message; + if (message is JsonRpcResponse or JsonRpcError && + message is JsonRpcMessageWithId rpcResponseOrError && + rpcResponseOrError.Id == relatedRpcRequest?.Id) + { + return rpcResponseOrError; + } } catch (JsonException ex) { diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 27b8ccdb..174cf9bd 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -60,7 +60,7 @@ public async ValueTask DisposeAsync() { yield return message; - if (message.Data is JsonRpcMessageWithId response && response.Id == _pendingRequest) + if (message.Data is JsonRpcResponse or JsonRpcError && ((JsonRpcMessageWithId)message.Data).Id == _pendingRequest) { // Complete the SSE response stream now that all pending requests have been processed. break; diff --git a/tests/Common/Utils/LoggedTest.cs b/tests/Common/Utils/LoggedTest.cs index a2e9e2ba..646f248a 100644 --- a/tests/Common/Utils/LoggedTest.cs +++ b/tests/Common/Utils/LoggedTest.cs @@ -12,16 +12,16 @@ public LoggedTest(ITestOutputHelper testOutputHelper) { CurrentTestOutputHelper = testOutputHelper, }; - LoggerProvider = new XunitLoggerProvider(_delegatingTestOutputHelper); + XunitLoggerProvider = new XunitLoggerProvider(_delegatingTestOutputHelper); LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => { - builder.AddProvider(LoggerProvider); + builder.AddProvider(XunitLoggerProvider); }); } public ITestOutputHelper TestOutputHelper => _delegatingTestOutputHelper; - public ILoggerFactory LoggerFactory { get; } - public ILoggerProvider LoggerProvider { get; } + public ILoggerFactory LoggerFactory { get; set; } + public ILoggerProvider XunitLoggerProvider { get; } public virtual void Dispose() { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index ef4277a9..0a423850 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -52,7 +52,7 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat await app.StartAsync(TestContext.Current.CancellationToken); - var mcpClient = await ConnectAsync(requestPath); + await using var mcpClient = await ConnectAsync(requestPath); Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index be8763ae..b2f65f82 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -31,7 +31,7 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat await app.StartAsync(TestContext.Current.CancellationToken); - var mcpClient = await ConnectAsync(requestPath); + await using var mcpClient = await ConnectAsync(requestPath); Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); } @@ -135,7 +135,7 @@ public async Task SseMode_Works_WithSseEndpoint() await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectAsync(options: new() + await using var mcpClient = await ConnectAsync(transportOptions: new() { Endpoint = new Uri("http://localhost/sse"), TransportMode = HttpTransportMode.Sse diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 6d153220..61fc54fa 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -1,9 +1,12 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using System.ComponentModel; using System.Net; using System.Security.Claims; @@ -20,18 +23,21 @@ protected void ConfigureStateless(HttpServerTransportOptions options) options.Stateless = Stateless; } - protected async Task ConnectAsync(string? path = null, SseClientTransportOptions? options = null) + protected async Task ConnectAsync( + string? path = null, + SseClientTransportOptions? transportOptions = null, + McpClientOptions? clientOptions = null) { // Default behavior when no options are provided path ??= UseStreamableHttp ? "/" : "/sse"; - await using var transport = new SseClientTransport(options ?? new SseClientTransportOptions() + await using var transport = new SseClientTransport(transportOptions ?? new SseClientTransportOptions() { Endpoint = new Uri($"http://localhost{path}"), TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }, HttpClient, LoggerFactory); - return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + return await McpClientFactory.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); } [Fact] @@ -71,7 +77,7 @@ IHttpContextAccessor is not currently supported with non-stateless Streamable HT await app.StartAsync(TestContext.Current.CancellationToken); - var mcpClient = await ConnectAsync(); + await using var mcpClient = await ConnectAsync(); var response = await mcpClient.CallToolAsync( "EchoWithUserName", @@ -111,13 +117,90 @@ public async Task Messages_FromNewUser_AreRejected() Assert.Equal(HttpStatusCode.Forbidden, httpRequestException.StatusCode); } - protected ClaimsPrincipal CreateUser(string name) + [Fact] + public async Task Sampling_DoesNotCloseStream_Prematurely() + { + Assert.SkipWhen(Stateless, "Sampling is not supported in stateless mode."); + + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); + + var mockLoggerProvider = new MockLoggerProvider(); + Builder.Logging.AddProvider(mockLoggerProvider); + Builder.Logging.SetMinimumLevel(LogLevel.Debug); + + await using var app = Builder.Build(); + + // Reset the LoggerFactory used by the client to use the MockLoggerProvider as well. + LoggerFactory = app.Services.GetRequiredService(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var sampleCount = 0; + var clientOptions = new McpClientOptions + { + Capabilities = new() + { + Sampling = new() + { + SamplingHandler = async (parameters, _, _) => + { + Assert.NotNull(parameters?.Messages); + var message = Assert.Single(parameters.Messages); + Assert.Equal(Role.User, message.Role); + Assert.Equal("text", message.Content.Type); + Assert.Equal("Test prompt for sampling", message.Content.Text); + + sampleCount++; + return new CreateMessageResult + { + Model = "test-model", + Role = Role.Assistant, + Content = new Content + { + Type = "text", + Text = "Sampling response from client" + } + }; + }, + }, + }, + }; + + await using var mcpClient = await ConnectAsync(clientOptions: clientOptions); + + var result = await mcpClient.CallToolAsync("sampling-tool", new Dictionary + { + ["prompt"] = "Test prompt for sampling" + }, cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(result); + Assert.False(result.IsError); + var textContent = Assert.Single(result.Content); + Assert.Equal("text", textContent.Type); + Assert.Equal("Sampling completed successfully. Client responded: Sampling response from client", textContent.Text); + + Assert.Equal(2, sampleCount); + + // Verify that the tool call and the sampling request both used the same ID to ensure we cover against regressions. + // https://github.com/modelcontextprotocol/csharp-sdk/issues/464 + Assert.Single(mockLoggerProvider.LogMessages, m => + m.Category == "ModelContextProtocol.Client.McpClient" && + m.Message.Contains("request '2' for method 'tools/call'")); + + Assert.Single(mockLoggerProvider.LogMessages, m => + m.Category == "ModelContextProtocol.Server.McpServer" && + m.Message.Contains("request '2' for method 'sampling/createMessage'")); + } + + private ClaimsPrincipal CreateUser(string name) => new ClaimsPrincipal(new ClaimsIdentity( [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)], "TestAuthType", "name", "role")); [McpServerToolType] - protected class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor) + private class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor) { [McpServerTool, Description("Echoes the input back to the client with their user name.")] public string EchoWithUserName(string message) @@ -127,4 +210,37 @@ public string EchoWithUserName(string message) return $"{userName}: {message}"; } } + + [McpServerToolType] + private class SamplingRegressionTools + { + [McpServerTool(Name = "sampling-tool")] + public static async Task SamplingToolAsync(IMcpServer server, string prompt, CancellationToken cancellationToken) + { + // This tool reproduces the scenario described in https://github.com/modelcontextprotocol/csharp-sdk/issues/464 + // 1. The client calls tool with request ID 2, because it's the first request after the initialize request. + // 2. This tool makes two sampling requests which use IDs 1 and 2. + // 3. In the old buggy Streamable HTTP transport code, this would close the SSE response stream, + // because the second sampling request used an ID matching the tool call. + var samplingRequest = new CreateMessageRequestParams + { + Messages = [ + new SamplingMessage + { + Role = Role.User, + Content = new Content + { + Type = "text", + Text = prompt + }, + } + ], + }; + + await server.SampleAsync(samplingRequest, cancellationToken); + var samplingResult = await server.SampleAsync(samplingRequest, cancellationToken); + + return $"Sampling completed successfully. Client responded: {samplingResult.Content.Text}"; + } + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index 597b39de..45aaeb5b 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -18,7 +18,7 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) Builder = WebApplication.CreateSlimBuilder(); Builder.Services.RemoveAll(); Builder.Services.AddSingleton(_inMemoryTransport); - Builder.Services.AddSingleton(LoggerProvider); + Builder.Services.AddSingleton(XunitLoggerProvider); HttpClient = new HttpClient(new SocketsHttpHandler() {