Skip to content

Commit 0fdea84

Browse files
authored
feat: add fair async task scheduling (#639)
1 parent d14c9b3 commit 0fdea84

11 files changed

Lines changed: 1397 additions & 89 deletions

File tree

architecture/dataset-builders.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Preparation (`_prepare_async_run`):
3535
4. Constructs `CompletionTracker`, `RowGroupBufferManager`, `AsyncTaskScheduler`
3636
5. Hooks `ProcessorRunner` for pre-batch and post-batch stages
3737

38-
`AsyncTaskScheduler` runs on a dedicated async loop with semaphore-based concurrency, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially.
38+
`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, semaphore-based capacity limits, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks are admitted through a virtual-time fair queue so one hot column or model-backed generator cannot consume the whole submission window before peer work gets a turn.
3939

4040
### Execution Graph
4141

@@ -123,7 +123,7 @@ DatasetBuilder.build()
123123
→ CompletionTracker.with_graph()
124124
→ AsyncTaskScheduler(semaphores, salvage_rounds)
125125
→ scheduler.run()
126-
→ for each row group, dispatch ready tasks from frontier
126+
→ for each row group, fairly admit ready tasks from frontier
127127
→ tasks execute generators, update CompletionTracker
128128
→ checkpoints via RowGroupBufferManager
129129
→ collect TaskTraces, emit telemetry
@@ -133,6 +133,7 @@ DatasetBuilder.build()
133133

134134
- **Dual execution engines behind one API.** The sequential engine is simpler and easier to debug; the async engine adds row-group parallelism for throughput. Users switch via an environment variable without changing their code.
135135
- **DAG-driven ordering** ensures columns with dependencies (e.g., a judge column that depends on a text column) are generated in the correct order, regardless of the order they appear in the config.
136+
- **Fair async admission** keeps the scheduler flowing across ready columns and model groups. Global semaphores still bound memory/coroutine growth, while per-group virtual-time queues prevent a large ready frontier from degenerating into a column-by-column wave. LLM admission caps are peer-sensitive: a solo model group can fill available global capacity, but once another scheduling group has queued work the saturated group yields until peers get admission slots or admitted tasks complete.
136137
- **Salvage rounds in async mode** retry failed tasks after all other tasks in a round complete, improving resilience against transient LLM failures without blocking the entire generation.
137138
- **Unified DAG construction.** `topologically_sort_column_configs` (in `execution_graph.py`) determines column ordering using Kahn's algorithm; the runtime `ExecutionGraph` adds strategy-aware dependency tracking for the async scheduler.
138139

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 167 additions & 55 deletions
Large diffs are not rendered by default.

packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@
8585

8686
from data_designer.engine.dataset_builders.async_scheduler import (
8787
DEFAULT_TASK_POOL_SIZE,
88-
LLM_WAIT_POOL_MULTIPLIER,
88+
GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER,
8989
AsyncTaskScheduler,
9090
)
9191
from data_designer.engine.dataset_builders.utils.async_concurrency import (
9292
AsyncConcurrentExecutor,
9393
ensure_async_engine_loop,
9494
)
95-
from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker
95+
from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta
9696
from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager
9797

9898

@@ -996,13 +996,18 @@ def _prepare_async_run(
996996

997997
# Pre-batch processor callback: runs after seed tasks complete for a row group.
998998
# If it raises, the scheduler propagates the error as DatasetGenerationError (fail-fast).
999-
def on_seeds_complete(rg_id: int, rg_size: int) -> None:
999+
def on_seeds_complete(rg_id: int, rg_size: int) -> FrontierDelta:
10001000
df = buffer_manager.get_dataframe(rg_id)
10011001
df = self._processor_runner.run_pre_batch_on_df(df, strict_row_count=True)
10021002
buffer_manager.replace_dataframe(rg_id, df)
1003+
deltas: list[FrontierDelta] = []
10031004
for ri in range(rg_size):
10041005
if buffer_manager.is_dropped(rg_id, ri) and not tracker.is_dropped(rg_id, ri):
1005-
tracker.drop_row(rg_id, ri)
1006+
deltas.append(tracker.drop_row(rg_id, ri))
1007+
return FrontierDelta(
1008+
added=tuple(task for delta in deltas for task in delta.added),
1009+
removed=tuple(task for delta in deltas for task in delta.removed),
1010+
)
10061011

10071012
# Post-batch processor callback: runs after all columns, before finalization.
10081013
def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
@@ -1022,7 +1027,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
10221027
row_groups=row_groups,
10231028
buffer_manager=buffer_manager,
10241029
max_submitted_tasks=DEFAULT_TASK_POOL_SIZE,
1025-
max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, LLM_WAIT_POOL_MULTIPLIER * aggregate),
1030+
max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER * aggregate),
10261031
on_finalize_row_group=on_finalize_row_group,
10271032
on_seeds_complete=(
10281033
on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None

packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
from collections import defaultdict
7+
from dataclasses import dataclass
78
from typing import TYPE_CHECKING
89

910
from data_designer.config.column_configs import GenerationStrategy
@@ -13,6 +14,18 @@
1314
from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph
1415

1516

17+
@dataclass(frozen=True)
18+
class FrontierDelta:
19+
"""Tasks added to or removed from the ready frontier by a tracker mutation."""
20+
21+
added: tuple[Task, ...] = ()
22+
removed: tuple[Task, ...] = ()
23+
24+
@property
25+
def empty(self) -> bool:
26+
return not self.added and not self.removed
27+
28+
1629
class CompletionTracker:
1730
"""Tracks which cells (column, row_group, row_index) are done.
1831
@@ -42,24 +55,34 @@ def with_graph(cls, graph: ExecutionGraph, row_groups: list[tuple[int, int]]) ->
4255
tracker._row_group_sizes = dict(row_groups)
4356
return tracker
4457

45-
def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> None:
58+
def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> FrontierDelta:
4659
self._validate_row_group(row_group)
4760
self._validate_strategy(column, GenerationStrategy.CELL_BY_CELL, "mark_cell_complete")
4861
self._completed[row_group][column].add(row_index)
62+
removed: list[Task] = []
63+
added: list[Task] = []
4964
if self._graph is not None:
50-
self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell"))
51-
self._enqueue_downstream(column, row_group, row_index=row_index)
65+
task = Task(column=column, row_group=row_group, row_index=row_index, task_type="cell")
66+
if self._discard_frontier_task(task):
67+
removed.append(task)
68+
added.extend(self._enqueue_downstream(column, row_group, row_index=row_index))
69+
return self._record_delta(added=added, removed=removed)
5270

53-
def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> None:
71+
def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> FrontierDelta:
5472
expected = self._validate_row_group(row_group)
5573
self._validate_strategy(column, GenerationStrategy.FULL_COLUMN, "mark_row_range_complete")
5674
if expected is not None and row_group_size != expected:
5775
raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}")
5876
self._completed[row_group][column] = set(range(row_group_size))
5977
self._batch_complete[row_group].add(column)
78+
removed: list[Task] = []
79+
added: list[Task] = []
6080
if self._graph is not None:
61-
self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch"))
62-
self._enqueue_downstream(column, row_group, row_index=None)
81+
task = Task(column=column, row_group=row_group, row_index=None, task_type="batch")
82+
if self._discard_frontier_task(task):
83+
removed.append(task)
84+
added.extend(self._enqueue_downstream(column, row_group, row_index=None))
85+
return self._record_delta(added=added, removed=removed)
6386

6487
def is_complete(self, ref: SliceRef) -> bool:
6588
return ref.row_index in self._completed.get(ref.row_group, {}).get(ref.column, set())
@@ -89,15 +112,20 @@ def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool:
89112
dropped = self._dropped.get(row_group_index, set())
90113
return all(ri in completed or ri in dropped for ri in range(rg_size))
91114

92-
def drop_row(self, row_group: int, row_index: int) -> None:
115+
def drop_row(self, row_group: int, row_index: int) -> FrontierDelta:
93116
self._validate_row_group(row_group)
94117
self._dropped[row_group].add(row_index)
118+
removed: list[Task] = []
119+
added: list[Task] = []
95120
if self._graph is not None:
96121
# Remove cell tasks for this row from the frontier
97122
for col in self._graph.columns:
98-
self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell"))
123+
task = Task(column=col, row_group=row_group, row_index=row_index, task_type="cell")
124+
if self._discard_frontier_task(task):
125+
removed.append(task)
99126
# Dropping a row may unblock batch downstream tasks
100-
self._reevaluate_batch_tasks(row_group)
127+
added.extend(self._reevaluate_batch_tasks(row_group))
128+
return self._record_delta(added=added, removed=removed)
101129

102130
def is_dropped(self, row_group: int, row_index: int) -> bool:
103131
return row_index in self._dropped.get(row_group, set())
@@ -129,6 +157,10 @@ def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None =
129157
t for t in self._frontier if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs)
130158
]
131159

160+
def is_frontier_task(self, task: Task) -> bool:
161+
"""Return whether *task* is still in the ready frontier."""
162+
return task in self._frontier
163+
132164
def seed_frontier(self) -> None:
133165
"""Populate the frontier with root tasks (columns with no upstream deps).
134166
@@ -147,10 +179,26 @@ def seed_frontier(self) -> None:
147179
else:
148180
self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch"))
149181

150-
def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> None:
182+
def _record_delta(self, *, added: list[Task], removed: list[Task]) -> FrontierDelta:
183+
return FrontierDelta(added=tuple(added), removed=tuple(removed))
184+
185+
def _add_frontier_task(self, task: Task) -> bool:
186+
if task in self._frontier:
187+
return False
188+
self._frontier.add(task)
189+
return True
190+
191+
def _discard_frontier_task(self, task: Task) -> bool:
192+
if task not in self._frontier:
193+
return False
194+
self._frontier.remove(task)
195+
return True
196+
197+
def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> list[Task]:
151198
"""Add newly-ready downstream tasks to the frontier."""
152199
if self._graph is None:
153200
raise RuntimeError("This method requires a graph to be set.")
201+
added: list[Task] = []
154202
rg_completed = self._completed.get(row_group, {})
155203
rg_dropped = self._dropped.get(row_group, set())
156204
rg_batch_complete = self._batch_complete.get(row_group, set())
@@ -175,7 +223,8 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None
175223
and all(row_index in s for s in cell_up_completed)
176224
):
177225
task = Task(column=down, row_group=row_group, row_index=row_index, task_type="cell")
178-
self._frontier.add(task)
226+
if self._add_frontier_task(task):
227+
added.append(task)
179228
else:
180229
# Batch completion: check all non-dropped, non-complete rows
181230
down_completed = rg_completed.get(down, set())
@@ -184,19 +233,23 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None
184233
continue
185234
if all(ri in s for s in cell_up_completed):
186235
task = Task(column=down, row_group=row_group, row_index=ri, task_type="cell")
187-
self._frontier.add(task)
236+
if self._add_frontier_task(task):
237+
added.append(task)
188238
else:
189239
# FULL_COLUMN downstream: ready when all cell upstreams are fully complete
190240
if down not in rg_batch_complete and self._are_cell_ups_complete(
191241
cell_ups, rg_completed, rg_size, rg_dropped
192242
):
193243
task = Task(column=down, row_group=row_group, row_index=None, task_type="batch")
194-
self._frontier.add(task)
244+
if self._add_frontier_task(task):
245+
added.append(task)
246+
return added
195247

196-
def _reevaluate_batch_tasks(self, row_group: int) -> None:
248+
def _reevaluate_batch_tasks(self, row_group: int) -> list[Task]:
197249
"""Check if any batch tasks became ready after a row was dropped."""
198250
if self._graph is None:
199251
raise RuntimeError("This method requires a graph to be set.")
252+
added: list[Task] = []
200253
rg_completed = self._completed.get(row_group, {})
201254
rg_dropped = self._dropped.get(row_group, set())
202255
rg_batch_complete = self._batch_complete.get(row_group, set())
@@ -212,7 +265,9 @@ def _reevaluate_batch_tasks(self, row_group: int) -> None:
212265
continue
213266
if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped):
214267
task = Task(column=col, row_group=row_group, row_index=None, task_type="batch")
215-
self._frontier.add(task)
268+
if self._add_frontier_task(task):
269+
added.append(task)
270+
return added
216271

217272
def _are_cell_ups_complete(
218273
self,

0 commit comments

Comments
 (0)