Skip to content

Commit 5a60c2a

Browse files
committed
fix localized retry admission
Fixes #742 Signed-off-by: Eric W. Tramel <1223539+eric-tramel@users.noreply.github.com>
1 parent 83ee424 commit 5a60c2a

2 files changed

Lines changed: 711 additions & 37 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 215 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from data_designer.engine.dataset_builders.scheduling.resolver import TaskSchedulingResolver
4242
from data_designer.engine.dataset_builders.scheduling.resources import (
4343
SchedulableTask,
44+
request_scheduler_resource_key,
4445
stable_task_id,
4546
)
4647
from data_designer.engine.dataset_builders.scheduling.task_admission import (
@@ -145,6 +146,15 @@ class _DispatchOutcome:
145146
admission_blocked: bool = False
146147

147148

149+
@dataclass(frozen=True)
150+
class _DeferredAdmissionAnalysis:
151+
"""Deferred retry pressure as seen by adaptive row-group admission."""
152+
153+
blocks: bool
154+
candidate_columns: tuple[str, ...]
155+
independent_candidate_columns: tuple[str, ...]
156+
157+
148158
class AsyncTaskScheduler:
149159
"""Dependency-aware async task scheduler for the dataset builder.
150160
@@ -329,6 +339,12 @@ def __init__(
329339
self._row_group_admission_pressure_ticks = 0
330340
self._row_group_admission_blocked_reasons: Counter[str] = Counter()
331341
self._adaptive_max_admitted_rows = self._max_admitted_rows_guardrail()
342+
self._row_group_admission_pending: tuple[int, int] | None = None
343+
self._deferred_admission_cache: tuple[tuple[Task, ...], tuple[int, int], _DeferredAdmissionAnalysis] | None = (
344+
None
345+
)
346+
self._transitive_upstream_cache: dict[str, frozenset[str]] = {}
347+
self._transitive_downstream_cache: dict[str, frozenset[str]] = {}
332348
self._request_pressure_provider = request_pressure_provider
333349
self._request_pressure_advisory = request_pressure_advisory and request_pressure_provider is not None
334350
self._request_pressure_advisory_skips = 0
@@ -944,13 +960,14 @@ def _maybe_update_adaptive_row_group_target(self) -> None:
944960
self._row_group_admission_event.set()
945961

946962
def _adaptive_row_group_block_reason(self) -> str | None:
947-
if self._deferred:
948-
return "deferred_tasks"
949-
next_size = self._next_unadmitted_row_group_size()
950-
if next_size is None:
963+
next_row_group = self._next_unadmitted_row_group()
964+
if next_row_group is None:
951965
return "no_pending_row_groups"
966+
next_rg_id, next_size = next_row_group
952967
if not self._row_group_row_guard_allows(next_size):
953968
return "max_admitted_rows"
969+
if self._deferred and self._deferred_admission_analysis(next_rg_id, next_size).blocks:
970+
return "deferred_tasks"
954971
queue_view = self._fair_queue.view()
955972
queue_guard = self._max_in_flight_tasks * 4
956973
if queue_view.queued_total >= queue_guard:
@@ -968,13 +985,169 @@ def _adaptive_row_group_block_reason(self) -> str | None:
968985
return "queued_llm_demand"
969986
return None
970987

971-
def _next_unadmitted_row_group_size(self) -> int | None:
972-
for rg_id, rg_size in self._row_groups:
973-
if rg_id not in self._rg_states and not self._tracker.is_row_group_complete(
974-
rg_id, rg_size, self._graph.columns
975-
):
976-
return rg_size
977-
return None
988+
def _next_unadmitted_row_group(self) -> tuple[int, int] | None:
989+
pending = self._row_group_admission_pending
990+
if pending is None:
991+
return None
992+
rg_id, rg_size = pending
993+
if rg_id in self._rg_states or self._tracker.is_row_group_complete(rg_id, rg_size, self._graph.columns):
994+
return None
995+
return pending
996+
997+
def _deferred_admission_analysis(self, row_group: int, row_group_size: int) -> _DeferredAdmissionAnalysis:
998+
cache_key = (tuple(self._deferred), (row_group, row_group_size))
999+
if self._deferred_admission_cache is not None and self._deferred_admission_cache[:2] == cache_key:
1000+
return self._deferred_admission_cache[2]
1001+
deferred_items = tuple(self._schedulable_task(task) for task in self._deferred)
1002+
deferred_keys = {key for item in deferred_items for key in self._localized_deferred_admission_keys(item)}
1003+
candidates = tuple(
1004+
(item, self._localized_deferred_admission_keys(item))
1005+
for item in self._row_group_admission_candidate_tasks(row_group, row_group_size)
1006+
)
1007+
blocked_columns: set[str] = set()
1008+
for item in deferred_items:
1009+
blocked_columns.update(self._task_output_columns(item.payload))
1010+
for item, keys in candidates:
1011+
if keys & deferred_keys:
1012+
blocked_columns.update(self._task_output_columns(item.payload))
1013+
independent_candidates = tuple(
1014+
item.payload.column
1015+
for item, keys in candidates
1016+
if not (keys & deferred_keys)
1017+
and not self._task_depends_on_any(item.payload, blocked_columns)
1018+
and (
1019+
self._is_resource_scoped_admission_candidate(item)
1020+
or not self._task_reaches_any(item.payload, blocked_columns)
1021+
)
1022+
)
1023+
blocks = bool(deferred_items) and not independent_candidates
1024+
analysis = _DeferredAdmissionAnalysis(
1025+
blocks=blocks,
1026+
candidate_columns=tuple(item.payload.column for item, _keys in candidates),
1027+
independent_candidate_columns=independent_candidates,
1028+
)
1029+
self._deferred_admission_cache = (*cache_key, analysis)
1030+
return analysis
1031+
1032+
def _row_group_admission_candidate_tasks(
1033+
self,
1034+
row_group: int,
1035+
row_group_size: int,
1036+
) -> tuple[SchedulableTask, ...]:
1037+
tasks: list[SchedulableTask] = []
1038+
seen_generators: set[int] = set()
1039+
for column in self._graph.get_topological_order():
1040+
generator_id = id(self._generators[column])
1041+
if generator_id in seen_generators:
1042+
continue
1043+
seen_generators.add(generator_id)
1044+
strategy = self._graph.get_strategy(column)
1045+
if strategy == GenerationStrategy.CELL_BY_CELL:
1046+
if row_group_size <= 0:
1047+
continue
1048+
task = Task(column=column, row_group=row_group, row_index=0, task_type="cell")
1049+
elif column in self._seed_cols:
1050+
task = Task(column=column, row_group=row_group, row_index=None, task_type="from_scratch")
1051+
else:
1052+
task = Task(column=column, row_group=row_group, row_index=None, task_type="batch")
1053+
tasks.append(self._schedulable_task(task))
1054+
return tuple(tasks)
1055+
1056+
def _localized_deferred_admission_keys(self, item: SchedulableTask) -> set[str]:
1057+
if item.request_resource_key is not None:
1058+
resource = item.request_resource_key
1059+
return {
1060+
f"request_resource:{_request_resource_label(resource)}",
1061+
f"scheduler_resource:{request_scheduler_resource_key(resource)}",
1062+
}
1063+
identity = "/".join(item.group.key.identity)
1064+
return {f"group:{item.group.key.kind}:{identity}"}
1065+
1066+
@staticmethod
1067+
def _is_localized_admission_resource(resource: str) -> bool:
1068+
return resource.startswith("request:")
1069+
1070+
def _is_resource_scoped_admission_candidate(self, item: SchedulableTask) -> bool:
1071+
return item.request_resource_key is not None or item.group.key.kind != "local"
1072+
1073+
def _task_output_columns(self, task: Task) -> tuple[str, ...]:
1074+
return self._task_flow_identity(task) or (task.column,)
1075+
1076+
def _task_depends_on_any(self, task: Task, blocked_columns: set[str]) -> bool:
1077+
return any(self._column_depends_on_any(column, blocked_columns) for column in self._task_output_columns(task))
1078+
1079+
def _task_reaches_any(self, task: Task, blocked_columns: set[str]) -> bool:
1080+
return any(self._column_reaches_any(column, blocked_columns) for column in self._task_output_columns(task))
1081+
1082+
def _column_depends_on_any(self, column: str, blocked_columns: set[str]) -> bool:
1083+
return bool(self._transitive_upstream_columns(column) & blocked_columns)
1084+
1085+
def _column_reaches_any(self, column: str, blocked_columns: set[str]) -> bool:
1086+
return bool(self._transitive_downstream_columns(column) & blocked_columns)
1087+
1088+
def _transitive_upstream_columns(self, column: str) -> frozenset[str]:
1089+
cached = self._transitive_upstream_cache.get(column)
1090+
if cached is not None:
1091+
return cached
1092+
result = self._walk_graph(column, upstream=True)
1093+
self._transitive_upstream_cache[column] = result
1094+
return result
1095+
1096+
def _transitive_downstream_columns(self, column: str) -> frozenset[str]:
1097+
cached = self._transitive_downstream_cache.get(column)
1098+
if cached is not None:
1099+
return cached
1100+
result = self._walk_graph(column, upstream=False)
1101+
self._transitive_downstream_cache[column] = result
1102+
return result
1103+
1104+
def _walk_graph(self, column: str, *, upstream: bool) -> frozenset[str]:
1105+
next_columns = self._graph.get_upstream_columns if upstream else self._graph.get_downstream_columns
1106+
to_visit = list(next_columns(column))
1107+
seen: set[str] = set()
1108+
while to_visit:
1109+
next_column = to_visit.pop()
1110+
if next_column in seen:
1111+
continue
1112+
seen.add(next_column)
1113+
to_visit.extend(next_columns(next_column))
1114+
return frozenset(seen)
1115+
1116+
def _deferred_admission_diagnostics(self) -> dict[str, object]:
1117+
deferred_items = tuple(self._schedulable_task(task) for task in self._deferred)
1118+
diagnostics: dict[str, object] = {
1119+
"count": len(self._deferred),
1120+
"scope": "localized" if self._deferred else "none",
1121+
"blocks_next_row_group": False,
1122+
"columns": dict(Counter(task.column for task in self._deferred)),
1123+
"request_resources": {},
1124+
"scheduler_resources": {},
1125+
"candidate_columns": (),
1126+
"independent_candidate_columns": (),
1127+
}
1128+
if not self._deferred:
1129+
return diagnostics
1130+
request_resource_counts = Counter(
1131+
label
1132+
for item in deferred_items
1133+
if (label := _request_resource_label(item.request_resource_key)) is not None
1134+
)
1135+
scheduler_resource_counts = Counter(
1136+
resource
1137+
for item in deferred_items
1138+
for resource in item.resource_request.amounts
1139+
if self._is_localized_admission_resource(resource)
1140+
)
1141+
diagnostics["request_resources"] = dict(request_resource_counts)
1142+
diagnostics["scheduler_resources"] = dict(scheduler_resource_counts)
1143+
next_row_group = self._next_unadmitted_row_group()
1144+
if next_row_group is None:
1145+
return diagnostics
1146+
analysis = self._deferred_admission_analysis(*next_row_group)
1147+
diagnostics["blocks_next_row_group"] = analysis.blocks
1148+
diagnostics["candidate_columns"] = analysis.candidate_columns
1149+
diagnostics["independent_candidate_columns"] = analysis.independent_candidate_columns
1150+
return diagnostics
9781151

9791152
def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]:
9801153
queue_view = self._fair_queue.view()
@@ -999,42 +1172,48 @@ def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]:
9991172
"llm_wait_leased": task_view.leased_resources.get("llm_wait", 0),
10001173
"llm_wait_available": task_view.resources_available.get("llm_wait", 0),
10011174
"blocked_reasons": dict(self._row_group_admission_blocked_reasons),
1175+
"deferred_admission": self._deferred_admission_diagnostics(),
10021176
}
10031177

10041178
async def _admit_row_groups(self) -> None:
10051179
"""Admit row groups as semaphore slots become available."""
10061180
all_admitted = True
1007-
for rg_id, rg_size in self._row_groups:
1008-
await self._wait_for_row_group_admission_capacity(rg_size)
1009-
if self._early_shutdown or self._fatal_worker_error is not None:
1010-
all_admitted = False
1011-
break
1012-
await self._rg_semaphore.acquire()
1013-
if self._early_shutdown or self._fatal_worker_error is not None:
1014-
self._rg_semaphore.release()
1015-
all_admitted = False
1016-
break
1017-
if not self._row_group_row_guard_allows(rg_size):
1018-
self._rg_semaphore.release()
1181+
try:
1182+
for rg_id, rg_size in self._row_groups:
1183+
self._row_group_admission_pending = (rg_id, rg_size)
10191184
await self._wait_for_row_group_admission_capacity(rg_size)
1185+
if self._early_shutdown or self._fatal_worker_error is not None:
1186+
all_admitted = False
1187+
break
10201188
await self._rg_semaphore.acquire()
10211189
if self._early_shutdown or self._fatal_worker_error is not None:
10221190
self._rg_semaphore.release()
10231191
all_admitted = False
10241192
break
1025-
self._rg_states[rg_id] = _RowGroupState(size=rg_size)
1026-
1027-
if self._buffer_manager is not None:
1028-
self._buffer_manager.init_row_group(rg_id, rg_size)
1029-
1030-
await self._dispatch_seeds(rg_id, rg_size)
1031-
self._emit_scheduler_event(
1032-
"row_group_admitted",
1033-
diagnostics=self._row_group_admission_diagnostics(reason="admitted")
1034-
| {"row_group": rg_id, "row_group_size": rg_size},
1035-
)
1036-
self._emit_scheduler_health_snapshot("row_group_admitted")
1037-
self._wake_event.set()
1193+
if not self._row_group_row_guard_allows(rg_size):
1194+
self._rg_semaphore.release()
1195+
await self._wait_for_row_group_admission_capacity(rg_size)
1196+
await self._rg_semaphore.acquire()
1197+
if self._early_shutdown or self._fatal_worker_error is not None:
1198+
self._rg_semaphore.release()
1199+
all_admitted = False
1200+
break
1201+
self._row_group_admission_pending = None
1202+
self._rg_states[rg_id] = _RowGroupState(size=rg_size)
1203+
1204+
if self._buffer_manager is not None:
1205+
self._buffer_manager.init_row_group(rg_id, rg_size)
1206+
1207+
await self._dispatch_seeds(rg_id, rg_size)
1208+
self._emit_scheduler_event(
1209+
"row_group_admitted",
1210+
diagnostics=self._row_group_admission_diagnostics(reason="admitted")
1211+
| {"row_group": rg_id, "row_group_size": rg_size},
1212+
)
1213+
self._emit_scheduler_health_snapshot("row_group_admitted")
1214+
self._wake_event.set()
1215+
finally:
1216+
self._row_group_admission_pending = None
10381217
self._all_rgs_admitted = all_admitted
10391218
self._wake_event.set()
10401219

0 commit comments

Comments
 (0)