Skip to content
19 changes: 19 additions & 0 deletions src/Common/Polyfills/System/Text/EncodingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ public static int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Spa
}
}
}

/// <summary>
/// Decodes all the bytes in the specified span into a string.
/// </summary>
public static string GetString(this Encoding encoding, ReadOnlySpan<byte> bytes)
{
if (bytes.IsEmpty)
{
return string.Empty;
}

unsafe
{
fixed (byte* bytesPtr = bytes)
{
return encoding.GetString(bytesPtr, bytes.Length);
}
}
}
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Client;
/// <summary>Provides the client side of a stdio-based session transport.</summary>
internal sealed class StdioClientSessionTransport(
StdioClientTransportOptions options, Process process, string endpointName, Queue<string> stderrRollingLog, ILoggerFactory? loggerFactory) :
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory)
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, endpointName, loggerFactory)
{
private readonly StdioClientTransportOptions _options = options;
private readonly Process _process = process;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using System.Buffers;
using System.IO.Pipelines;
using System.Text;
using System.Text.Json;

Expand All @@ -12,7 +14,7 @@ internal class StreamClientSessionTransport : TransportBase

internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false);

private readonly TextReader _serverOutput;
private readonly PipeReader _serverOutputPipe;
private readonly Stream _serverInputStream;
private readonly SemaphoreSlim _sendLock = new(1, 1);
private CancellationTokenSource? _shutdownCts = new();
Expand All @@ -27,9 +29,6 @@ internal class StreamClientSessionTransport : TransportBase
/// <param name="serverOutput">
/// The server's output stream. Messages read from this stream will be received from the server.
/// </param>
/// <param name="encoding">
/// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
/// </param>
/// <param name="endpointName">
/// A name that identifies this transport endpoint in logs.
/// </param>
Expand All @@ -40,18 +39,14 @@ internal class StreamClientSessionTransport : TransportBase
/// This constructor starts a background task to read messages from the server output stream.
/// The transport will be marked as connected once initialized.
/// </remarks>
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory)
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, string endpointName, ILoggerFactory? loggerFactory)
: base(endpointName, loggerFactory)
{
Throw.IfNull(serverInput);
Throw.IfNull(serverOutput);

_serverInputStream = serverInput;
#if NET
_serverOutput = new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#else
_serverOutput = new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#endif
_serverOutputPipe = PipeReader.Create(serverOutput);

SetConnected();

Expand Down Expand Up @@ -102,24 +97,8 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
try
{
LogTransportEnteringReadMessagesLoop(Name);

while (true)
{
if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line)
{
LogTransportEndOfStream(Name);
break;
}

if (string.IsNullOrWhiteSpace(line))
{
continue;
}

LogTransportReceivedMessageSensitive(Name, line);

await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
}
await _serverOutputPipe.ReadLinesAsync(ProcessLineAsync, cancellationToken).ConfigureAwait(false);
LogTransportEndOfStream(Name);
}
catch (OperationCanceledException)
{
Expand All @@ -137,25 +116,43 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
}
}

private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
private async Task ProcessLineAsync(ReadOnlySequence<byte> line, CancellationToken cancellationToken)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, PipeReaderExtensions.GetUtf8String(line));
}

try
{
var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
if (message != null)
JsonRpcMessage? message;
if (line.IsSingleSegment)
{
message = JsonSerializer.Deserialize(line.First.Span, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}
else
{
var reader = new Utf8JsonReader(line, isFinalBlock: true, state: default);
message = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}

if (message is not null)
{
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
else
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, line);
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, PipeReaderExtensions.GetUtf8String(line));
}
}
}
catch (JsonException ex)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseFailedSensitive(Name, line, ex);
LogTransportMessageParseFailedSensitive(Name, PipeReaderExtensions.GetUtf8String(line), ex);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public Task<ITransport> ConnectAsync(CancellationToken cancellationToken = defau
return Task.FromResult<ITransport>(new StreamClientSessionTransport(
_serverInput,
_serverOutput,
encoding: null,
"Client (stream)",
_loggerFactory));
}
Expand Down
79 changes: 79 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/PipeReaderExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using System.Buffers;
using System.IO.Pipelines;
using System.Text;

namespace ModelContextProtocol.Protocol;

/// <summary>Internal helper for reading newline-delimited UTF-8 lines from a <see cref="PipeReader"/>.</summary>
internal static class PipeReaderExtensions
{
/// <summary>
/// Reads newline-delimited lines from <paramref name="reader"/>, invoking
/// <paramref name="processLine"/> for each non-empty line, until the reader signals completion.
/// </summary>
internal static async Task ReadLinesAsync(
this PipeReader reader,
Func<ReadOnlySequence<byte>, CancellationToken, Task> processLine,
CancellationToken cancellationToken)
{
while (true)
{
ReadResult result = await reader.ReadAsync(cancellationToken).ConfigureAwait(false);
ReadOnlySequence<byte> buffer = result.Buffer;

SequencePosition? position;
while ((position = buffer.PositionOf((byte)'\n')) != null)
{
ReadOnlySequence<byte> line = buffer.Slice(0, position.Value);

// Trim trailing \r for Windows-style CRLF line endings.
if (EndsWithCarriageReturn(line))
{
line = line.Slice(0, line.Length - 1);
}

if (!line.IsEmpty)
{
await processLine(line, cancellationToken).ConfigureAwait(false);
}

// Advance past the '\n'.
buffer = buffer.Slice(buffer.GetPosition(1, position.Value));
}

reader.AdvanceTo(buffer.Start, buffer.End);

if (result.IsCompleted)
{
break;
}
}
}

/// <summary>Decodes a UTF-8 <see cref="ReadOnlySequence{T}"/> to a <see cref="string"/>.</summary>
internal static string GetUtf8String(in ReadOnlySequence<byte> sequence) =>
sequence.IsSingleSegment
? Encoding.UTF8.GetString(sequence.First.Span)
: Encoding.UTF8.GetString(sequence.ToArray());

private static bool EndsWithCarriageReturn(in ReadOnlySequence<byte> sequence)
{
if (sequence.IsSingleSegment)
{
ReadOnlySpan<byte> span = sequence.First.Span;
return span.Length > 0 && span[span.Length - 1] == (byte)'\r';
}

// Multi-segment: find the last non-empty segment to check its last byte.
ReadOnlyMemory<byte> last = default;
foreach (ReadOnlyMemory<byte> segment in sequence)
{
if (!segment.IsEmpty)
{
last = segment;
}
}

return !last.IsEmpty && last.Span[last.Length - 1] == (byte)'\r';
}
}
106 changes: 56 additions & 50 deletions src/ModelContextProtocol.Core/Server/StreamServerTransport.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Protocol;
using System.Text;
using System.Buffers;
using System.IO.Pipelines;
using System.Text.Json;

namespace ModelContextProtocol.Server;
Expand All @@ -20,7 +21,8 @@ public class StreamServerTransport : TransportBase

private readonly ILogger _logger;

private readonly TextReader _inputReader;
private readonly Stream _inputStream;
private readonly PipeReader _inputPipeReader;
private readonly Stream _outputStream;

private readonly SemaphoreSlim _sendLock = new(1, 1);
Expand All @@ -45,11 +47,8 @@ public StreamServerTransport(Stream inputStream, Stream outputStream, string? se

_logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;

#if NET
_inputReader = new StreamReader(inputStream, Encoding.UTF8);
#else
_inputReader = new CancellableStreamReader(inputStream, Encoding.UTF8);
#endif
_inputStream = inputStream;
_inputPipeReader = PipeReader.Create(inputStream);
_outputStream = outputStream;

SetConnected();
Expand Down Expand Up @@ -94,48 +93,8 @@ private async Task ReadMessagesAsync()
try
{
LogTransportEnteringReadMessagesLoop(Name);

while (!shutdownToken.IsCancellationRequested)
{
var line = await _inputReader.ReadLineAsync(shutdownToken).ConfigureAwait(false);
if (string.IsNullOrWhiteSpace(line))
{
if (line is null)
{
LogTransportEndOfStream(Name);
break;
}

continue;
}

LogTransportReceivedMessageSensitive(Name, line);

try
{
if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) is JsonRpcMessage message)
{
await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false);
}
else
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, line);
}
}
catch (JsonException ex)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseFailedSensitive(Name, line, ex);
}
else
{
LogTransportMessageParseFailed(Name, ex);
}

// Continue reading even if we fail to parse a message
}
}
await _inputPipeReader.ReadLinesAsync(ProcessLineAsync, shutdownToken).ConfigureAwait(false);
LogTransportEndOfStream(Name);
}
catch (OperationCanceledException)
{
Expand All @@ -152,6 +111,53 @@ private async Task ReadMessagesAsync()
}
}

private async Task ProcessLineAsync(ReadOnlySequence<byte> line, CancellationToken cancellationToken)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, PipeReaderExtensions.GetUtf8String(line));
}

try
{
JsonRpcMessage? message;
if (line.IsSingleSegment)
{
message = JsonSerializer.Deserialize(line.First.Span, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}
else
{
var reader = new Utf8JsonReader(line, isFinalBlock: true, state: default);
message = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}

if (message is not null)
{
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
else
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, PipeReaderExtensions.GetUtf8String(line));
}
}
}
catch (JsonException ex)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseFailedSensitive(Name, PipeReaderExtensions.GetUtf8String(line), ex);
}
else
{
LogTransportMessageParseFailed(Name, ex);
}

// Continue reading even if we fail to parse a message.
}
}

/// <inheritdoc />
public override async ValueTask DisposeAsync()
{
Expand All @@ -170,7 +176,7 @@ public override async ValueTask DisposeAsync()

// Dispose of stdin/out. Cancellation may not be able to wake up operations
// synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor.
_inputReader?.Dispose();
_inputStream?.Dispose();
_outputStream?.Dispose();

// Make sure the work has quiesced.
Expand Down