4
4
import unittest
5
5
from datetime import datetime , timedelta
6
6
from functools import wraps
7
- from typing import Any , Callable , Coroutine , Dict , Optional , TypeVar
7
+ from typing import Any , Callable , Coroutine , Dict , List , Optional , TypeVar
8
8
9
9
import aiohttp
10
10
from aiohttp import web
@@ -193,6 +193,7 @@ def make_request(call: Call) -> RunRequest:
193
193
else :
194
194
error = Error (type = "status" , message = str (res .status ))
195
195
return CallResult (
196
+ correlation_id = call .correlation_id ,
196
197
dispatch_id = dispatch_id ,
197
198
error = error ,
198
199
)
@@ -203,6 +204,7 @@ def make_request(call: Call) -> RunRequest:
203
204
continue
204
205
result = res .exit .result
205
206
return CallResult (
207
+ correlation_id = call .correlation_id ,
206
208
dispatch_id = dispatch_id ,
207
209
output = result .output if result .HasField ("output" ) else None ,
208
210
error = result .error if result .HasField ("error" ) else None ,
@@ -317,6 +319,7 @@ def test(self):
317
319
endpoint = DISPATCH_ENDPOINT_URL ,
318
320
client = Client (api_key = DISPATCH_API_KEY , api_url = DISPATCH_API_URL ),
319
321
)
322
+ set_default_registry (_registry )
320
323
321
324
322
325
@_registry .function
@@ -354,7 +357,33 @@ async def broken_nested(name: str) -> str:
354
357
return await broken ()
355
358
356
359
357
- set_default_registry (_registry )
360
+ @_registry .function
361
+ async def distributed_merge_sort (values : List [int ]) -> List [int ]:
362
+ if len (values ) <= 1 :
363
+ return values
364
+ i = len (values ) // 2
365
+
366
+ (l , r ) = await dispatch .gather (
367
+ distributed_merge_sort (values [:i ]),
368
+ distributed_merge_sort (values [i :]),
369
+ )
370
+
371
+ return merge (l , r )
372
+
373
+
374
+ def merge (l : List [int ], r : List [int ]) -> List [int ]:
375
+ result = []
376
+ i = j = 0
377
+ while i < len (l ) and j < len (r ):
378
+ if l [i ] < r [j ]:
379
+ result .append (l [i ])
380
+ i += 1
381
+ else :
382
+ result .append (r [j ])
383
+ j += 1
384
+ result .extend (l [i :])
385
+ result .extend (r [j :])
386
+ return result
358
387
359
388
360
389
class TestCase (unittest .TestCase ):
@@ -473,6 +502,26 @@ async def test_call_nested_function_with_error(self):
473
502
with self .assertRaises (ValueError ) as e :
474
503
await broken_nested ("hello" )
475
504
505
+ @aiotest
506
+ async def test_distributed_merge_sort_no_values (self ):
507
+ values : List [int ] = []
508
+ self .assertEqual (await distributed_merge_sort (values ), sorted (values ))
509
+
510
+ @aiotest
511
+ async def test_distributed_merge_sort_one_value (self ):
512
+ values : List [int ] = [1 ]
513
+ self .assertEqual (await distributed_merge_sort (values ), sorted (values ))
514
+
515
+ @aiotest
516
+ async def test_distributed_merge_sort_two_values (self ):
517
+ values : List [int ] = [1 , 5 ]
518
+ self .assertEqual (await distributed_merge_sort (values ), sorted (values ))
519
+
520
+ @aiotest
521
+ async def test_distributed_merge_sort_many_values (self ):
522
+ values : List [int ] = [1 , 5 , 3 , 2 , 4 , 6 , 7 , 8 , 9 , 0 ]
523
+ self .assertEqual (await distributed_merge_sort (values ), sorted (values ))
524
+
476
525
477
526
class ClientRequestContentLengthMissing (aiohttp .ClientRequest ):
478
527
def update_headers (self , skip_auto_headers ):
0 commit comments