Skip to content

Commit ceb4768

Browse files
committed
WIP: cooperative threads
1 parent 702267c commit ceb4768

File tree

2 files changed

+125
-119
lines changed

2 files changed

+125
-119
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 125 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,10 @@ class CanonicalOptions(LiftLowerOptions):
214214

215215
### Runtime State
216216

217-
scheduler = asyncio.Lock()
218-
219217
#### Component Instance State
220218

221219
class ComponentInstance:
220+
store: Store
222221
table: Table
223222
may_leave: bool
224223
backpressure: bool
@@ -455,6 +454,51 @@ def drop(self):
455454
trap_if(len(self.elems) > 0)
456455
trap_if(self.num_waiting > 0)
457456

457+
#### Thread State
458+
459+
class Cancelled(IntEnum):
460+
FALSE = 0
461+
TRUE = 1
462+
463+
class Thread:
464+
store: Store
465+
awaitable: Optional[Awaitable]
466+
on_resume: Optional[asyncio.Future]
467+
on_suspend_or_exit: Optional[asyncio.Future]
468+
469+
def __init__(self, store, callee, on_start, on_resolve):
470+
self.store = store
471+
self.awaitable = None
472+
self.on_resume = asyncio.Future()
473+
self.on_suspend_or_exit = None
474+
async def thread():
475+
assert(await self.on_resume == Cancelled.FALSE)
476+
callee(self, on_start, on_resolve)
477+
self.on_suspend_or_exit.set_result(None)
478+
asyncio.create_task(thread())
479+
store.waiting.append(self)
480+
481+
async def resume(self, cancelled = Cancelled.FALSE):
482+
assert(cancelled or (not self.awaitable or self.awaitable.done()))
483+
self.awaitable = None
484+
self.store.waiting.remove(self)
485+
self.on_resume.set_result(cancelled)
486+
self.on_resume = None
487+
assert(not self.on_suspend_or_exit)
488+
self.on_suspend_or_exit = asyncio.Future()
489+
await self.on_suspend_or_exit
490+
assert(self.awaitable)
491+
self.store.waiting.append(self)
492+
493+
async def suspend(self, awaitable) -> Cancelled:
494+
assert(not self.awaitable)
495+
self.awaitable = awaitable
496+
self.on_suspend_or_exit.set_result(None)
497+
self.on_suspend_or_exit = None
498+
assert(not self.on_resume)
499+
self.on_resume = asyncio.Future()
500+
return await thread.on_resume
501+
458502
#### Task State
459503

460504
class Cancelled(IntEnum):
@@ -477,29 +521,28 @@ class State(Enum):
477521
inst: ComponentInstance
478522
ft: FuncType
479523
supertask: Optional[Task]
524+
thread: Thread
480525
on_resolve: OnResolve
481-
on_block: OnBlock
482526
num_borrows: int
483527
context: ContextLocalStorage
484528

485-
def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
529+
def __init__(self, opts, inst, ft, supertask, thread, on_resolve):
486530
self.state = Task.State.INITIAL
487531
self.opts = opts
488532
self.inst = inst
489533
self.ft = ft
490534
self.supertask = supertask
535+
self.thread = thread
491536
self.on_resolve = on_resolve
492-
self.on_block = on_block
493537
self.num_borrows = 0
494538
self.context = ContextLocalStorage()
495539

496540
async def enter(self):
497-
assert(scheduler.locked())
498541
self.trap_if_on_the_stack(self.inst)
499542
if not self.may_enter(self) or self.inst.pending_tasks:
500543
f = asyncio.Future()
501544
self.inst.pending_tasks.append((self, f))
502-
if await self.on_block(f) == Cancelled.TRUE:
545+
if await self.thread.suspend(f) == Cancelled.TRUE:
503546
[i] = [i for i,(t,_) in enumerate(self.inst.pending_tasks) if t == self]
504547
self.inst.pending_tasks.pop(i)
505548
self.on_resolve(None)
@@ -531,44 +574,28 @@ def maybe_start_pending_task(self):
531574
pending_future.set_result(None)
532575
return
533576

534-
async def sync_wait(self, awaitable, cancellable = False) -> Cancelled:
577+
async def sync_wait(self, awaitable) -> None:
535578
awaitable = asyncio.ensure_future(awaitable)
536579
if awaitable.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
537580
return
538581
assert(self.inst.unblocked.is_set())
539582
self.inst.unblocked.clear()
540-
cancelled = await self.on_block(awaitable)
541-
if cancelled and not cancellable:
583+
if await self.thread.suspend(awaitable) == Cancelled.TRUE:
542584
assert(self.state == Task.State.INITIAL)
543585
self.state = Task.State.PENDING_CANCEL
544-
cancelled = await self.on_block(awaitable)
545-
assert(not cancelled)
586+
assert(await self.thread.suspend(awaitable) == Cancelled.FALSE)
546587
self.inst.unblocked.set()
547-
return cancelled
548588

549589
async def async_wait(self, awaitable) -> Cancelled:
550590
self.maybe_start_pending_task()
551591
awaitable = asyncio.ensure_future(awaitable)
552592
if awaitable.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
553593
return
554-
cancelled = await self.on_block(awaitable)
594+
cancelled = await self.thread.suspend(awaitable)
555595
while not self.inst.unblocked.is_set():
556-
cancelled |= await self.on_block(self.inst.unblocked.wait())
596+
cancelled |= await self.thread.suspend(self.inst.unblocked.wait())
557597
return cancelled
558598

559-
async def call_sync(self, callee, on_start, on_return):
560-
async def sync_on_block(awaitable):
561-
if await self.on_block(awaitable) == Cancelled.TRUE:
562-
assert(self.state == Task.State.INITIAL)
563-
self.state = Task.State.PENDING_CANCEL
564-
assert(await self.on_block(awaitable) == Cancelled.FALSE)
565-
return False
566-
567-
assert(self.inst.unblocked.is_set())
568-
self.inst.unblocked.clear()
569-
await callee(self, on_start, on_return, sync_on_block)
570-
self.inst.unblocked.set()
571-
572599
async def wait_for_event(self, waitable_set, sync) -> EventTuple:
573600
if self.state == Task.State.PENDING_CANCEL:
574601
self.state = Task.State.CANCEL_DELIVERED
@@ -579,14 +606,13 @@ async def wait_for_event(self, waitable_set, sync) -> EventTuple:
579606
while not e:
580607
maybe_event = waitable_set.maybe_has_pending_event.wait()
581608
if sync:
582-
cancelled = await self.sync_wait(maybe_event, cancellable = True)
609+
await self.sync_wait(maybe_event)
583610
else:
584-
cancelled = await self.async_wait(maybe_event)
585-
if cancelled:
586-
assert(self.state == Task.State.INITIAL)
587-
self.state = Task.State.CANCEL_DELIVERED
588-
e = (EventCode.TASK_CANCELLED, 0, 0)
589-
break
611+
if await self.async_wait(maybe_event) == Cancelled.TRUE:
612+
assert(self.state == Task.State.INITIAL)
613+
self.state = Task.State.CANCEL_DELIVERED
614+
e = (EventCode.TASK_CANCELLED, 0, 0)
615+
break
590616
e = waitable_set.poll()
591617
waitable_set.num_waiting -= 1
592618
return e
@@ -595,12 +621,11 @@ async def yield_(self, sync) -> EventTuple:
595621
if self.state == Task.State.PENDING_CANCEL:
596622
self.state = Task.State.CANCEL_DELIVERED
597623
return (EventCode.TASK_CANCELLED, 0, 0)
624+
if sync:
625+
await self.sync_wait(asyncio.sleep(0))
626+
return (EventCode.NONE, 0, 0)
598627
else:
599-
if sync:
600-
cancelled = await self.sync_wait(asyncio.sleep(0), cancellable = True)
601-
else:
602-
cancelled = await self.async_wait(asyncio.sleep(0))
603-
if cancelled:
628+
if await self.async_wait(asyncio.sleep(0)) == Cancelled.TRUE:
604629
assert(self.state == Task.State.INITIAL)
605630
self.state = Task.State.CANCEL_DELIVERED
606631
return (EventCode.TASK_CANCELLED, 0, 0)
@@ -630,7 +655,6 @@ def cancel(self):
630655
self.state = Task.State.RESOLVED
631656

632657
def exit(self):
633-
assert(scheduler.locked())
634658
trap_if(self.state != Task.State.RESOLVED)
635659
assert(self.num_borrows == 0)
636660
if self.opts.sync:
@@ -650,34 +674,46 @@ class State(IntEnum):
650674

651675
state: State
652676
task: Task
677+
callee_thread: Optional[Thread]
653678
lenders: Optional[list[ResourceHandle]]
654-
request_cancel_begin: asyncio.Future
655-
request_cancel_end: asyncio.Future
679+
cancelled: bool
656680

657681
def __init__(self, task):
658682
Waitable.__init__(self)
659683
self.state = Subtask.State.STARTING
660684
self.task = task
685+
self.callee_thread = None
661686
self.lenders = []
662-
self.request_cancel_begin = asyncio.Future()
663-
self.request_cancel_end = asyncio.Future()
687+
self.cancelled = False
664688

665-
async def call_sync(self, callee, on_start, on_resolve):
666-
def sync_on_start():
689+
async def call(self, callee, on_start, on_resolve):
690+
def update_on_start():
667691
assert(self.state == Subtask.State.STARTING)
668692
self.state = Subtask.State.STARTED
669693
return on_start()
670694

671-
def sync_on_resolve(result):
672-
assert(result is not None)
695+
def update_on_resolve(result):
673696
assert(self.state == Subtask.State.STARTED)
674-
self.state = Subtask.State.RETURNED
697+
if result is None:
698+
assert(self.cancelled)
699+
if self.state == Subtask.State.STARTING:
700+
self.state = Subtask.State.CANCELLED_BEFORE_STARTED
701+
else:
702+
assert(self.state == Subtask.State.STARTED)
703+
self.state = Subtask.State.CANCELLED_BEFORE_RETURNED
704+
else:
705+
assert(self.state == Subtask.State.STARTED)
706+
self.state = Subtask.State.RETURNED
675707
on_resolve(result)
676708

677-
await Task.call_sync(self.task, callee, sync_on_start, sync_on_resolve)
709+
assert(not self.callee_thread)
710+
self.callee_thread = Thread(self.supertask.inst.store, callee, update_on_start, update_on_resolve)
711+
await self.callee_thread.resume()
678712

679-
def cancelled(self):
680-
return self.request_cancel_begin.done()
713+
async def wait_until_resolved(self):
714+
while not self.resolved():
715+
await self.supertask.sync_wait(self.callee_thread.awaitable)
716+
await self.callee_thread.resume()
681717

682718
def resolved(self):
683719
match self.state:
@@ -689,57 +725,13 @@ def resolved(self):
689725
Subtask.State.CANCELLED_BEFORE_RETURNED):
690726
return True
691727

692-
async def request_cancel(self):
693-
assert(not self.cancelled() and not self.resolved())
694-
self.request_cancel_begin.set_result(None)
695-
await self.request_cancel_end
696-
697-
async def call_async(self, callee, on_start, on_resolve):
698-
async def do_call():
699-
await callee(self.task, async_on_start, async_on_resolve, async_on_block)
700-
relinquish_control()
728+
def cancellation_requested(self):
729+
return self.cancelled
701730

702-
def async_on_start():
703-
assert(self.state == Subtask.State.STARTING)
704-
self.state = Subtask.State.STARTED
705-
return on_start()
706-
707-
def async_on_resolve(result):
708-
if result is None:
709-
if self.state == Subtask.State.STARTING:
710-
self.state = Subtask.State.CANCELLED_BEFORE_STARTED
711-
else:
712-
assert(self.state == Subtask.State.STARTED)
713-
self.state = Subtask.State.CANCELLED_BEFORE_RETURNED
714-
else:
715-
assert(self.state == Subtask.State.STARTED)
716-
self.state = Subtask.State.RETURNED
717-
on_resolve(result)
718-
719-
async def async_on_block(awaitable):
720-
relinquish_control()
721-
if not self.request_cancel_end.done():
722-
await asyncio.wait([awaitable, self.request_cancel_begin],
723-
return_when = asyncio.FIRST_COMPLETED)
724-
if self.request_cancel_begin.done():
725-
return True
726-
else:
727-
await awaitable
728-
assert(awaitable.done())
729-
await scheduler.acquire()
730-
return False
731-
732-
def relinquish_control():
733-
if not ret.done():
734-
ret.set_result(None)
735-
elif self.request_cancel_begin.done() and not self.request_cancel_end.done():
736-
self.request_cancel_end.set_result(None)
737-
else:
738-
scheduler.release()
739-
740-
ret = asyncio.Future()
741-
asyncio.create_task(do_call())
742-
await ret
731+
async def request_cancel(self):
732+
assert(not self.cancellation_requested() and not self.resolved())
733+
self.cancelled = True
734+
await self.callee_thread.resume(Cancelled.TRUE)
743735

744736
def add_lender(self, lending_handle):
745737
assert(not self.resolve_delivered() and not self.resolved())
@@ -981,6 +973,27 @@ def drop(self):
981973
trap_if(not self.done)
982974
FutureEnd.drop(self)
983975

976+
### Store State
977+
978+
class Store:
979+
loop: asyncio.AbstractEventLoop
980+
waiting: list[Thread]
981+
982+
def __init__(self):
983+
self.loop = asyncio.new_event_loop()
984+
self.waiting = []
985+
986+
def call_export(self, callee, on_start, on_resolve):
987+
self.run_until_complete(Thread(self, callee, on_start, on_resolve).resume())
988+
989+
def tick(self, i):
990+
if not DETERMINISTIC_PROFILE:
991+
random.shuffle(self.waiting)
992+
for thread in self.waiting:
993+
if thread.awaitable.done():
994+
self.run_until_complete(thread.resume())
995+
return
996+
984997
### Despecialization
985998

986999
def despecialize(t):
@@ -1936,8 +1949,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19361949

19371950
### `canon lift`
19381951

1939-
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_block):
1940-
task = Task(opts, inst, ft, caller, on_resolve, on_block)
1952+
async def canon_lift(opts, inst, ft, callee, caller, thread, on_start, on_resolve):
1953+
task = Task(opts, inst, ft, caller, thread, on_resolve)
19411954
if not await task.enter():
19421955
return
19431956

@@ -2022,10 +2035,13 @@ def on_start():
20222035

20232036
flat_results = None
20242037
def on_resolve(result):
2038+
assert(result is not None)
20252039
nonlocal flat_results
20262040
flat_results = lower_flat_values(cx, MAX_FLAT_RESULTS, result, ft.result_type(), flat_args)
20272041

2028-
await subtask.call_sync(callee, on_start, on_resolve)
2042+
await subtask.call(callee, on_start, on_resolve)
2043+
await subtask.wait_until_resolved()
2044+
20292045
assert(types_match_values(flat_ft.results, flat_results))
20302046
subtask.deliver_resolve()
20312047
return flat_results
@@ -2048,7 +2064,7 @@ def subtask_event():
20482064
return (EventCode.SUBTASK, subtaski, subtask.state)
20492065
subtask.set_event(subtask_event)
20502066

2051-
await subtask.call_async(callee, on_start, on_resolve)
2067+
await subtask.call(callee, on_start, on_resolve)
20522068
if subtask.resolved():
20532069
subtask.deliver_resolve()
20542070
return [Subtask.State.RETURNED]
@@ -2217,16 +2233,13 @@ async def canon_subtask_cancel(sync, task, i):
22172233
trap_if(not task.inst.may_leave)
22182234
subtask = task.inst.table.get(i)
22192235
trap_if(not isinstance(subtask, Subtask))
2220-
trap_if(subtask.resolve_delivered() or subtask.cancelled())
2236+
trap_if(subtask.resolve_delivered() or subtask.cancellation_requested())
22212237
if subtask.resolved():
22222238
assert(subtask.has_pending_event())
22232239
else:
22242240
await subtask.request_cancel()
22252241
if sync:
2226-
while not subtask.resolved():
2227-
if subtask.has_pending_event():
2228-
_ = subtask.get_event()
2229-
await task.sync_wait(subtask.wait_for_pending_event())
2242+
await subtask.wait_until_resolved()
22302243
else:
22312244
if not subtask.resolved():
22322245
return [BLOCKED]

0 commit comments

Comments
 (0)