4141from data_designer .engine .dataset_builders .scheduling .resolver import TaskSchedulingResolver
4242from data_designer .engine .dataset_builders .scheduling .resources import (
4343 SchedulableTask ,
44+ request_scheduler_resource_key ,
4445 stable_task_id ,
4546)
4647from data_designer .engine .dataset_builders .scheduling .task_admission import (
@@ -145,6 +146,15 @@ class _DispatchOutcome:
145146 admission_blocked : bool = False
146147
147148
149+ @dataclass (frozen = True )
150+ class _DeferredAdmissionAnalysis :
151+ """Deferred retry pressure as seen by adaptive row-group admission."""
152+
153+ blocks : bool
154+ candidate_columns : tuple [str , ...]
155+ independent_candidate_columns : tuple [str , ...]
156+
157+
148158class AsyncTaskScheduler :
149159 """Dependency-aware async task scheduler for the dataset builder.
150160
@@ -329,6 +339,12 @@ def __init__(
329339 self ._row_group_admission_pressure_ticks = 0
330340 self ._row_group_admission_blocked_reasons : Counter [str ] = Counter ()
331341 self ._adaptive_max_admitted_rows = self ._max_admitted_rows_guardrail ()
342+ self ._row_group_admission_pending : tuple [int , int ] | None = None
343+ self ._deferred_admission_cache : tuple [tuple [Task , ...], tuple [int , int ], _DeferredAdmissionAnalysis ] | None = (
344+ None
345+ )
346+ self ._transitive_upstream_cache : dict [str , frozenset [str ]] = {}
347+ self ._transitive_downstream_cache : dict [str , frozenset [str ]] = {}
332348 self ._request_pressure_provider = request_pressure_provider
333349 self ._request_pressure_advisory = request_pressure_advisory and request_pressure_provider is not None
334350 self ._request_pressure_advisory_skips = 0
@@ -944,13 +960,14 @@ def _maybe_update_adaptive_row_group_target(self) -> None:
944960 self ._row_group_admission_event .set ()
945961
946962 def _adaptive_row_group_block_reason (self ) -> str | None :
947- if self ._deferred :
948- return "deferred_tasks"
949- next_size = self ._next_unadmitted_row_group_size ()
950- if next_size is None :
963+ next_row_group = self ._next_unadmitted_row_group ()
964+ if next_row_group is None :
951965 return "no_pending_row_groups"
966+ next_rg_id , next_size = next_row_group
952967 if not self ._row_group_row_guard_allows (next_size ):
953968 return "max_admitted_rows"
969+ if self ._deferred and self ._deferred_admission_analysis (next_rg_id , next_size ).blocks :
970+ return "deferred_tasks"
954971 queue_view = self ._fair_queue .view ()
955972 queue_guard = self ._max_in_flight_tasks * 4
956973 if queue_view .queued_total >= queue_guard :
@@ -968,13 +985,169 @@ def _adaptive_row_group_block_reason(self) -> str | None:
968985 return "queued_llm_demand"
969986 return None
970987
971- def _next_unadmitted_row_group_size (self ) -> int | None :
972- for rg_id , rg_size in self ._row_groups :
973- if rg_id not in self ._rg_states and not self ._tracker .is_row_group_complete (
974- rg_id , rg_size , self ._graph .columns
975- ):
976- return rg_size
977- return None
988+ def _next_unadmitted_row_group (self ) -> tuple [int , int ] | None :
989+ pending = self ._row_group_admission_pending
990+ if pending is None :
991+ return None
992+ rg_id , rg_size = pending
993+ if rg_id in self ._rg_states or self ._tracker .is_row_group_complete (rg_id , rg_size , self ._graph .columns ):
994+ return None
995+ return pending
996+
997+ def _deferred_admission_analysis (self , row_group : int , row_group_size : int ) -> _DeferredAdmissionAnalysis :
998+ cache_key = (tuple (self ._deferred ), (row_group , row_group_size ))
999+ if self ._deferred_admission_cache is not None and self ._deferred_admission_cache [:2 ] == cache_key :
1000+ return self ._deferred_admission_cache [2 ]
1001+ deferred_items = tuple (self ._schedulable_task (task ) for task in self ._deferred )
1002+ deferred_keys = {key for item in deferred_items for key in self ._localized_deferred_admission_keys (item )}
1003+ candidates = tuple (
1004+ (item , self ._localized_deferred_admission_keys (item ))
1005+ for item in self ._row_group_admission_candidate_tasks (row_group , row_group_size )
1006+ )
1007+ blocked_columns : set [str ] = set ()
1008+ for item in deferred_items :
1009+ blocked_columns .update (self ._task_output_columns (item .payload ))
1010+ for item , keys in candidates :
1011+ if keys & deferred_keys :
1012+ blocked_columns .update (self ._task_output_columns (item .payload ))
1013+ independent_candidates = tuple (
1014+ item .payload .column
1015+ for item , keys in candidates
1016+ if not (keys & deferred_keys )
1017+ and not self ._task_depends_on_any (item .payload , blocked_columns )
1018+ and (
1019+ self ._is_resource_scoped_admission_candidate (item )
1020+ or not self ._task_reaches_any (item .payload , blocked_columns )
1021+ )
1022+ )
1023+ blocks = bool (deferred_items ) and not independent_candidates
1024+ analysis = _DeferredAdmissionAnalysis (
1025+ blocks = blocks ,
1026+ candidate_columns = tuple (item .payload .column for item , _keys in candidates ),
1027+ independent_candidate_columns = independent_candidates ,
1028+ )
1029+ self ._deferred_admission_cache = (* cache_key , analysis )
1030+ return analysis
1031+
1032+ def _row_group_admission_candidate_tasks (
1033+ self ,
1034+ row_group : int ,
1035+ row_group_size : int ,
1036+ ) -> tuple [SchedulableTask , ...]:
1037+ tasks : list [SchedulableTask ] = []
1038+ seen_generators : set [int ] = set ()
1039+ for column in self ._graph .get_topological_order ():
1040+ generator_id = id (self ._generators [column ])
1041+ if generator_id in seen_generators :
1042+ continue
1043+ seen_generators .add (generator_id )
1044+ strategy = self ._graph .get_strategy (column )
1045+ if strategy == GenerationStrategy .CELL_BY_CELL :
1046+ if row_group_size <= 0 :
1047+ continue
1048+ task = Task (column = column , row_group = row_group , row_index = 0 , task_type = "cell" )
1049+ elif column in self ._seed_cols :
1050+ task = Task (column = column , row_group = row_group , row_index = None , task_type = "from_scratch" )
1051+ else :
1052+ task = Task (column = column , row_group = row_group , row_index = None , task_type = "batch" )
1053+ tasks .append (self ._schedulable_task (task ))
1054+ return tuple (tasks )
1055+
1056+ def _localized_deferred_admission_keys (self , item : SchedulableTask ) -> set [str ]:
1057+ if item .request_resource_key is not None :
1058+ resource = item .request_resource_key
1059+ return {
1060+ f"request_resource:{ _request_resource_label (resource )} " ,
1061+ f"scheduler_resource:{ request_scheduler_resource_key (resource )} " ,
1062+ }
1063+ identity = "/" .join (item .group .key .identity )
1064+ return {f"group:{ item .group .key .kind } :{ identity } " }
1065+
1066+ @staticmethod
1067+ def _is_localized_admission_resource (resource : str ) -> bool :
1068+ return resource .startswith ("request:" )
1069+
1070+ def _is_resource_scoped_admission_candidate (self , item : SchedulableTask ) -> bool :
1071+ return item .request_resource_key is not None or item .group .key .kind != "local"
1072+
1073+ def _task_output_columns (self , task : Task ) -> tuple [str , ...]:
1074+ return self ._task_flow_identity (task ) or (task .column ,)
1075+
1076+ def _task_depends_on_any (self , task : Task , blocked_columns : set [str ]) -> bool :
1077+ return any (self ._column_depends_on_any (column , blocked_columns ) for column in self ._task_output_columns (task ))
1078+
1079+ def _task_reaches_any (self , task : Task , blocked_columns : set [str ]) -> bool :
1080+ return any (self ._column_reaches_any (column , blocked_columns ) for column in self ._task_output_columns (task ))
1081+
1082+ def _column_depends_on_any (self , column : str , blocked_columns : set [str ]) -> bool :
1083+ return bool (self ._transitive_upstream_columns (column ) & blocked_columns )
1084+
1085+ def _column_reaches_any (self , column : str , blocked_columns : set [str ]) -> bool :
1086+ return bool (self ._transitive_downstream_columns (column ) & blocked_columns )
1087+
1088+ def _transitive_upstream_columns (self , column : str ) -> frozenset [str ]:
1089+ cached = self ._transitive_upstream_cache .get (column )
1090+ if cached is not None :
1091+ return cached
1092+ result = self ._walk_graph (column , upstream = True )
1093+ self ._transitive_upstream_cache [column ] = result
1094+ return result
1095+
1096+ def _transitive_downstream_columns (self , column : str ) -> frozenset [str ]:
1097+ cached = self ._transitive_downstream_cache .get (column )
1098+ if cached is not None :
1099+ return cached
1100+ result = self ._walk_graph (column , upstream = False )
1101+ self ._transitive_downstream_cache [column ] = result
1102+ return result
1103+
1104+ def _walk_graph (self , column : str , * , upstream : bool ) -> frozenset [str ]:
1105+ next_columns = self ._graph .get_upstream_columns if upstream else self ._graph .get_downstream_columns
1106+ to_visit = list (next_columns (column ))
1107+ seen : set [str ] = set ()
1108+ while to_visit :
1109+ next_column = to_visit .pop ()
1110+ if next_column in seen :
1111+ continue
1112+ seen .add (next_column )
1113+ to_visit .extend (next_columns (next_column ))
1114+ return frozenset (seen )
1115+
1116+ def _deferred_admission_diagnostics (self ) -> dict [str , object ]:
1117+ deferred_items = tuple (self ._schedulable_task (task ) for task in self ._deferred )
1118+ diagnostics : dict [str , object ] = {
1119+ "count" : len (self ._deferred ),
1120+ "scope" : "localized" if self ._deferred else "none" ,
1121+ "blocks_next_row_group" : False ,
1122+ "columns" : dict (Counter (task .column for task in self ._deferred )),
1123+ "request_resources" : {},
1124+ "scheduler_resources" : {},
1125+ "candidate_columns" : (),
1126+ "independent_candidate_columns" : (),
1127+ }
1128+ if not self ._deferred :
1129+ return diagnostics
1130+ request_resource_counts = Counter (
1131+ label
1132+ for item in deferred_items
1133+ if (label := _request_resource_label (item .request_resource_key )) is not None
1134+ )
1135+ scheduler_resource_counts = Counter (
1136+ resource
1137+ for item in deferred_items
1138+ for resource in item .resource_request .amounts
1139+ if self ._is_localized_admission_resource (resource )
1140+ )
1141+ diagnostics ["request_resources" ] = dict (request_resource_counts )
1142+ diagnostics ["scheduler_resources" ] = dict (scheduler_resource_counts )
1143+ next_row_group = self ._next_unadmitted_row_group ()
1144+ if next_row_group is None :
1145+ return diagnostics
1146+ analysis = self ._deferred_admission_analysis (* next_row_group )
1147+ diagnostics ["blocks_next_row_group" ] = analysis .blocks
1148+ diagnostics ["candidate_columns" ] = analysis .candidate_columns
1149+ diagnostics ["independent_candidate_columns" ] = analysis .independent_candidate_columns
1150+ return diagnostics
9781151
9791152 def _row_group_admission_diagnostics (self , * , reason : str ) -> dict [str , object ]:
9801153 queue_view = self ._fair_queue .view ()
@@ -999,42 +1172,48 @@ def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]:
9991172 "llm_wait_leased" : task_view .leased_resources .get ("llm_wait" , 0 ),
10001173 "llm_wait_available" : task_view .resources_available .get ("llm_wait" , 0 ),
10011174 "blocked_reasons" : dict (self ._row_group_admission_blocked_reasons ),
1175+ "deferred_admission" : self ._deferred_admission_diagnostics (),
10021176 }
10031177
10041178 async def _admit_row_groups (self ) -> None :
10051179 """Admit row groups as semaphore slots become available."""
10061180 all_admitted = True
1007- for rg_id , rg_size in self ._row_groups :
1008- await self ._wait_for_row_group_admission_capacity (rg_size )
1009- if self ._early_shutdown or self ._fatal_worker_error is not None :
1010- all_admitted = False
1011- break
1012- await self ._rg_semaphore .acquire ()
1013- if self ._early_shutdown or self ._fatal_worker_error is not None :
1014- self ._rg_semaphore .release ()
1015- all_admitted = False
1016- break
1017- if not self ._row_group_row_guard_allows (rg_size ):
1018- self ._rg_semaphore .release ()
1181+ try :
1182+ for rg_id , rg_size in self ._row_groups :
1183+ self ._row_group_admission_pending = (rg_id , rg_size )
10191184 await self ._wait_for_row_group_admission_capacity (rg_size )
1185+ if self ._early_shutdown or self ._fatal_worker_error is not None :
1186+ all_admitted = False
1187+ break
10201188 await self ._rg_semaphore .acquire ()
10211189 if self ._early_shutdown or self ._fatal_worker_error is not None :
10221190 self ._rg_semaphore .release ()
10231191 all_admitted = False
10241192 break
1025- self ._rg_states [rg_id ] = _RowGroupState (size = rg_size )
1026-
1027- if self ._buffer_manager is not None :
1028- self ._buffer_manager .init_row_group (rg_id , rg_size )
1029-
1030- await self ._dispatch_seeds (rg_id , rg_size )
1031- self ._emit_scheduler_event (
1032- "row_group_admitted" ,
1033- diagnostics = self ._row_group_admission_diagnostics (reason = "admitted" )
1034- | {"row_group" : rg_id , "row_group_size" : rg_size },
1035- )
1036- self ._emit_scheduler_health_snapshot ("row_group_admitted" )
1037- self ._wake_event .set ()
1193+ if not self ._row_group_row_guard_allows (rg_size ):
1194+ self ._rg_semaphore .release ()
1195+ await self ._wait_for_row_group_admission_capacity (rg_size )
1196+ await self ._rg_semaphore .acquire ()
1197+ if self ._early_shutdown or self ._fatal_worker_error is not None :
1198+ self ._rg_semaphore .release ()
1199+ all_admitted = False
1200+ break
1201+ self ._row_group_admission_pending = None
1202+ self ._rg_states [rg_id ] = _RowGroupState (size = rg_size )
1203+
1204+ if self ._buffer_manager is not None :
1205+ self ._buffer_manager .init_row_group (rg_id , rg_size )
1206+
1207+ await self ._dispatch_seeds (rg_id , rg_size )
1208+ self ._emit_scheduler_event (
1209+ "row_group_admitted" ,
1210+ diagnostics = self ._row_group_admission_diagnostics (reason = "admitted" )
1211+ | {"row_group" : rg_id , "row_group_size" : rg_size },
1212+ )
1213+ self ._emit_scheduler_health_snapshot ("row_group_admitted" )
1214+ self ._wake_event .set ()
1215+ finally :
1216+ self ._row_group_admission_pending = None
10381217 self ._all_rgs_admitted = all_admitted
10391218 self ._wake_event .set ()
10401219
0 commit comments