Skip to content

Commit 968792a

Browse files
Clean Qwen model input segmentation controls
Use model_input_window.max_duration_s as the single model-input safety limit, keep old chunking/planner names as rejection guards only, and refresh Qwen raw docs/tests for the current payload lifecycle path. Signed-off-by: Aaftab V <aaftabv@nvidia.com>
1 parent eb01540 commit 968792a

13 files changed

Lines changed: 1080 additions & 305 deletions

File tree

PR1967_FEATURE_WALKTHROUGH.md

Lines changed: 550 additions & 95 deletions
Large diffs are not rendered by default.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Qwen-Omni In-Process ASR Assets
2+
3+
This folder contains prompt templates used by the Qwen-Omni in-process ASR
4+
adapter.
5+
6+
The executable code path is:
7+
8+
```text
9+
Pipeline
10+
-> ManifestReader
11+
-> AudioPayloadMaterializeStage
12+
-> ASRStage(adapter_target=QwenOmniASRAdapter)
13+
-> PayloadReleaseStage
14+
-> ManifestWriterStage
15+
```
16+
17+
The adapter reads prompt text through `prompt_file`, `en_prompt_file`,
18+
`followup_prompt_file`, or `system_prompt_file`. Curator stage behavior remains
19+
outside the prompt files:
20+
21+
- graph expansion lives in `nemo_curator/pipeline/payload_lifecycle.py`;
22+
- audio decode and payload refs live in `nemo_curator/stages/payload_lifecycle.py`;
23+
- local/windowed ASR model-input segmentation and batching live in
24+
`nemo_curator/stages/audio/inference/asr/stage.py`;
25+
- Qwen/vLLM request construction lives in `nemo_curator/models/asr/qwen_omni.py`.
26+
27+
Prompt files may use `{language}` and `{transcript}` placeholders when the
28+
stage supplies language or reference text columns.

nemo_curator/pipeline/payload_lifecycle.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,36 @@ def expand_payload_lifecycle_stages(
8282

8383
reader = _last_manifest_reader(stages[: materialize_idx + 1])
8484
payload_specs = _payload_binding_specs(payload_cfg, stages=stages, consumers=consumers, reader=reader)
85+
_configure_planned_source_segment_inputs(reader, payload_cfg, payload_specs, config)
8586
_validate_payload_consumers(consumers, payload_specs)
87+
_validate_single_segment_planner_owner(
88+
reader,
89+
consumers,
90+
config=config,
91+
)
8692

87-
materializers = [_build_payload_materializer(reader, spec, payload_cfg, config, run_id=run_id) for spec in payload_specs]
93+
materializers = [
94+
_build_payload_materializer(reader, spec, payload_cfg, config, run_id=run_id)
95+
for spec in payload_specs
96+
]
8897
primary_spec = payload_specs[0]
8998
release = payload_release_stage_cls(
9099
name=str(payload_cfg.get("release_stage_name", "payload_release")),
91100
payload_ref_key=primary_spec.ref_key,
92101
waveform_key=primary_spec.waveform_key,
93102
)
94103

104+
assembler = _post_release_payload_lifecycle_stage(config, reader, consumers, primary_spec, run_id=run_id)
105+
95106
expanded: list[ProcessingStage] = []
96107
for idx, stage in enumerate(stages):
97108
expanded.append(stage)
98109
if idx == materialize_idx:
99110
expanded.extend(materializers)
100111
if idx == release_idx:
101112
expanded.append(release)
113+
if assembler is not None:
114+
expanded.append(assembler)
102115
logger.info("Expanded logical graph into payload lifecycle execution graph: {}", " -> ".join(stage.name for stage in expanded))
103116
return expanded
104117

@@ -249,6 +262,108 @@ def _stage_payload_bindings(stage: ProcessingStage) -> list[dict[str, str]]:
249262
return []
250263

251264

265+
def _configure_planned_source_segment_inputs(
266+
reader: ProcessingStage | None,
267+
payload_cfg: dict[str, Any],
268+
payload_specs: list[PayloadBindingSpec],
269+
config: Any,
270+
) -> None:
271+
if reader is None or not bool(getattr(reader, "enable_global_bucketing", False)):
272+
return
273+
scheduler_cfg = _config_section(config, "global_audio_scheduler")
274+
configured = scheduler_cfg.get("segment_input_keys", payload_cfg.get("segment_input_keys"))
275+
segment_input_keys: list[str] = []
276+
if configured is not None:
277+
segment_input_keys.extend(_normalise_string_list(configured, key="global_audio_scheduler.segment_input_keys"))
278+
segment_input_keys.extend(spec.source_key for spec in payload_specs)
279+
setattr(reader, "segment_input_keys", _dedupe_strings(segment_input_keys))
280+
setattr(reader, "run_id", _pipeline_run_id(config))
281+
if "parent_store_actor_name_prefix" in scheduler_cfg:
282+
setattr(reader, "parent_store_actor_name_prefix", str(scheduler_cfg["parent_store_actor_name_prefix"]))
283+
284+
285+
def _validate_single_segment_planner_owner(
286+
reader: ProcessingStage | None,
287+
consumers: list[ProcessingStage],
288+
*,
289+
config: Any,
290+
) -> None:
291+
if reader is None or not bool(getattr(reader, "enable_global_bucketing", False)):
292+
return
293+
owner_stage = _single_selector(getattr(reader, "owner_stage", None), key="global_audio_scheduler.owner_stage")
294+
matching_consumers = [stage for stage in consumers if owner_stage in _stage_match_idents(stage)]
295+
if not matching_consumers:
296+
available = sorted({ident for stage in consumers for ident in _stage_match_idents(stage)})
297+
msg = (
298+
"global_audio_scheduler.owner_stage must select exactly one stage listed in "
299+
"payload_lifecycle.consumers. Global bucketing has a single planning owner; "
300+
f"{owner_stage!r} was not found in payload consumers {available}."
301+
)
302+
raise ValueError(msg)
303+
if len(matching_consumers) > 1:
304+
names = [stage.name for stage in matching_consumers]
305+
msg = f"global_audio_scheduler.owner_stage must select exactly one payload consumer; matched {names}"
306+
raise ValueError(msg)
307+
_validate_planner_owner_has_largest_model_window(reader=reader, owner=matching_consumers[0], consumers=consumers)
308+
setattr(reader, "owner_stage", owner_stage)
309+
310+
311+
def _validate_planner_owner_has_largest_model_window(
312+
*,
313+
reader: ProcessingStage,
314+
owner: ProcessingStage,
315+
consumers: list[ProcessingStage],
316+
) -> None:
317+
owner_max_s = _required_positive_seconds(owner, "max_inference_duration_s")
318+
consumer_max_s = [(stage.name, _required_positive_seconds(stage, "max_inference_duration_s")) for stage in consumers]
319+
larger_consumers = [(name, max_s) for name, max_s in consumer_max_s if max_s > owner_max_s]
320+
if larger_consumers:
321+
details = ", ".join(f"{name}={value:g}s" for name, value in larger_consumers)
322+
msg = (
323+
"global_audio_scheduler.owner_stage must select the payload consumer with the largest "
324+
"max_inference_duration_s because the source planner emits one segment plan. "
325+
f"Selected owner {owner.name!r} has max_inference_duration_s={owner_max_s:g}s, "
326+
f"but larger consumer(s) exist: {details}."
327+
)
328+
raise ValueError(msg)
329+
330+
reader_max_s = _required_positive_seconds(reader, "max_inference_duration_s")
331+
if abs(reader_max_s - owner_max_s) > 1e-6:
332+
msg = (
333+
"ManifestReader(enable_global_bucketing=True).max_inference_duration_s must match the "
334+
"selected owner stage's max_inference_duration_s. "
335+
f"Reader has {reader_max_s:g}s, owner {owner.name!r} has {owner_max_s:g}s."
336+
)
337+
raise ValueError(msg)
338+
339+
340+
def _required_positive_seconds(stage: ProcessingStage, attr: str) -> float:
341+
value = getattr(stage, attr, None)
342+
if value is None:
343+
msg = f"Global bucketing requires stage {stage.name!r} to define positive {attr}"
344+
raise ValueError(msg)
345+
return _positive_seconds(value, label=f"{stage.name}.{attr}")
346+
347+
348+
def _optional_positive_seconds(stage: ProcessingStage, attr: str) -> float | None:
349+
value = getattr(stage, attr, None)
350+
if value is None:
351+
return None
352+
return _positive_seconds(value, label=f"{stage.name}.{attr}")
353+
354+
355+
def _positive_seconds(value: Any, *, label: str) -> float:
356+
try:
357+
seconds = float(value)
358+
except (TypeError, ValueError) as exc:
359+
msg = f"{label} must be a positive number of seconds, got {value!r}"
360+
raise TypeError(msg) from exc
361+
if seconds <= 0:
362+
msg = f"{label} must be > 0 seconds, got {seconds:g}"
363+
raise ValueError(msg)
364+
return seconds
365+
366+
252367
def _build_payload_materializer(
253368
reader: ProcessingStage | None,
254369
spec: PayloadBindingSpec,
@@ -274,6 +389,33 @@ def _build_payload_materializer(
274389
)
275390

276391

392+
def _post_release_payload_lifecycle_stage(
393+
config: Any,
394+
reader: ProcessingStage | None,
395+
consumers: list[ProcessingStage],
396+
primary_spec: PayloadBindingSpec,
397+
*,
398+
run_id: str,
399+
) -> ProcessingStage | None:
400+
if reader is None or not bool(getattr(reader, "enable_global_bucketing", False)):
401+
return None
402+
builder = getattr(reader, "build_payload_lifecycle_post_release_stage", None)
403+
if not callable(builder):
404+
msg = (
405+
"Global bucketing is enabled, but the source/reader stage does not provide "
406+
"build_payload_lifecycle_post_release_stage(). The central payload lifecycle "
407+
"planner only owns generic insertion order; modality-specific assembly must be "
408+
f"provided by the planner stage, got {type(reader).__name__}."
409+
)
410+
raise ValueError(msg)
411+
return builder(
412+
pipeline_config=config,
413+
consumers=consumers,
414+
primary_payload_spec=primary_spec,
415+
run_id=run_id,
416+
)
417+
418+
277419
def _pipeline_run_id(config: Any) -> str:
278420
value = _config_get(config, "_curator_pipeline_run_id")
279421
text = str(value or "").strip()
@@ -352,6 +494,20 @@ def _normalise_string_list(value: Any, *, key: str) -> list[str]:
352494
return result
353495

354496

497+
def _dedupe_strings(values: list[str]) -> list[str]:
498+
result: list[str] = []
499+
seen: set[str] = set()
500+
for value in values:
501+
text = str(value).strip()
502+
if text and text not in seen:
503+
seen.add(text)
504+
result.append(text)
505+
if not result:
506+
msg = "At least one non-empty string is required"
507+
raise ValueError(msg)
508+
return result
509+
510+
355511
def _single_selector(value: Any, *, key: str) -> str:
356512
values = _normalise_string_list(value, key=key)
357513
if len(values) != 1:

nemo_curator/pipeline/pipeline.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,26 @@ def build(self) -> None:
136136
self._built = True
137137
self._planned_stage_snapshot = list(self.stages)
138138

139+
def _expand_pipeline_graph(self, stages: list[ProcessingStage]) -> list[ProcessingStage]:
140+
"""Apply generic pipeline-level graph expansion rules."""
141+
from nemo_curator.pipeline.payload_lifecycle import expand_payload_lifecycle_stages
142+
143+
return expand_payload_lifecycle_stages(stages, self.config)
144+
139145
def _sync_public_stage_mutations(self) -> None:
140-
"""Respect direct ``pipeline.stages`` edits made through the public API.
146+
"""Preserve the historical public ``stages`` list mutation behavior.
141147
142-
The logical/execution split keeps graph expansion idempotent, but
143-
``stages`` is still a public list in Curator. If user code mutates that
144-
list directly, treat it as the new logical graph before planning.
148+
``_logical_stages`` is the canonical source for graph expansion, but
149+
existing user code may still mutate ``pipeline.stages`` directly. Treat
150+
those mutations as logical graph edits before planning instead of
151+
silently ignoring them.
145152
"""
146153
if self._built:
147154
if self.stages == self._planned_stage_snapshot:
148155
return
149156
logger.warning(
150-
"Pipeline '{}' execution-stage list was modified after build(); treating the current stages "
151-
"as the new logical graph",
152-
self.name,
157+
"Pipeline.stages was mutated after build(); treating the current public stages list "
158+
"as the new logical graph. Prefer Pipeline.add_stage() for future code."
153159
)
154160
self._clear_default_source_sink_roles()
155161
self._logical_stages = list(self.stages)
@@ -159,19 +165,13 @@ def _sync_public_stage_mutations(self) -> None:
159165

160166
if self.stages != self._logical_stages:
161167
logger.warning(
162-
"Pipeline '{}' stages list was modified directly; syncing it into the logical graph",
163-
self.name,
168+
"Pipeline.stages was mutated directly; syncing it into the logical graph. "
169+
"Prefer Pipeline.add_stage() for future code."
164170
)
165171
self._clear_default_source_sink_roles()
166172
self._logical_stages = list(self.stages)
167173
self._planned_stage_snapshot = []
168174

169-
def _expand_pipeline_graph(self, stages: list[ProcessingStage]) -> list[ProcessingStage]:
170-
"""Apply generic pipeline-level graph expansion rules."""
171-
from nemo_curator.pipeline.payload_lifecycle import expand_payload_lifecycle_stages
172-
173-
return expand_payload_lifecycle_stages(stages, self.config)
174-
175175
def _clear_default_source_sink_roles(self) -> None:
176176
"""Clear source/sink roles that were assigned by a previous build.
177177

0 commit comments

Comments
 (0)