Skip to content

Commit d5f5848

Browse files
use context varaibles instead of thread locals
Signed-off-by: Achille Roussel <[email protected]>
1 parent a406d2a commit d5f5848

File tree

2 files changed

+56
-16
lines changed

2 files changed

+56
-16
lines changed

src/dispatch/scheduler.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import asyncio
2+
import contextvars
23
import logging
34
import pickle
45
import sys
5-
import threading
66
from dataclasses import dataclass, field
77
from types import coroutine
88
from typing import (
@@ -32,19 +32,10 @@
3232
CoroutineID: TypeAlias = int
3333
CorrelationID: TypeAlias = int
3434

35-
36-
class ThreadLocal(threading.local):
37-
in_function_call: bool
38-
39-
def __init__(self):
40-
self.in_function_call = False
41-
42-
43-
thread_local = ThreadLocal()
44-
35+
_in_function_call = contextvars.ContextVar("dispatch.scheduler.in_function_call", default=False)
4536

4637
def in_function_call() -> bool:
47-
return thread_local.in_function_call
38+
return bool(_in_function_call.get())
4839

4940

5041
@dataclass
@@ -343,15 +334,15 @@ def __init__(
343334

344335
async def run(self, input: Input) -> Output:
345336
try:
346-
thread_local.in_function_call = True
337+
token = _in_function_call.set(True)
347338
return await self._run(input)
348339
except Exception as e:
349340
logger.exception(
350341
"unexpected exception occurred during coroutine scheduling"
351342
)
352343
return Output.error(Error.from_exception(e))
353344
finally:
354-
thread_local.in_function_call = False
345+
_in_function_call.reset(token)
355346

356347
def _init_state(self, input: Input) -> State:
357348
logger.debug("starting main coroutine")

src/dispatch/test/__init__.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import unittest
55
from datetime import datetime, timedelta
66
from functools import wraps
7-
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar
7+
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar
88

99
import aiohttp
1010
from aiohttp import web
@@ -193,6 +193,7 @@ def make_request(call: Call) -> RunRequest:
193193
else:
194194
error = Error(type="status", message=str(res.status))
195195
return CallResult(
196+
correlation_id=call.correlation_id,
196197
dispatch_id=dispatch_id,
197198
error=error,
198199
)
@@ -203,6 +204,7 @@ def make_request(call: Call) -> RunRequest:
203204
continue
204205
result = res.exit.result
205206
return CallResult(
207+
correlation_id=call.correlation_id,
206208
dispatch_id=dispatch_id,
207209
output=result.output if result.HasField("output") else None,
208210
error=result.error if result.HasField("error") else None,
@@ -317,6 +319,7 @@ def test(self):
317319
endpoint=DISPATCH_ENDPOINT_URL,
318320
client=Client(api_key=DISPATCH_API_KEY, api_url=DISPATCH_API_URL),
319321
)
322+
set_default_registry(_registry)
320323

321324

322325
@_registry.function
@@ -354,7 +357,33 @@ async def broken_nested(name: str) -> str:
354357
return await broken()
355358

356359

357-
set_default_registry(_registry)
360+
@_registry.function
361+
async def distributed_merge_sort(values: List[int]) -> List[int]:
362+
if len(values) <= 1:
363+
return values
364+
i = len(values) // 2
365+
366+
(l, r) = await dispatch.gather(
367+
distributed_merge_sort(values[:i]),
368+
distributed_merge_sort(values[i:]),
369+
)
370+
371+
return merge(l, r)
372+
373+
374+
def merge(l: List[int], r: List[int]) -> List[int]:
375+
result = []
376+
i = j = 0
377+
while i < len(l) and j < len(r):
378+
if l[i] < r[j]:
379+
result.append(l[i])
380+
i += 1
381+
else:
382+
result.append(r[j])
383+
j += 1
384+
result.extend(l[i:])
385+
result.extend(r[j:])
386+
return result
358387

359388

360389
class TestCase(unittest.TestCase):
@@ -473,6 +502,26 @@ async def test_call_nested_function_with_error(self):
473502
with self.assertRaises(ValueError) as e:
474503
await broken_nested("hello")
475504

505+
@aiotest
506+
async def test_distributed_merge_sort_no_values(self):
507+
values: List[int] = []
508+
self.assertEqual(await distributed_merge_sort(values), sorted(values))
509+
510+
@aiotest
511+
async def test_distributed_merge_sort_one_value(self):
512+
values: List[int] = [1]
513+
self.assertEqual(await distributed_merge_sort(values), sorted(values))
514+
515+
@aiotest
516+
async def test_distributed_merge_sort_two_values(self):
517+
values: List[int] = [1, 5]
518+
self.assertEqual(await distributed_merge_sort(values), sorted(values))
519+
520+
@aiotest
521+
async def test_distributed_merge_sort_many_values(self):
522+
values: List[int] = [1, 5, 3, 2, 4, 6, 7, 8, 9, 0]
523+
self.assertEqual(await distributed_merge_sort(values), sorted(values))
524+
476525

477526
class ClientRequestContentLengthMissing(aiohttp.ClientRequest):
478527
def update_headers(self, skip_auto_headers):

0 commit comments

Comments
 (0)