Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from data_designer.engine.dataset_builders.scheduling.resolver import TaskSchedulingResolver
from data_designer.engine.dataset_builders.scheduling.resources import (
SchedulableTask,
request_scheduler_resource_key,
stable_task_id,
)
from data_designer.engine.dataset_builders.scheduling.task_admission import (
Expand Down Expand Up @@ -145,6 +146,15 @@ class _DispatchOutcome:
admission_blocked: bool = False


@dataclass(frozen=True)
class _DeferredAdmissionAnalysis:
"""Deferred retry pressure as seen by adaptive row-group admission."""

blocks: bool
candidate_columns: tuple[str, ...]
independent_candidate_columns: tuple[str, ...]


class AsyncTaskScheduler:
"""Dependency-aware async task scheduler for the dataset builder.

Expand Down Expand Up @@ -329,6 +339,12 @@ def __init__(
self._row_group_admission_pressure_ticks = 0
self._row_group_admission_blocked_reasons: Counter[str] = Counter()
self._adaptive_max_admitted_rows = self._max_admitted_rows_guardrail()
self._row_group_admission_pending: tuple[int, int] | None = None
self._deferred_admission_cache: tuple[tuple[Task, ...], tuple[int, int], _DeferredAdmissionAnalysis] | None = (
None
)
self._transitive_upstream_cache: dict[str, frozenset[str]] = {}
self._transitive_downstream_cache: dict[str, frozenset[str]] = {}
self._request_pressure_provider = request_pressure_provider
self._request_pressure_advisory = request_pressure_advisory and request_pressure_provider is not None
self._request_pressure_advisory_skips = 0
Expand Down Expand Up @@ -944,13 +960,14 @@ def _maybe_update_adaptive_row_group_target(self) -> None:
self._row_group_admission_event.set()

def _adaptive_row_group_block_reason(self) -> str | None:
if self._deferred:
return "deferred_tasks"
next_size = self._next_unadmitted_row_group_size()
if next_size is None:
next_row_group = self._next_unadmitted_row_group()
if next_row_group is None:
return "no_pending_row_groups"
next_rg_id, next_size = next_row_group
if not self._row_group_row_guard_allows(next_size):
return "max_admitted_rows"
if self._deferred and self._deferred_admission_analysis(next_rg_id, next_size).blocks:
return "deferred_tasks"
queue_view = self._fair_queue.view()
queue_guard = self._max_in_flight_tasks * 4
if queue_view.queued_total >= queue_guard:
Expand All @@ -968,13 +985,169 @@ def _adaptive_row_group_block_reason(self) -> str | None:
return "queued_llm_demand"
return None

def _next_unadmitted_row_group_size(self) -> int | None:
for rg_id, rg_size in self._row_groups:
if rg_id not in self._rg_states and not self._tracker.is_row_group_complete(
rg_id, rg_size, self._graph.columns
):
return rg_size
return None
def _next_unadmitted_row_group(self) -> tuple[int, int] | None:
pending = self._row_group_admission_pending
if pending is None:
return None
rg_id, rg_size = pending
if rg_id in self._rg_states or self._tracker.is_row_group_complete(rg_id, rg_size, self._graph.columns):
return None
return pending

def _deferred_admission_analysis(self, row_group: int, row_group_size: int) -> _DeferredAdmissionAnalysis:
cache_key = (tuple(self._deferred), (row_group, row_group_size))
if self._deferred_admission_cache is not None and self._deferred_admission_cache[:2] == cache_key:
return self._deferred_admission_cache[2]
deferred_items = tuple(self._schedulable_task(task) for task in self._deferred)
deferred_keys = {key for item in deferred_items for key in self._localized_deferred_admission_keys(item)}
candidates = tuple(
(item, self._localized_deferred_admission_keys(item))
for item in self._row_group_admission_candidate_tasks(row_group, row_group_size)
)
blocked_columns: set[str] = set()
for item in deferred_items:
blocked_columns.update(self._task_output_columns(item.payload))
for item, keys in candidates:
if keys & deferred_keys:
blocked_columns.update(self._task_output_columns(item.payload))
independent_candidates = tuple(
item.payload.column
for item, keys in candidates
if not (keys & deferred_keys)
and not self._task_depends_on_any(item.payload, blocked_columns)
and (
self._is_resource_scoped_admission_candidate(item)
or not self._task_reaches_any(item.payload, blocked_columns)
)
)
blocks = bool(deferred_items) and not independent_candidates
analysis = _DeferredAdmissionAnalysis(
blocks=blocks,
candidate_columns=tuple(item.payload.column for item, _keys in candidates),
independent_candidate_columns=independent_candidates,
)
self._deferred_admission_cache = (*cache_key, analysis)
return analysis

def _row_group_admission_candidate_tasks(
self,
row_group: int,
row_group_size: int,
) -> tuple[SchedulableTask, ...]:
tasks: list[SchedulableTask] = []
seen_generators: set[int] = set()
for column in self._graph.get_topological_order():
generator_id = id(self._generators[column])
if generator_id in seen_generators:
continue
seen_generators.add(generator_id)
strategy = self._graph.get_strategy(column)
if strategy == GenerationStrategy.CELL_BY_CELL:
if row_group_size <= 0:
continue
task = Task(column=column, row_group=row_group, row_index=0, task_type="cell")
elif column in self._seed_cols:
task = Task(column=column, row_group=row_group, row_index=None, task_type="from_scratch")
else:
task = Task(column=column, row_group=row_group, row_index=None, task_type="batch")
tasks.append(self._schedulable_task(task))
return tuple(tasks)

def _localized_deferred_admission_keys(self, item: SchedulableTask) -> set[str]:
if item.request_resource_key is not None:
resource = item.request_resource_key
return {
f"request_resource:{_request_resource_label(resource)}",
f"scheduler_resource:{request_scheduler_resource_key(resource)}",
}
identity = "/".join(item.group.key.identity)
return {f"group:{item.group.key.kind}:{identity}"}

@staticmethod
def _is_localized_admission_resource(resource: str) -> bool:
return resource.startswith("request:")

def _is_resource_scoped_admission_candidate(self, item: SchedulableTask) -> bool:
return item.request_resource_key is not None or item.group.key.kind != "local"

def _task_output_columns(self, task: Task) -> tuple[str, ...]:
return self._task_flow_identity(task) or (task.column,)

def _task_depends_on_any(self, task: Task, blocked_columns: set[str]) -> bool:
return any(self._column_depends_on_any(column, blocked_columns) for column in self._task_output_columns(task))

def _task_reaches_any(self, task: Task, blocked_columns: set[str]) -> bool:
return any(self._column_reaches_any(column, blocked_columns) for column in self._task_output_columns(task))

def _column_depends_on_any(self, column: str, blocked_columns: set[str]) -> bool:
return bool(self._transitive_upstream_columns(column) & blocked_columns)

def _column_reaches_any(self, column: str, blocked_columns: set[str]) -> bool:
return bool(self._transitive_downstream_columns(column) & blocked_columns)

def _transitive_upstream_columns(self, column: str) -> frozenset[str]:
cached = self._transitive_upstream_cache.get(column)
if cached is not None:
return cached
result = self._walk_graph(column, upstream=True)
self._transitive_upstream_cache[column] = result
return result

def _transitive_downstream_columns(self, column: str) -> frozenset[str]:
cached = self._transitive_downstream_cache.get(column)
if cached is not None:
return cached
result = self._walk_graph(column, upstream=False)
self._transitive_downstream_cache[column] = result
return result

def _walk_graph(self, column: str, *, upstream: bool) -> frozenset[str]:
next_columns = self._graph.get_upstream_columns if upstream else self._graph.get_downstream_columns
to_visit = list(next_columns(column))
seen: set[str] = set()
while to_visit:
next_column = to_visit.pop()
if next_column in seen:
continue
seen.add(next_column)
to_visit.extend(next_columns(next_column))
return frozenset(seen)

def _deferred_admission_diagnostics(self) -> dict[str, object]:
deferred_items = tuple(self._schedulable_task(task) for task in self._deferred)
diagnostics: dict[str, object] = {
"count": len(self._deferred),
"scope": "localized" if self._deferred else "none",
"blocks_next_row_group": False,
"columns": dict(Counter(task.column for task in self._deferred)),
"request_resources": {},
"scheduler_resources": {},
"candidate_columns": (),
"independent_candidate_columns": (),
}
if not self._deferred:
return diagnostics
request_resource_counts = Counter(
label
for item in deferred_items
if (label := _request_resource_label(item.request_resource_key)) is not None
)
scheduler_resource_counts = Counter(
resource
for item in deferred_items
for resource in item.resource_request.amounts
if self._is_localized_admission_resource(resource)
)
diagnostics["request_resources"] = dict(request_resource_counts)
diagnostics["scheduler_resources"] = dict(scheduler_resource_counts)
next_row_group = self._next_unadmitted_row_group()
if next_row_group is None:
return diagnostics
analysis = self._deferred_admission_analysis(*next_row_group)
diagnostics["blocks_next_row_group"] = analysis.blocks
diagnostics["candidate_columns"] = analysis.candidate_columns
diagnostics["independent_candidate_columns"] = analysis.independent_candidate_columns
return diagnostics

def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]:
queue_view = self._fair_queue.view()
Expand All @@ -999,42 +1172,48 @@ def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]:
"llm_wait_leased": task_view.leased_resources.get("llm_wait", 0),
"llm_wait_available": task_view.resources_available.get("llm_wait", 0),
"blocked_reasons": dict(self._row_group_admission_blocked_reasons),
"deferred_admission": self._deferred_admission_diagnostics(),
}

async def _admit_row_groups(self) -> None:
"""Admit row groups as semaphore slots become available."""
all_admitted = True
for rg_id, rg_size in self._row_groups:
await self._wait_for_row_group_admission_capacity(rg_size)
if self._early_shutdown or self._fatal_worker_error is not None:
all_admitted = False
break
await self._rg_semaphore.acquire()
if self._early_shutdown or self._fatal_worker_error is not None:
self._rg_semaphore.release()
all_admitted = False
break
if not self._row_group_row_guard_allows(rg_size):
self._rg_semaphore.release()
try:
for rg_id, rg_size in self._row_groups:
self._row_group_admission_pending = (rg_id, rg_size)
await self._wait_for_row_group_admission_capacity(rg_size)
if self._early_shutdown or self._fatal_worker_error is not None:
all_admitted = False
break
await self._rg_semaphore.acquire()
if self._early_shutdown or self._fatal_worker_error is not None:
self._rg_semaphore.release()
all_admitted = False
break
self._rg_states[rg_id] = _RowGroupState(size=rg_size)

if self._buffer_manager is not None:
self._buffer_manager.init_row_group(rg_id, rg_size)

await self._dispatch_seeds(rg_id, rg_size)
self._emit_scheduler_event(
"row_group_admitted",
diagnostics=self._row_group_admission_diagnostics(reason="admitted")
| {"row_group": rg_id, "row_group_size": rg_size},
)
self._emit_scheduler_health_snapshot("row_group_admitted")
self._wake_event.set()
if not self._row_group_row_guard_allows(rg_size):
self._rg_semaphore.release()
await self._wait_for_row_group_admission_capacity(rg_size)
await self._rg_semaphore.acquire()
if self._early_shutdown or self._fatal_worker_error is not None:
self._rg_semaphore.release()
all_admitted = False
break
self._row_group_admission_pending = None
self._rg_states[rg_id] = _RowGroupState(size=rg_size)

if self._buffer_manager is not None:
self._buffer_manager.init_row_group(rg_id, rg_size)

await self._dispatch_seeds(rg_id, rg_size)
self._emit_scheduler_event(
"row_group_admitted",
diagnostics=self._row_group_admission_diagnostics(reason="admitted")
| {"row_group": rg_id, "row_group_size": rg_size},
)
self._emit_scheduler_health_snapshot("row_group_admitted")
self._wake_event.set()
finally:
self._row_group_admission_pending = None
self._all_rgs_admitted = all_admitted
self._wake_event.set()

Expand Down
Loading
Loading