diff --git a/src/protobuf-net.GrpcLite/Internal/Client/LiteCallInvoker.cs b/src/protobuf-net.GrpcLite/Internal/Client/LiteCallInvoker.cs index fdd739a..f4e9d95 100644 --- a/src/protobuf-net.GrpcLite/Internal/Client/LiteCallInvoker.cs +++ b/src/protobuf-net.GrpcLite/Internal/Client/LiteCallInvoker.cs @@ -18,6 +18,7 @@ internal sealed class LiteCallInvoker : CallInvoker, IConnection, IWorker private readonly ChannelWriter<(Frame Frame, FrameWriteFlags Flags)> _output; private readonly string _target; private readonly CancellationTokenSource _clientShutdown = new(); + private readonly CancellationToken _shutdownToken; private readonly ConcurrentDictionary _streams = new(); RefCountedMemoryPool IConnection.Pool => RefCountedMemoryPool.Shared; @@ -33,8 +34,15 @@ string IConnection.LastKnownUserAgent private int _nextId = ushort.MaxValue; // so that our first id is zero void IConnection.Remove(ushort id) => _streams.TryRemove(id, out _); + public void CompleteAllStreams() + { + foreach (var stream in _streams) + { + stream.Value.Cancel(); + } + } - CancellationToken IConnection.Shutdown => _clientShutdown.Token; + CancellationToken IConnection.Shutdown => _shutdownToken; private ClientStream AddClientStream(Method method, in CallOptions options) where TRequest : class where TResponse : class @@ -75,7 +83,9 @@ public LiteCallInvoker(string target, IFrameConnection connection, ILogger? logg this._target = target; this._connection = connection; this._logger = logger; - _ = connection.StartWriterAsync(this, out _output, _clientShutdown.Token); + + this._shutdownToken = _clientShutdown.Token; + _ = connection.StartWriterAsync(this, out _output, _shutdownToken); } ChannelWriter<(Frame Frame, FrameWriteFlags Flags)> IConnection.Output => _output; @@ -146,7 +156,7 @@ public void Execute() { _logger.SetSource(LogKind.Client, "invoker"); _logger.Debug(_target, static (state, _) => $"Starting call-invoker (client): {state}..."); - _ = this.RunAsync(_logger, _clientShutdown.Token); + _ = this.RunAsync(_logger, _shutdownToken); } catch (Exception ex) { diff --git a/src/protobuf-net.GrpcLite/Internal/ListenerEngine.cs b/src/protobuf-net.GrpcLite/Internal/ListenerEngine.cs index 0c89f7f..36d019b 100644 --- a/src/protobuf-net.GrpcLite/Internal/ListenerEngine.cs +++ b/src/protobuf-net.GrpcLite/Internal/ListenerEngine.cs @@ -21,6 +21,7 @@ internal interface IConnection ConcurrentDictionary Streams { get; } void Remove(ushort streamId); + void CompleteAllStreams(); CancellationToken Shutdown { get; } void Close(Exception? fault); @@ -167,13 +168,16 @@ await connection.ReadAllAsync(async frame => { logger.Debug(connection, static (state, _) => $"connection {state} ({(state.IsClient ? "client" : "server")}) exiting cleanly"); connection.Output.Complete(null); + connection.CompleteAllStreams(); + } catch (OperationCanceledException oce) when (oce.CancellationToken == cancellationToken) { } // alt-success catch (Exception ex) { logger.Error(frame, static (state, ex) => $"Error processing {state}: {ex?.Message}"); - connection?.Output.Complete(ex); + connection.Output.Complete(ex); + connection.CompleteAllStreams(); throw; } finally diff --git a/src/protobuf-net.GrpcLite/Internal/LiteStream.cs b/src/protobuf-net.GrpcLite/Internal/LiteStream.cs index b4030c0..1670773 100644 --- a/src/protobuf-net.GrpcLite/Internal/LiteStream.cs +++ b/src/protobuf-net.GrpcLite/Internal/LiteStream.cs @@ -813,6 +813,8 @@ public ValueTask SendHeaderAsync(string? host, in CallOptions options, FrameWrit protected void OnComplete() { IsActive = false; + + _suspendedContinuationPoint.SetResult(true); _connection?.Remove(Id); } diff --git a/src/protobuf-net.GrpcLite/Internal/Server/LiteConnection.cs b/src/protobuf-net.GrpcLite/Internal/Server/LiteConnection.cs index 272ecae..748b958 100644 --- a/src/protobuf-net.GrpcLite/Internal/Server/LiteConnection.cs +++ b/src/protobuf-net.GrpcLite/Internal/Server/LiteConnection.cs @@ -22,6 +22,13 @@ internal sealed class LiteConnection : IWorker, IConnection public int Id { get; } void IConnection.Remove(ushort id) => _streams.TryRemove(id, out _); + public void CompleteAllStreams() + { + foreach (var stream in _streams) + { + stream.Value.Cancel(); + } + } CancellationToken IConnection.Shutdown => _server.ServerShutdown;