Skip to content

Commit 1e45dec

Browse files
committed
CABI: factor out common code in Subtask and canon_lower
1 parent a6c34b2 commit 1e45dec

File tree

1 file changed

+55
-72
lines changed

1 file changed

+55
-72
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 55 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -645,36 +645,19 @@ class State(IntEnum):
645645
CANCELLED_BEFORE_RETURNED = 4
646646

647647
state: State
648-
supertask: Optional[Task]
648+
task: Task
649649
lenders: Optional[list[ResourceHandle]]
650650
request_cancel_begin: asyncio.Future
651651
request_cancel_end: asyncio.Future
652652

653-
def __init__(self, supertask):
653+
def __init__(self, task):
654654
Waitable.__init__(self)
655655
self.state = Subtask.State.STARTING
656-
self.supertask = supertask
656+
self.task = task
657657
self.lenders = []
658658
self.request_cancel_begin = asyncio.Future()
659659
self.request_cancel_end = asyncio.Future()
660660

661-
async def call_sync(self, callee, on_start, on_resolve):
662-
def sync_on_start():
663-
assert(self.state == Subtask.State.STARTING)
664-
self.state = Subtask.State.STARTED
665-
return on_start()
666-
667-
def sync_on_resolve(result):
668-
assert(result is not None)
669-
assert(self.state == Subtask.State.STARTED)
670-
self.state = Subtask.State.RETURNED
671-
on_resolve(result)
672-
673-
await Task.call_sync(self.supertask, callee, sync_on_start, sync_on_resolve)
674-
675-
def cancelled(self):
676-
return self.request_cancel_begin.done()
677-
678661
def resolved(self):
679662
match self.state:
680663
case (Subtask.State.STARTING |
@@ -686,32 +669,18 @@ def resolved(self):
686669
return True
687670

688671
async def request_cancel(self):
689-
assert(not self.cancelled() and not self.resolved())
672+
assert(not self.cancellation_requested() and not self.resolved())
690673
self.request_cancel_begin.set_result(None)
691674
await self.request_cancel_end
692675

676+
def cancellation_requested(self):
677+
return self.request_cancel_begin.done()
678+
693679
async def call_async(self, callee, on_start, on_resolve):
694680
async def do_call():
695-
await callee(self.supertask, async_on_start, async_on_resolve, async_on_block)
681+
await callee(self.task, on_start, on_resolve, async_on_block)
696682
relinquish_control()
697683

698-
def async_on_start():
699-
assert(self.state == Subtask.State.STARTING)
700-
self.state = Subtask.State.STARTED
701-
return on_start()
702-
703-
def async_on_resolve(result):
704-
if result is None:
705-
if self.state == Subtask.State.STARTING:
706-
self.state = Subtask.State.CANCELLED_BEFORE_STARTED
707-
else:
708-
assert(self.state == Subtask.State.STARTED)
709-
self.state = Subtask.State.CANCELLED_BEFORE_RETURNED
710-
else:
711-
assert(self.state == Subtask.State.STARTED)
712-
self.state = Subtask.State.RETURNED
713-
on_resolve(result)
714-
715684
async def async_on_block(awaitable):
716685
relinquish_control()
717686
if not self.request_cancel_end.done():
@@ -2007,52 +1976,65 @@ async def call_and_trap_on_throw(callee, task, args):
20071976
async def canon_lower(opts, ft, callee, task, flat_args):
20081977
trap_if(not task.inst.may_leave)
20091978
subtask = Subtask(task)
1979+
20101980
cx = LiftLowerContext(opts, task.inst, subtask)
20111981
flat_ft = flatten_functype(opts, ft, 'lower')
20121982
assert(types_match_values(flat_ft.params, flat_args))
20131983
flat_args = CoreValueIter(flat_args)
20141984

20151985
if opts.sync:
2016-
def on_start():
2017-
return lift_flat_values(cx, MAX_FLAT_PARAMS, flat_args, ft.param_types())
2018-
2019-
flat_results = None
2020-
def on_resolve(result):
2021-
nonlocal flat_results
2022-
flat_results = lower_flat_values(cx, MAX_FLAT_RESULTS, result, ft.result_type(), flat_args)
1986+
max_flat_params = MAX_FLAT_PARAMS
1987+
max_flat_results = MAX_FLAT_RESULTS
1988+
else:
1989+
max_flat_params = MAX_FLAT_ASYNC_PARAMS
1990+
max_flat_results = 0
20231991

2024-
await subtask.call_sync(callee, on_start, on_resolve)
2025-
assert(types_match_values(flat_ft.results, flat_results))
2026-
subtask.deliver_resolve()
2027-
return flat_results
1992+
on_progress = lambda:()
1993+
flat_results = None
20281994

20291995
def on_start():
20301996
on_progress()
2031-
return lift_flat_values(cx, MAX_FLAT_ASYNC_PARAMS, flat_args, ft.param_types())
1997+
assert(subtask.state == Subtask.State.STARTING)
1998+
subtask.state = Subtask.State.STARTED
1999+
return lift_flat_values(cx, max_flat_params, flat_args, ft.param_types())
20322000

20332001
def on_resolve(result):
20342002
on_progress()
2035-
if result is not None:
2036-
[] = lower_flat_values(cx, 0, result, ft.result_type(), flat_args)
2037-
2038-
subtaski = None
2039-
def on_progress():
2040-
if subtaski is not None:
2041-
def subtask_event():
2042-
if subtask.resolved():
2043-
subtask.deliver_resolve()
2044-
return (EventCode.SUBTASK, subtaski, subtask.state)
2045-
subtask.set_event(subtask_event)
2046-
2047-
await subtask.call_async(callee, on_start, on_resolve)
2048-
if subtask.resolved():
2049-
subtask.deliver_resolve()
2050-
return [Subtask.State.RETURNED]
2003+
if result is None:
2004+
assert(subtask.cancellation_requested())
2005+
if subtask.state == Subtask.State.STARTING:
2006+
subtask.state = Subtask.State.CANCELLED_BEFORE_STARTED
2007+
else:
2008+
assert(subtask.state == Subtask.State.STARTED)
2009+
subtask.state = Subtask.State.CANCELLED_BEFORE_RETURNED
2010+
else:
2011+
assert(subtask.state == Subtask.State.STARTED)
2012+
subtask.state = Subtask.State.RETURNED
2013+
nonlocal flat_results
2014+
flat_results = lower_flat_values(cx, max_flat_results, result, ft.result_type(), flat_args)
20512015

2052-
subtaski = task.inst.table.add(subtask)
2053-
assert(0 < subtaski <= Table.MAX_LENGTH < 2**28)
2054-
assert(0 <= subtask.state < 2**4)
2055-
return [subtask.state | (subtaski << 4)]
2016+
if opts.sync:
2017+
await task.call_sync(callee, on_start, on_resolve)
2018+
assert(types_match_values(flat_ft.results, flat_results))
2019+
subtask.deliver_resolve()
2020+
return flat_results
2021+
else:
2022+
await subtask.call_async(callee, on_start, on_resolve)
2023+
if subtask.resolved():
2024+
assert(flat_results == [])
2025+
subtask.deliver_resolve()
2026+
return [Subtask.State.RETURNED]
2027+
else:
2028+
subtaski = task.inst.table.add(subtask)
2029+
def on_progress():
2030+
def subtask_event():
2031+
if subtask.resolved():
2032+
subtask.deliver_resolve()
2033+
return (EventCode.SUBTASK, subtaski, subtask.state)
2034+
subtask.set_event(subtask_event)
2035+
assert(0 < subtaski <= Table.MAX_LENGTH < 2**28)
2036+
assert(0 <= subtask.state < 2**4)
2037+
return [subtask.state | (subtaski << 4)]
20562038

20572039
### `canon resource.new`
20582040

@@ -2213,7 +2195,8 @@ async def canon_subtask_cancel(sync, task, i):
22132195
trap_if(not task.inst.may_leave)
22142196
subtask = task.inst.table.get(i)
22152197
trap_if(not isinstance(subtask, Subtask))
2216-
trap_if(subtask.resolve_delivered() or subtask.cancelled())
2198+
trap_if(subtask.resolve_delivered())
2199+
trap_if(subtask.cancellation_requested())
22172200
if subtask.resolved():
22182201
assert(subtask.has_pending_event())
22192202
else:

0 commit comments

Comments
 (0)