@@ -49,21 +49,16 @@ namespace Microsoft.Playwright.Tests.TestServer;
4949
5050public class SimpleServer
5151{
52- const int MaxMessageSize = 256 * 1024 ;
53-
54- private readonly IDictionary < string , Action < HttpContext > > _requestWaits ;
55- private readonly IList < Func < WebSocket , HttpContext , Task > > _waitForWebSocketConnectionRequestsWaits ;
52+ private readonly IDictionary < string , Action < HttpContext > > _requestSubscribers ;
53+ private readonly ConcurrentBag < Func < WebSocketWithEvents , HttpContext , Task > > _webSocketSubscribers ;
5654 private readonly IDictionary < string , Func < HttpContext , Task > > _routes ;
5755 private readonly IDictionary < string , ( string username , string password ) > _auths ;
5856 private readonly IDictionary < string , string > _csp ;
5957 private readonly IList < string > _gzipRoutes ;
6058 private readonly string _contentRoot ;
6159
62-
6360 private ArraySegment < byte > _onWebSocketConnectionData ;
6461 private readonly IWebHost _webHost ;
65- private static int _counter ;
66- private readonly Dictionary < int , WebSocket > _clients = new ( ) ;
6762
6863 public int Port { get ; }
6964 public string Prefix { get ; }
@@ -84,8 +79,8 @@ public SimpleServer(int port, string contentRoot, bool isHttps)
8479 EmptyPage = $ "{ Prefix } /empty.html";
8580
8681 var currentExecutionContext = TestExecutionContext . CurrentContext ;
87- _requestWaits = new ConcurrentDictionary < string , Action < HttpContext > > ( ) ;
88- _waitForWebSocketConnectionRequestsWaits = [ ] ;
82+ _requestSubscribers = new ConcurrentDictionary < string , Action < HttpContext > > ( ) ;
83+ _webSocketSubscribers = [ ] ;
8984 _routes = new ConcurrentDictionary < string , Func < HttpContext , Task > > ( ) ;
9085 _auths = new ConcurrentDictionary < string , ( string username , string password ) > ( ) ;
9186 _csp = new ConcurrentDictionary < string , string > ( ) ;
@@ -108,30 +103,23 @@ public SimpleServer(int port, string contentRoot, bool isHttps)
108103 var currentContext = typeof ( TestExecutionContext ) . GetField ( "AsyncLocalCurrentContext" , BindingFlags . NonPublic | BindingFlags . Static ) . GetValue ( null ) as AsyncLocal < TestExecutionContext > ;
109104 currentContext . Value = currentExecutionContext ;
110105 }
111- if ( context . Request . Path == "/ws" )
106+ if ( context . WebSockets . IsWebSocketRequest && context . Request . Path == "/ws" )
112107 {
113- if ( context . WebSockets . IsWebSocketRequest )
108+ var webSocket = await context . WebSockets . AcceptWebSocketAsync ( ) . ConfigureAwait ( false ) ;
109+ var testWebSocket = new WebSocketWithEvents ( webSocket , context . Request ) ;
110+ if ( _onWebSocketConnectionData != null )
114111 {
115- var webSocket = await context . WebSockets . AcceptWebSocketAsync ( ) . ConfigureAwait ( false ) ;
116- foreach ( var wait in _waitForWebSocketConnectionRequestsWaits )
117- {
118- _waitForWebSocketConnectionRequestsWaits . Remove ( wait ) ;
119- await wait ( webSocket , context ) . ConfigureAwait ( false ) ;
120- }
121- if ( _onWebSocketConnectionData != null )
122- {
123- await webSocket . SendAsync ( _onWebSocketConnectionData , WebSocketMessageType . Text , true , CancellationToken . None ) . ConfigureAwait ( false ) ;
124- }
125- await ReceiveLoopAsync ( webSocket , context . Request . Headers [ "User-Agent" ] . ToString ( ) . Contains ( "Firefox" ) , CancellationToken . None ) . ConfigureAwait ( false ) ;
112+ await webSocket . SendAsync ( _onWebSocketConnectionData , WebSocketMessageType . Text , true , CancellationToken . None ) . ConfigureAwait ( false ) ;
126113 }
127- else if ( ! context . Response . HasStarted )
114+ foreach ( var wait in _webSocketSubscribers )
128115 {
129- context . Response . StatusCode = 400 ;
116+ await wait ( testWebSocket , context ) . ConfigureAwait ( false ) ;
130117 }
118+ await testWebSocket . RunReceiveLoop ( ) . ConfigureAwait ( false ) ;
131119 return ;
132120 }
133121
134- if ( _requestWaits . TryGetValue ( context . Request . Path , out var requestWait ) )
122+ if ( _requestSubscribers . TryGetValue ( context . Request . Path , out var requestWait ) )
135123 {
136124 requestWait ( context ) ;
137125 }
@@ -264,9 +252,10 @@ public void Reset()
264252 _routes . Clear ( ) ;
265253 _auths . Clear ( ) ;
266254 _csp . Clear ( ) ;
267- _requestWaits . Clear ( ) ;
255+ _requestSubscribers . Clear ( ) ;
268256 _gzipRoutes . Clear ( ) ;
269257 _onWebSocketConnectionData = null ;
258+ _webSocketSubscribers . Clear ( ) ;
270259 }
271260
272261 public void EnableGzip ( string path ) => _gzipRoutes . Add ( path ) ;
@@ -290,33 +279,33 @@ public void SetRedirect(string from, string to) => SetRoute(from, context =>
290279 public async Task < T > WaitForRequest < T > ( string path , Func < HttpRequest , T > selector )
291280 {
292281 var taskCompletion = new TaskCompletionSource < T > ( ) ;
293- _requestWaits . Add ( path , context =>
282+ _requestSubscribers . Add ( path , context =>
294283 {
295284 taskCompletion . SetResult ( selector ( context . Request ) ) ;
296285 } ) ;
297286
298287 var request = await taskCompletion . Task . ConfigureAwait ( false ) ;
299- _requestWaits . Remove ( path ) ;
288+ _requestSubscribers . Remove ( path ) ;
300289
301290 return request ;
302291 }
303292
304293 public Task WaitForRequest ( string path ) => WaitForRequest ( path , _ => true ) ;
305294
306- public async Task < ( WebSocket , HttpRequest ) > WaitForWebSocketConnectionRequest ( )
295+ public Task < WebSocketWithEvents > WaitForWebSocketAsync ( )
307296 {
308- var taskCompletion = new TaskCompletionSource < ( WebSocket , HttpRequest ) > ( ) ;
309- OnceWebSocketConnection ( ( WebSocket ws , HttpContext context ) =>
297+ var tcs = new TaskCompletionSource < WebSocketWithEvents > ( ) ;
298+ OnceWebSocketConnection ( ( ws , _ ) =>
310299 {
311- taskCompletion . SetResult ( ( ws , context . Request ) ) ;
300+ tcs . SetResult ( ws ) ;
312301 return Task . CompletedTask ;
313302 } ) ;
314- return await taskCompletion . Task . ConfigureAwait ( false ) ;
303+ return tcs . Task ;
315304 }
316305
317- public void OnceWebSocketConnection ( Func < WebSocket , HttpContext , Task > handler )
306+ public void OnceWebSocketConnection ( Func < WebSocketWithEvents , HttpContext , Task > handler )
318307 {
319- _waitForWebSocketConnectionRequestsWaits . Add ( handler ) ;
308+ _webSocketSubscribers . Add ( handler ) ;
320309 }
321310
322311 private static bool Authenticate ( string username , string password , HttpContext context )
@@ -332,73 +321,4 @@ private static bool Authenticate(string username, string password, HttpContext c
332321 }
333322 return false ;
334323 }
335-
336- private async Task ReceiveLoopAsync ( WebSocket webSocket , bool sendCloseMessage , CancellationToken token )
337- {
338- int connectionId = NextConnectionId ( ) ;
339- _clients . Add ( connectionId , webSocket ) ;
340-
341- byte [ ] buffer = new byte [ MaxMessageSize ] ;
342-
343- try
344- {
345- while ( true )
346- {
347- var result = await webSocket . ReceiveAsync ( new ( buffer ) , token ) . ConfigureAwait ( false ) ;
348-
349- if ( result . MessageType == WebSocketMessageType . Close )
350- {
351- if ( sendCloseMessage )
352- {
353- await webSocket . CloseAsync ( WebSocketCloseStatus . NormalClosure , "Close" , CancellationToken . None ) . ConfigureAwait ( false ) ;
354- }
355- break ;
356- }
357-
358- var data = await ReadFrames ( result , webSocket , buffer , token ) . ConfigureAwait ( false ) ;
359-
360- if ( data . Count == 0 )
361- {
362- break ;
363- }
364- }
365- }
366- finally
367- {
368- _clients . Remove ( connectionId ) ;
369- }
370- }
371-
372- private async Task < ArraySegment < byte > > ReadFrames ( WebSocketReceiveResult result , WebSocket webSocket , byte [ ] buffer , CancellationToken token )
373- {
374- int count = result . Count ;
375-
376- while ( ! result . EndOfMessage )
377- {
378- if ( count >= MaxMessageSize )
379- {
380- string closeMessage = $ "Maximum message size: { MaxMessageSize } bytes.";
381- await webSocket . CloseAsync ( WebSocketCloseStatus . MessageTooBig , closeMessage , token ) . ConfigureAwait ( false ) ;
382- return new ( ) ;
383- }
384-
385- result = await webSocket . ReceiveAsync ( new ( buffer , count , MaxMessageSize - count ) , token ) . ConfigureAwait ( false ) ;
386- count += result . Count ;
387-
388- }
389- return new ( buffer , 0 , count ) ;
390- }
391-
392-
393- private static int NextConnectionId ( )
394- {
395- int id = Interlocked . Increment ( ref _counter ) ;
396-
397- if ( id == int . MaxValue )
398- {
399- throw new ( "connection id limit reached: " + id ) ;
400- }
401-
402- return id ;
403- }
404324}
0 commit comments