From a6257129999ca074630fabf2439434c1b8dbefb0 Mon Sep 17 00:00:00 2001 From: Mohammad Aaftab Date: Mon, 29 Jun 2026 00:43:44 +0530 Subject: [PATCH 1/2] Complete local Qwen audio lifecycle and bucketing Add backend-visible payload materialization, bounded batched ref resolution, one-call prefetch, Qwen ASR, finite-window duration bucketing, and comprehensive performance/test coverage. Preserve current-main task and backend compatibility and use finite failure-recovery leases for published payloads. Signed-off-by: Mohammad Aaftab --- PR1967_FEATURE_WALKTHROUGH.md | 686 +++++++++++ examples/audio/qwen_omni_inprocess/README.md | 31 + .../prompts/en_qwen3_omni_disfluency_asr.md | 30 + .../en_qwen3_omni_reference_improvement.md | 51 + .../prompts/ml_qwen3_omni_disfluency_asr.md | 1 + .../ml_qwen3_omni_reference_improvement.md | 23 + nemo_curator/backends/base.py | 337 +++-- nemo_curator/backends/perf_identity.py | 445 +++++++ nemo_curator/backends/ray_data/adapter.py | 84 +- nemo_curator/backends/ray_data/executor.py | 41 +- nemo_curator/backends/ray_data/utils.py | 25 +- nemo_curator/backends/utils.py | 79 +- nemo_curator/backends/xenna/adapter.py | 36 +- nemo_curator/backends/xenna/executor.py | 147 ++- nemo_curator/models/asr/__init__.py | 39 + nemo_curator/models/asr/base.py | 123 ++ nemo_curator/models/asr/qwen_omni.py | 624 ++++++++++ nemo_curator/models/vllm_model.py | 126 +- nemo_curator/pipeline/payload_lifecycle.py | 668 ++++++++++ nemo_curator/pipeline/payload_refs.py | 263 ++++ nemo_curator/pipeline/pipeline.py | 111 +- nemo_curator/pipeline/prefetch.py | 88 ++ nemo_curator/pipelines/__init__.py | 15 + nemo_curator/pipelines/audio/__init__.py | 15 + .../pipelines/audio/qwen_omni_inprocess.py | 351 ++++++ nemo_curator/stages/audio/README.md | 432 ++++++- nemo_curator/stages/audio/common.py | 213 +++- .../stages/audio/inference/__init__.py | 17 + .../stages/audio/inference/asr/__init__.py | 24 + .../stages/audio/inference/asr/stage.py | 1092 +++++++++++++++++ .../stages/audio/inference/batch_policy.py | 463 +++++++ .../stages/audio/inference/bucketed_stage.py | 106 ++ .../stages/audio/io/audio_file_reader.py | 324 +++++ .../stages/audio/io/manifest_writer_utils.py | 155 +++ .../stages/audio/io/nemo_tarred_reader.py | 669 ++++++++++ .../audio/io/sharded_manifest_writer.py | 292 +++++ .../stages/audio/io/waveform_utils.py | 91 ++ .../stages/audio/metrics/performance.py | 972 +++++++++++++++ .../stages/audio/metrics/performance_utils.py | 184 +++ nemo_curator/stages/audio/metrics/squim.py | 4 +- .../stages/audio/model_input_segmentation.py | 104 ++ .../audio/preprocessing/mono_conversion.py | 8 +- nemo_curator/stages/base.py | 16 + nemo_curator/stages/payload_lifecycle.py | 1042 ++++++++++++++++ nemo_curator/tasks/task_terminals.py | 169 +++ nemo_curator/utils/gpu_sampler.py | 173 +++ nemo_curator/utils/performance_utils.py | 91 +- .../utils/pipeline_hardware_sampler.py | 185 +++ pyproject.toml | 8 + tests/backends/ray_data/test_utils.py | 179 ++- tests/backends/test_task_id_postprocess.py | 115 +- tests/backends/test_utils.py | 214 +++- tests/backends/xenna/__init__.py | 1 + tests/backends/xenna/test_executor.py | 164 +++ tests/models/asr/__init__.py | 0 tests/models/asr/test_base.py | 107 ++ tests/models/asr/test_package_lazy_import.py | 62 + tests/models/asr/test_qwen_omni.py | 353 ++++++ tests/pipeline/__init__.py | 15 + tests/pipeline/test_payload_refs.py | 90 ++ tests/pipeline/test_prefetch.py | 78 ++ tests/pipelines/audio/__init__.py | 13 + .../audio/test_qwen_omni_inprocess.py | 326 +++++ tests/pipelines/test_pipelines.py | 60 +- tests/stages/audio/inference/asr/__init__.py | 0 .../stages/audio/inference/test_asr_stage.py | 777 ++++++++++++ .../audio/inference/test_batch_policy.py | 501 ++++++++ .../audio/inference/test_bucketed_stage.py | 101 ++ .../stages/audio/io/test_audio_file_reader.py | 152 +++ .../audio/io/test_nemo_tarred_reader.py | 252 ++++ .../audio/io/test_sharded_manifest_writer.py | 166 +++ .../stages/audio/metrics/test_perf_summary.py | 360 ++++++ tests/stages/audio/test_common.py | 93 +- .../audio/test_model_input_segmentation.py | 99 ++ tests/stages/test_payload_lifecycle.py | 644 ++++++++++ tests/stages/text/io/reader/test_jsonl.py | 2 +- tests/stages/text/io/reader/test_parquet.py | 2 +- tests/tasks/test_utils.py | 21 + tests/utils/test_gpu_sampler.py | 32 + tutorials/audio/README.md | 2 + uv.lock | 47 +- 81 files changed, 15933 insertions(+), 368 deletions(-) create mode 100644 PR1967_FEATURE_WALKTHROUGH.md create mode 100644 examples/audio/qwen_omni_inprocess/README.md create mode 100644 examples/audio/qwen_omni_inprocess/prompts/en_qwen3_omni_disfluency_asr.md create mode 100644 examples/audio/qwen_omni_inprocess/prompts/en_qwen3_omni_reference_improvement.md create mode 100644 examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_disfluency_asr.md create mode 100644 examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_reference_improvement.md create mode 100644 nemo_curator/backends/perf_identity.py create mode 100644 nemo_curator/models/asr/__init__.py create mode 100644 nemo_curator/models/asr/base.py create mode 100644 nemo_curator/models/asr/qwen_omni.py create mode 100644 nemo_curator/pipeline/payload_lifecycle.py create mode 100644 nemo_curator/pipeline/payload_refs.py create mode 100644 nemo_curator/pipeline/prefetch.py create mode 100644 nemo_curator/pipelines/__init__.py create mode 100644 nemo_curator/pipelines/audio/__init__.py create mode 100644 nemo_curator/pipelines/audio/qwen_omni_inprocess.py create mode 100644 nemo_curator/stages/audio/inference/asr/stage.py create mode 100644 nemo_curator/stages/audio/inference/batch_policy.py create mode 100644 nemo_curator/stages/audio/inference/bucketed_stage.py create mode 100644 nemo_curator/stages/audio/io/audio_file_reader.py create mode 100644 nemo_curator/stages/audio/io/manifest_writer_utils.py create mode 100644 nemo_curator/stages/audio/io/nemo_tarred_reader.py create mode 100644 nemo_curator/stages/audio/io/sharded_manifest_writer.py create mode 100644 nemo_curator/stages/audio/io/waveform_utils.py create mode 100644 nemo_curator/stages/audio/metrics/performance.py create mode 100644 nemo_curator/stages/audio/metrics/performance_utils.py create mode 100644 nemo_curator/stages/audio/model_input_segmentation.py create mode 100644 nemo_curator/stages/payload_lifecycle.py create mode 100644 nemo_curator/tasks/task_terminals.py create mode 100644 nemo_curator/utils/gpu_sampler.py create mode 100644 nemo_curator/utils/pipeline_hardware_sampler.py create mode 100644 tests/backends/xenna/__init__.py create mode 100644 tests/backends/xenna/test_executor.py create mode 100644 tests/models/asr/__init__.py create mode 100644 tests/models/asr/test_base.py create mode 100644 tests/models/asr/test_package_lazy_import.py create mode 100644 tests/models/asr/test_qwen_omni.py create mode 100644 tests/pipeline/__init__.py create mode 100644 tests/pipeline/test_payload_refs.py create mode 100644 tests/pipeline/test_prefetch.py create mode 100644 tests/pipelines/audio/__init__.py create mode 100644 tests/pipelines/audio/test_qwen_omni_inprocess.py create mode 100644 tests/stages/audio/inference/asr/__init__.py create mode 100644 tests/stages/audio/inference/test_asr_stage.py create mode 100644 tests/stages/audio/inference/test_batch_policy.py create mode 100644 tests/stages/audio/inference/test_bucketed_stage.py create mode 100644 tests/stages/audio/io/test_audio_file_reader.py create mode 100644 tests/stages/audio/io/test_nemo_tarred_reader.py create mode 100644 tests/stages/audio/io/test_sharded_manifest_writer.py create mode 100644 tests/stages/audio/metrics/test_perf_summary.py create mode 100644 tests/stages/audio/test_model_input_segmentation.py create mode 100644 tests/stages/test_payload_lifecycle.py create mode 100644 tests/utils/test_gpu_sampler.py diff --git a/PR1967_FEATURE_WALKTHROUGH.md b/PR1967_FEATURE_WALKTHROUGH.md new file mode 100644 index 0000000000..36e4766250 --- /dev/null +++ b/PR1967_FEATURE_WALKTHROUGH.md @@ -0,0 +1,686 @@ +# PR1967 Local/Windowed Branch: Current-Code Walkthrough + +This document describes the current files on disk for +`aaftabv/qwen1967-local-noio-control`. It is a reviewer guide to the code as it +exists now: public contracts, execution flow, task shape, scheduling, memory +ownership, failure behavior, observability, and tests. It intentionally does +not narrate discarded designs or earlier benchmark revisions. + +The branch provides four connected capabilities: + +1. a generic pipeline rule that inserts payload materialization and release + stages around independently scheduled consumers; +2. a Ray-backed payload reference lifecycle that decodes large audio payloads + once and keeps waveform tensors out of ordinary task rows; +3. model-input segmentation and duration-aware bucketing inside the + backend-visible ASR stage's finite input window; +4. a pluggable ASR stage and Qwen-Omni adapter with model-call controls and + detailed performance metrics. + +The logical Qwen graph is: + +```text +ManifestReader -> ASRStage -> ManifestWriterStage +``` + +`Pipeline.build()` expands it to: + +```text +ManifestReader + -> AudioPayloadMaterializeStage + -> ASRStage + -> PayloadReleaseStage + -> ManifestWriterStage +``` + +All five execution stages remain separately visible to Ray Data or Xenna. The +payload lifecycle does not combine the reader and GPU stage, and it does not +change the GPU stage's worker or resource contract. + +The local and global branches share every executable implementation file except +`nemo_curator/stages/audio/common.py`. The local file contains ordinary +manifest reading and writing. The global file additionally contains +full-manifest segment planning, parent-row storage, and parent assembly. +Payload handling, ASR, Qwen, backends, task contracts, and performance code are +otherwise byte-identical in the current worktrees. + +## 1. Terms Used By The Code + +- **Logical stage**: a stage listed by the user in `Pipeline.stages` or Hydra + `stages:`. +- **Execution stage**: a concrete `ProcessingStage` after graph expansion and + composite-stage decomposition. +- **Backend-visible stage**: an execution stage independently presented to Ray + Data or Xenna. +- **Payload**: a large object managed outside ordinary task serialization. The + audio materializer stores a decoded waveform tensor. +- **Payload binding**: the mapping from a source field such as + `audio_filepath` to ref, waveform, sample-rate, sample-count, duration, and + materializer fields. +- **PayloadRef**: the lightweight handle in `task.data` that identifies the + object store, admission state, producer node, byte count, sample metadata, + lease settings, and Ray namespace. +- **Payload consumer**: a backend-visible stage that declares payload bindings + and resolves refs only while processing a batch. +- **Parent row**: one complete manifest record delivered to the local ASR + stage. +- **Model-input segment**: one contiguous waveform slice no longer than + `max_inference_duration_s`. +- **Local/windowed bucketing**: duration grouping over model-input segments + available in the current backend-provided ASR window. +- **Terminal row**: a row with `_curator_terminal_*` ownership fields that a + downstream terminal consumer must receive exactly once. The local Qwen graph + does not create terminal segment rows, but the generic task contract is + available to other stages. + +ASR-internal symbols retain `chunk` in names such as `_ChunkSpec`, +`_build_chunk_specs()`, and `_curator_asr_chunk_*`. Those symbols implement +ASR-local stitch-back. The public planning term is **segment**, the public +model boundary is `max_inference_duration_s`, and the shared planner type is +`AudioSegment`. + +## 2. Pipeline Construction And Graph Expansion + +### 2.1 Pipeline state + +[`nemo_curator/pipeline/pipeline.py`](nemo_curator/pipeline/pipeline.py) keeps: + +- `_logical_stages`, the canonical user graph; +- `stages`, the public list and built execution graph. + +`Pipeline.__init__()` preserves the caller-visible config mapping, matching +main-branch behavior, and separately creates a private +`_curator_pipeline_run_id`. The id is copied only into the ephemeral graph +expansion config. Payload actor names include it, preventing independent +pipeline objects from attaching to one another without mutating the caller's +configuration. + +`Pipeline.add_stage()` updates the logical graph and invalidates the built plan. +`_sync_public_stage_mutations()` also accepts direct mutations to +`pipeline.stages`. `_clear_default_source_sink_roles()` removes roles assigned +automatically by a prior build; explicit roles remain subject to the one-source +and one-sink validation in `_assign_source_sink_roles()`. + +`Pipeline.build()`: + +1. synchronizes public stage changes; +2. applies pipeline-level graph expansion; +3. decomposes composite stages; +4. assigns source and sink roles; +5. stores the execution graph for idempotent repeated builds. + +### 2.2 Payload lifecycle rule + +[`nemo_curator/pipeline/payload_lifecycle.py`](nemo_curator/pipeline/payload_lifecycle.py) +implements the backend-neutral `expand_payload_lifecycle_stages()` rule. + +The rule validates: + +- one `materialize_after` stage and one `release_after` stage; +- all consumers lie within that lifecycle range; +- no materialize/release helper is explicitly listed in the logical graph; +- all consumers are payload-aware; +- consumer ref/waveform bindings match the materialized bindings; +- source, ref, and waveform keys are unique. + +Selectors match Hydra stage id, stage name, class name, or fully qualified class +name. + +The preferred multiple-input form is `payload_lifecycle.payloads`, with one +mapping per payload. The single-input shorthand derives fields from +`payload_keys` and consumer attributes. + +The central rule calls `ManifestReader.build_payload_materialize_stage()` to +construct the audio materializer. The post-release extension path is inactive +because this reader does not enable global planning, so no parent assembler is +inserted. + +### 2.3 Expanded graph config + +```yaml +payload_lifecycle: + enabled: true + materialize_after: reader + payload_keys: [audio_filepath] + ref_key: waveform_ref + consumers: [qwen_omni] + release_after: qwen_omni + target_sample_rate: 16000 + target_nchannels: 1 + node_memory_fraction: 0.80 +``` + +For several consumers, list every payload-aware stage in `consumers` and put +`release_after` on the final one. Each consumer retains its own `Resources`, +`batch_size`, worker count, and backend actor. + +## 3. Payload Reference Lifecycle + +The handle API is in +[`nemo_curator/pipeline/payload_refs.py`](nemo_curator/pipeline/payload_refs.py). +Audio materialization and consumer support are in +[`nemo_curator/stages/payload_lifecycle.py`](nemo_curator/stages/payload_lifecycle.py). + +### 3.1 PayloadRef + +`PayloadRef` carries: + +| Field | Meaning | +| --- | --- | +| `payload_id` | object id in the payload store | +| `owner_node_id` | node that decoded and stored the object | +| `store_actor_name` | node-local object store actor | +| `admission_actor_name` | cluster admission actor | +| `amount_bytes` | actual stored byte count | +| `sample_rate` / `num_samples` | consumer-visible waveform metadata | +| `lease_ttl_s` | heartbeat setting | +| `actor_namespace` | Ray namespace for actor lookup | + +`resolve_payload_ref()` refreshes actor state and returns one payload. +`resolve_payload_refs_batched()` groups handles by actor, issues +`heartbeat_many`, `pin_many`, and `get_many` RPCs, preserves caller order, and +splits work by an optional byte bound. Each actor-side bulk method performs one +expiry-reap pass for the whole request, so actor work stays linear in the +number of handles. +`release_payload_ref()` removes the object and releases admission bytes. +`strip_payload_refs()` recursively removes handles from nested containers. + +### 3.2 Admission and stores + +`_PayloadAdmissionState` tracks per-node budgets, aggregate cluster usage, and +reservations. The default cluster budget is the sum of registered node budgets; +`max_cluster_payload_bytes` can set an explicit limit. A row larger than either +applicable budget fails immediately. Temporary lack of capacity waits until a +payload is released, bounded by `admission_wait_timeout_s` (four hours by +default). A timeout reports the requested bytes and the actor's node/cluster +usage snapshot rather than polling forever. + +`_PayloadStoreState` owns actual objects. Store actors are node-affined. Store +and admission actors use the pipeline run id and active Ray namespace, are +detached across backend worker lifetimes, and are killed by executor cleanup. + +Materialized payloads use a longer finite `materialized_lease_ttl_s` while they +wait between stages (four hours by default). `_PayloadLeaseKeeper` switches to +active `lease_ttl_s` renewals while a consumer performs long model work. +Explicit release is the normal fast path; finite expiry lets admission and +store actors reclaim a payload whose row is lost before release. + +### 3.3 AudioPayloadMaterializeStage + +For each row the materializer: + +1. reads the configured metadata duration; +2. estimates waveform bytes and acquires admission capacity with a finite + `lease_ttl_s` lease; +3. decodes the local file through `AudioFileReaderStage`; +4. removes the waveform from normal task data; +5. measures actual tensor bytes and resizes the reservation; +6. stores the tensor in the node-local actor; +7. writes `PayloadRef`, estimated bytes, actual bytes, and producer node id; +8. converts the completed reservation to the finite + `materialized_lease_ttl_s`, giving queued rows a long handoff window while + bounding orphan retention. + +Duration must be positive and numeric. Byte-limit strings accept integer, `k`, +`m`, and `g` forms and reject invalid values. If actual bytes cannot fit, the +stage releases its reservation and raises. If materialization fails after the +store insert, it also removes the stored object. If a worker dies before the +reservation is committed, its finite materialization lease can be reaped. + +When `skip_on_read_error` is enabled, a reader error yields a skipped row with +no payload ref. The reservation and zero-length waveform are removed. + +Metrics include admission wait, poll count, estimated/reserved/stored bytes, +node and cluster budgets, and materialization count. + +### 3.4 Payload-aware consumers + +`PayloadAwareStageMixin.payload_bindings()` provides a single-waveform default. +A multi-input stage overrides it with one binding per payload. + +`resolve_payload_refs_for_batch()` resolves handles with actor-grouped, +byte-bounded bulk RPCs, restores sample metadata, records same-node and +cross-node resolution metrics, and starts batched heartbeats. +`drop_resolved_payloads()` stops the heartbeat thread and removes temporary +waveform fields. + +The Qwen config opts into `BoundedOneAheadPrefetchIterator` from +`pipeline/prefetch.py`. ASR can plan exact model calls from `PayloadRef` +metadata (`num_samples`, `sample_rate`, and `amount_bytes`) before loading the +waveform. `_PayloadCallMaterializer` then: + +1. groups and resolves only the unique parent refs required by one adapter + call; +2. caches a resolved parent while contiguous calls still need its segments; +3. slices each requested model-input segment locally; +4. allows one byte-bounded successor call to resolve while the current call is + on the GPU; and +5. drops actor-local waveform references as soon as their call is complete. + +The payload-store actor continues to own the original waveform until +`PayloadReleaseStage`; prefetch changes only the ASR actor's temporary working +set. This path is opt-in. Existing payload-aware consumers retain the eager, +batched resolver path. + +The lightweight ref remains in task data after one consumer returns. Later +configured consumers can resolve the same waveform without another file read. + +### 3.5 Release and exception paths + +`PayloadReleaseStage` finds all nested refs, deduplicates payload ids, releases +objects and byte reservations, strips handles, removes waveform and payload +bookkeeping keys, and returns the task. It supports rows without refs and keeps +the existing task-data mapping object intact. + +`BaseStageAdapter` performs payload scanning only for stages marked by the +lifecycle expander. On those stages it releases all input refs on an exception +and releases refs that disappear because a stage filtered a row. Stages in +ordinary pipelines do not pay that recursive scan cost. + +## 4. Local Audio Decode + +[`nemo_curator/stages/audio/io/audio_file_reader.py`](nemo_curator/stages/audio/io/audio_file_reader.py) +defines the single raw audio-byte I/O implementation used by the materializer. + +The reader: + +- accepts local paths and rejects URI-style paths; +- checks for ffmpeg in per-node setup; +- decodes to float32 PCM at configured sample rate and channel count; +- supports `segment_start_s` and `segment_duration_s` when supplied by another + stage; +- returns channels-first contiguous tensors; +- writes waveform, sample rate, sample count, duration, mono status, and + `audio_item_id`; +- converts decode errors into skipped rows when configured. + +The local manifest reader does not create segment offsets. It emits complete +parent rows, so the materializer decodes each full source row once. ASR slices +that in-memory waveform only when the model input ceiling requires it. + +## 5. Local Manifest Reading + +`ManifestReader` in +[`nemo_curator/stages/audio/common.py`](nemo_curator/stages/audio/common.py) is a +`CompositeStage` that decomposes into: + +```text +FilePartitioningStage -> ManifestReaderStage +``` + +`FilePartitioningStage` discovers manifest files. `ManifestReaderStage` streams +each JSONL file line by line with fsspec and emits one `AudioTask` per non-empty +line. It copies task metadata/performance and derives child task ids from the +partition task. It is a one-worker fanout stage. + +Rows keep the complete original manifest dictionary and input order. No +full-manifest duration plan, parent-data actor, segment terminal fields, or +assembler exists in this branch. + +## 6. Model-Input Segmentation + +[`nemo_curator/stages/audio/model_input_segmentation.py`](nemo_curator/stages/audio/model_input_segmentation.py) +contains the shared safety primitive. + +`resolve_max_model_input_duration()` validates the positive model input +ceiling. `plan_audio_segments()` converts actual sample count, sample rate, and +the ceiling into contiguous `AudioSegment` records. Each record contains index, +count, start sample, stop sample, and duration. + +Properties enforced by the implementation: + +- an input at the ceiling produces one segment; +- an input just over the ceiling produces one full segment and one tail; +- no overlap or padding is introduced; +- zero samples remain representable as one empty segment; +- invalid sample rates raise; +- metadata duration conversion uses ceiling sample math. + +The local branch applies this helper inside `ASRStage` to the decoded waveform. +With bucketing enabled, the bounded segments are bucketed. With bucketing +disabled, the same segmentation remains the model/OOM safety boundary. + +## 7. ASR Stage And Windowed Bucketing + +[`nemo_curator/stages/audio/inference/asr/stage.py`](nemo_curator/stages/audio/inference/asr/stage.py) +defines `ASRStage`. Model adapters conform to +[`ASRAdapter`](nemo_curator/models/asr/base.py). + +### 7.1 Stage contract + +The stage accepts waveform or `waveform_ref`, sample rate, optional source +language, and optional reference text. It writes configured primary text, +optional secondary text, and a skip key. + +`setup_on_node()` prefetches model weights with one CPU and zero GPUs. +Worker `setup()` constructs the adapter under the stage's GPU resource +allocation. `teardown()` releases adapter state. + +### 7.2 Independent batch controls + +| Control | Scope | +| --- | --- | +| `ASRStage.batch_size` | parent-row candidate window passed by Ray Data/Xenna | +| `BatchPolicy.max_items_per_batch_by_bucket` | model-work grouping per duration bucket | +| `adapter_batch_size` | fallback items per adapter call | +| `BatchPolicy.bucketed_inference_batch_size` | per-duration-bucket adapter-call cap | +| `BatchPolicy.max_audio_sec_per_batch` | aggregate cost cap for one bucketed batch | + +These controls do not govern payload RAM. Payload memory is admitted in bytes +by the materializer. + +### 7.3 Process flow + +`ASRStage.process_batch()`: + +1. uses eager bulk payload resolution unless prefetch is explicitly enabled; +2. in prefetch mode, plans segments and exact adapter calls from payload + metadata before waveform resolution; +3. resolves the current call and overlaps one bounded next call with current + GPU inference; +4. builds adapter items with language, reference text, task id, duration, and + stitch-back indices; +5. applies duration-aware policy when enabled; +6. splits each bucket by its adapter-call cap; +7. invokes the adapter and realigns results; +8. joins per-parent text in segment order; +9. drops current-call waveform references; the payload actor retains the + original until `PayloadReleaseStage`. + +The finite candidate set is the current backend-provided `process_batch()` +window. Local bucketing cannot inspect rows outside that window. + +The stage emits `adapter_inference_calls` and `adapter_inference_items`, plus +input, processed, skipped, generated-segment, audio-duration, waveform-byte, +output-character, token, and inference-time metrics. + +### 7.4 Code-derived 5h example + +Using the same real benchmark manifest as the global guide: + +```text +/home/aaftabv/grananary-v2/realdata_5h_yt_alm_part2_20260613/manifest_5h_stratified_duration_tails.jsonl +``` + +the local reader emits 89 complete parent rows in source order. It does not +create segment rows. For the first two records: + +| Source index | Parent duration | Materialized object at 16 kHz mono float32 | ASR model-input segments | +| ---: | ---: | ---: | --- | +| 0 | 7513.3335 s | about 480,853,344 bytes | 2400, 2400, 2400, 313.3335 s | +| 1 | 2756.4135 s | about 176,410,464 bytes | 2400, 356.4135 s | + +The difference from global is where segmentation happens. Local materializes +the complete parent waveform once, stores one parent `PayloadRef`, and sends +that row into the backend-provided ASR window. `_build_chunk_specs()` calls the +shared segment planner against the resolved parent sample count. In prefetch +mode, ASR plans those descriptors from ref metadata, resolves a parent once, +and reuses the cached tensor for its contiguous segment calls. + +For all 89 rows, model safety is the same as global: no adapter input exceeds +2,400 seconds. Packing scope is different. Local duration-aware bucketing can +combine only segments whose parent rows are present in the current +`process_batch()` window. It cannot reorder the complete 89-row manifest before +materialization, and payload admission accounts for complete parent tensors +rather than globally planned segment tensors. + +### 7.5 BatchPolicy + +[`nemo_curator/stages/audio/inference/batch_policy.py`](nemo_curator/stages/audio/inference/batch_policy.py) +defines a generic cost-bucket policy. + +`BatchPolicy` validates strictly increasing edges starting at zero, per-bucket +item caps, optional adapter caps, total-cost cap, candidate window, and flush +interval. `BucketQueueScheduler` flushes on item capacity, cost capacity, timer, +or drain. Finite planning orders ready batches by descending total cost and +returns original indices for result alignment. + +`run_bucketed()` exposes the same bucket-dispatch-and-realign loop to other +inference stages through caller-supplied cost and execution functions. + +## 8. Qwen-Omni Adapter + +[`nemo_curator/models/asr/qwen_omni.py`](nemo_curator/models/asr/qwen_omni.py) +implements `QwenOmniASRAdapter` using `VLLMBase` from +[`nemo_curator/models/vllm_model.py`](nemo_curator/models/vllm_model.py). + +Install this adapter with `uv sync --extra audio_qwen`. The Qwen-only extra +composes the unchanged `audio_cuda12` and `vllm` extras with +`qwen-omni-utils`, so existing audio installations keep their main-branch +dependency selection. + +It supports: + +- inline and file-backed default, English, follow-up, and system prompts; +- per-item `{language}` and `{transcript}` interpolation; +- waveform normalization and 16 kHz resampling; +- threaded multimodal request preparation; +- one-turn or two-turn Qwen inference; +- stable one-result-per-input ordering with skipped placeholders; +- tensor parallelism, model/token sequence limits, GPU memory utilization, + prefix caching, multimodal limits, sampling, seed, and output-token settings; +- preparation, generation, valid/skipped input, and output-token metrics. + +`ASRStage` passes `waveform`, `sample_rate`, `language`, `language_code`, +`reference_text`, `task_id`, `audio_seconds`, and stitch-back indices to the +adapter. Adapter-specific conversion and Qwen request construction remain out +of the stage. + +## 9. Backend Scheduling And Autoscaling + +The backend implementation is modality-neutral and shared with the global +branch. + +### 9.1 Base adapter + +[`nemo_curator/backends/base.py`](nemo_curator/backends/base.py) wraps one +`stage.process_batch()` invocation with timing, task-id postprocessing, custom +metrics, and the existing validation flow. Payload-ref scans and terminal +tombstones are enabled only on stages marked by payload lifecycle expansion. +Worker identity and invocation-window GPU/VRAM sampling are enabled only when +`extended_performance_metrics` is explicitly set; the Qwen entrypoint opts in, +while existing pipelines retain the compact main-compatible record shape. + +The backend does not own duration bucketing or payload prefetch. Ray Data and +Xenna deliver their normal finite `process_batch()` windows, and ASR performs +segmentation, bucketing, bulk resolution, and optional one-call lookahead +inside its existing actor. Ray Actor Pool remains outside this Qwen execution +path; payload lifecycle does not add a second scheduling policy to it. + +### 9.2 Ray Data + +[`nemo_curator/backends/ray_data/adapter.py`](nemo_curator/backends/ray_data/adapter.py) +maps every backend-visible stage with `Dataset.map_batches()`. + +- `stage.batch_size` sets the process-batch window. +- Stages with setup/GPU/actor requirements use actors; stateless CPU helpers can + use tasks. +- `stage.num_workers()` fixes an actor pool when set. +- Otherwise optional `min_workers`, `max_workers`, and `initial_workers` values + come from `ray_stage_spec()`; without those bounds Ray Data controls + actor-pool scaling subject to each actor's declared resources. +- Fanout stages repartition output blocks to one row each. + +For this graph, materialize and release are CPU task stages, ASR is a GPU actor +stage, and writer is a one-worker actor stage. + +### 9.3 Xenna + +[`nemo_curator/backends/xenna/executor.py`](nemo_curator/backends/xenna/executor.py) +creates one Xenna `StageSpec` per execution stage. It forwards stage resources, +batch size, runtime environment, retry/lifetime settings, and worker sizing. + +Cluster-wide worker count comes from `stage.num_workers()`. +`xenna_stage_spec()["num_workers_per_node"]` is the Xenna-specific alternative; +setting both is rejected. `xenna_stage_spec()["num_workers"]` is rejected with +a message directing stage authors to `num_workers()`. A stage without either +fixed sizing remains under Xenna allocation/autoscaling. Streaming and batch +modes use the same stage graph. + +### 9.4 Setup resources + +`ProcessingStage.setup_on_node_resources()` defaults to the stage's processing +resources. `execute_setup_on_node()` submits setup for every stage on every +alive Ray node with those resources. `ASRStage` explicitly requests CPU-only +prefetch setup; model construction remains a GPU-worker operation. + +## 10. Tasks And Terminal Rows + +[`nemo_curator/tasks/tasks.py`](nemo_curator/tasks/tasks.py) makes the base +`Task.task_id` framework-owned (`init=False`). `BaseStageAdapter` overwrites it +at every derivable stage boundary: one-to-many outputs use parent plus output +index/content id, positional many-to-many uses each matching parent, and an +ambiguous many-to-different-count fanout receives an `r` fallback. +`AudioTask` retains its audio-specific constructor field, but normal backend +postprocessing still derives the stage-boundary id. + +`EmptyTask` is a payload-less class rooted at task id `"0"`; source execution +constructs it with `EmptyTask()`. + +[`nemo_curator/tasks/task_terminals.py`](nemo_curator/tasks/task_terminals.py) +defines generic `_curator_terminal_*` ownership and tombstone fields. Normal +local Qwen rows have no terminal ownership metadata, so ordinary filtering +still removes them. The helper activates only for rows that explicitly carry a +terminal contract. + +## 11. Performance And Resource Observability + +### 11.1 Stage metrics + +`BaseStageAdapter` attaches one `StagePerfStats` record per stage invocation. +Its public `to_dict()` and numeric `items()` schema remains main-compatible. +When `extended_performance_metrics` is enabled, the record also carries an +invocation id, expected resources, node/worker/actor identity, and per-GPU +utilization/VRAM observations. Audio aggregation deduplicates records by +invocation id because one invocation record may be attached to several output +tasks. + +[`nemo_curator/backends/perf_identity.py`](nemo_curator/backends/perf_identity.py) +normalizes Ray and Xenna identity into common node, worker, actor, hostname, GPU +index, GPU UUID, and allocation fields. + +### 11.2 Pipeline hardware sampler + +When `pipeline_hardware_sampler_enabled` is true, +[`nemo_curator/utils/pipeline_hardware_sampler.py`](nemo_curator/utils/pipeline_hardware_sampler.py) +starts one sampler actor per alive node for the executor lifetime. It observes +every GPU independently of stage ownership. The generic executor default is +off; the Qwen entrypoint opts in by default. Executors attach the resulting +`pipeline_hardware_sampler` record without using it for placement decisions. + +### 11.3 Audio performance summary + +[`nemo_curator/stages/audio/metrics/performance.py`](nemo_curator/stages/audio/metrics/performance.py) +aggregates stage totals/percentiles, per-actor and per-GPU views, payload wait +and locality metrics, adapter calls/items, audio throughput, writer timing, and +hardware samples. Shared invocation ids prevent duplicated counts when the same +perf record is attached to several output rows. + +## 12. Manifest Output + +`ManifestWriterStage` in +[`nemo_curator/stages/audio/common.py`](nemo_curator/stages/audio/common.py) is a +single-worker actor stage. Driver setup truncates the output once; per-node +setup only creates directories. + +[`manifest_writer_utils.py`](nemo_curator/stages/audio/io/manifest_writer_utils.py) +applies an explicit serialization policy. By default it writes task data as-is, +matching the existing writer contract. `drop_manifest_keys` and +`drop_array_like_values` opt into omission, and non-JSON values otherwise fail +with the offending key. In the Qwen graph, `PayloadReleaseStage` removes refs +and waveform bookkeeping before the writer; the Qwen writer config also opts +into array/key filtering. `write_perf_stats` defaults off for compatibility and +is enabled by the benchmark config to refresh `perf_summary.json` and merge the +executor's external pipeline-hardware record. + +`ShardedManifestWriterStage` and `NemoTarredAudioReader` provide separate +sharded-output and tarred-input APIs. They do not alter the raw Qwen lifecycle +graph. + +## 13. Extending The Primitives + +### 13.1 Another payload modality + +A source stage implements `build_payload_materialize_stage()`. Its materializer +creates the configured handle, and consumers implement +`resolve_payload_refs_for_batch()`. Central graph insertion remains independent +of the payload's modality. + +### 13.2 Multiple payloads or consumers + +Use one `payload_lifecycle.payloads` mapping per source and override +`payload_bindings()` in consumers. All materializers are inserted after the +source. One release stage recursively frees every nested ref after the final +consumer. + +### 13.3 Another inference model + +An ASR model implements `ASRAdapter`. Another modality can use `BatchPolicy` +and `run_bucketed()` without adopting ASR. Each model stage retains its own +backend resources and worker count. A model adapter can expose +`estimate_item_cost()` for encoder-token or VRAM-aware cost in place of raw +duration. + +## 14. Test Map + +- [`tests/pipelines/test_pipelines.py`](tests/pipelines/test_pipelines.py): + logical/execution graph state, rebuilds, and source/sink roles; +- [`tests/pipelines/audio/test_qwen_omni_inprocess.py`](tests/pipelines/audio/test_qwen_omni_inprocess.py): + lifecycle expansion, multiple consumers, multiple payloads, and helper-stage + rejection; +- [`tests/stages/test_payload_lifecycle.py`](tests/stages/test_payload_lifecycle.py): + byte admission, stores, explicit-release lifetime, namespaces, batched actor + methods, heartbeat, nested release, read-error rows, and actor cleanup; +- [`tests/pipeline/test_payload_refs.py`](tests/pipeline/test_payload_refs.py) + and [`tests/pipeline/test_prefetch.py`](tests/pipeline/test_prefetch.py): + actor-grouped resolution, stable ref order, byte bounds, cache behavior, and + one-successor overlap; +- [`tests/stages/audio/test_model_input_segmentation.py`](tests/stages/audio/test_model_input_segmentation.py): + validation, exact 2400-second boundary, zero samples, and tail segments; +- [`tests/stages/audio/inference/test_asr_stage.py`](tests/stages/audio/inference/test_asr_stage.py): + payload-backed inputs, segmentation, language/reference fields, result + ordering, skip behavior, adapter calls, and metrics; +- [`tests/stages/audio/inference/test_batch_policy.py`](tests/stages/audio/inference/test_batch_policy.py): + bucket edges, caps, cost scheduling, adapter batches, ordering, and generic + scheduler hooks; +- [`tests/backends/ray_data/test_utils.py`](tests/backends/ray_data/test_utils.py): + actor sizing and backend batch delivery; +- [`tests/backends/xenna/test_executor.py`](tests/backends/xenna/test_executor.py): + StageSpec construction, `num_workers()`/per-node sizing conflicts, and + verbosity; +- [`tests/stages/audio/metrics/test_perf_summary.py`](tests/stages/audio/metrics/test_perf_summary.py) + and [`tests/utils/test_gpu_sampler.py`](tests/utils/test_gpu_sampler.py): + summary and GPU metrics. + +## 15. Reviewer File Map + +| Concern | Primary files | +| --- | --- | +| pipeline planning | `nemo_curator/pipeline/pipeline.py`, `pipeline/payload_lifecycle.py` | +| payload handles/prefetch | `pipeline/payload_refs.py`, `pipeline/prefetch.py`, `stages/payload_lifecycle.py` | +| local manifest reader/writer | `stages/audio/common.py` | +| local audio decode | `stages/audio/io/audio_file_reader.py` | +| model-input segmentation | `stages/audio/model_input_segmentation.py` | +| ASR and batching | `stages/audio/inference/asr/stage.py`, `batch_policy.py`, `bucketed_stage.py` | +| Qwen adapter | `models/asr/base.py`, `models/asr/qwen_omni.py`, `models/vllm_model.py` | +| backend execution | `backends/base.py`, `backends/ray_data/*`, `backends/xenna/*`; `backends/ray_actor_pool/*` remains current main | +| task contracts | `tasks/tasks.py`, `tasks/task_terminals.py`, `tasks/sentinels.py` | +| performance | `backends/perf_identity.py`, `utils/gpu_sampler.py`, `utils/pipeline_hardware_sampler.py`, `stages/audio/metrics/*` | +| output safety | `stages/audio/io/manifest_writer_utils.py`, `stages/audio/common.py` | +| Hydra entry point | `pipelines/audio/qwen_omni_inprocess.py` | + +## 16. Core Invariants To Verify + +1. Each input audio file is decoded once by the materializer. +2. The waveform tensor lives in a payload actor between consumers. +3. Every configured consumer can resolve the same ref without another file read. +4. Release removes the stored tensor and its byte reservation. +5. Every GPU consumer remains a separate backend-visible stage. +6. Ray Data and Xenna use each stage's normal resources and worker contract. +7. Model inputs never exceed `max_inference_duration_s`. +8. Bucket-on groups only model-safe segments from the current backend window. +9. Bucket-off retains segmentation as its long-row model safety boundary. +10. ASR results are restored to original parent order. +11. Output manifests contain neither waveform tensors nor `PayloadRef` objects. +12. Performance summaries contain adapter-level calls/items and both + invocation-window and pipeline-wide GPU/VRAM observations. diff --git a/examples/audio/qwen_omni_inprocess/README.md b/examples/audio/qwen_omni_inprocess/README.md new file mode 100644 index 0000000000..589dad02a3 --- /dev/null +++ b/examples/audio/qwen_omni_inprocess/README.md @@ -0,0 +1,31 @@ +# Qwen-Omni In-Process ASR Assets + +This folder contains prompt templates used by the Qwen-Omni in-process ASR +adapter. + +Install the runtime with `uv sync --extra audio_qwen`. The dedicated extra +keeps Qwen/vLLM dependencies out of existing `audio_cuda12` installations. + +The executable code path is: + +```text +Pipeline + -> ManifestReader + -> AudioPayloadMaterializeStage + -> ASRStage(adapter_target=QwenOmniASRAdapter) + -> PayloadReleaseStage + -> ManifestWriterStage +``` + +The adapter reads prompt text through `prompt_file`, `en_prompt_file`, +`followup_prompt_file`, or `system_prompt_file`. Curator stage behavior remains +outside the prompt files: + +- graph expansion lives in `nemo_curator/pipeline/payload_lifecycle.py`; +- audio decode and payload refs live in `nemo_curator/stages/payload_lifecycle.py`; +- local/windowed ASR model-input segmentation and batching live in + `nemo_curator/stages/audio/inference/asr/stage.py`; +- Qwen/vLLM request construction lives in `nemo_curator/models/asr/qwen_omni.py`. + +Prompt files may use `{language}` and `{transcript}` placeholders when the +stage supplies language or reference text columns. diff --git a/examples/audio/qwen_omni_inprocess/prompts/en_qwen3_omni_disfluency_asr.md b/examples/audio/qwen_omni_inprocess/prompts/en_qwen3_omni_disfluency_asr.md new file mode 100644 index 0000000000..c7ca1107d0 --- /dev/null +++ b/examples/audio/qwen_omni_inprocess/prompts/en_qwen3_omni_disfluency_asr.md @@ -0,0 +1,30 @@ +You receive audio in English. + +MAIN GOAL: faithfully transcribe audio as is spoken in the audio with all disfluencies present in the audio. +- Do NOT remove, correct, or "clean up" any speech artifacts. +- Do NOT paraphrase, edit grammar, or make the speech more polished. + +FILLER WORDS: +- Include hesitation markers like "um", "uh", "hm", "ah" etc as is spoken in the audio. + +REPETITIONS: +- Consecutive instances of the same word or short phrase spoken unintentionally — keep all repetitions as-is. + - Example: "I I think", "the the problem" + +FALSE STARTS: +- Incomplete words or phrases the speaker abandons, mark with a hyphen — keep them as-is. + - Example: "I was go going to the store." → "I was go- going to the store." + +COLLOQUIAL REDUCTIONS: +- Preserve forms such as "wanna", "gonna", "kinda", "lemme", "lotta", "outta", "Imma", "sorta", "ya", "m'kay", "finna", "tryna", etc exactly as spoken. Do NOT expand them into standard forms. + +WRONG GRAMMAR: +- Grammatical errors should be faithfully captured in the transcript — do NOT correct them. +- You MUST NOT fix subject-verb agreement, tense errors, or any other grammatical issues. + +NUMERICALS: +- Keep numbers as is spoken in words. Do NOT convert them to numbers. like "oh eleven" should be "oh eleven" as spoken in the audio not "zero eleven" etc + +Output format: +- Return ONLY the transcription text. +- No explanations, no JSON, no lists. diff --git a/examples/audio/qwen_omni_inprocess/prompts/en_qwen3_omni_reference_improvement.md b/examples/audio/qwen_omni_inprocess/prompts/en_qwen3_omni_reference_improvement.md new file mode 100644 index 0000000000..7fb70a8744 --- /dev/null +++ b/examples/audio/qwen_omni_inprocess/prompts/en_qwen3_omni_reference_improvement.md @@ -0,0 +1,51 @@ +You receive English audio and a reference transcript. The reference may be cleaned, partially wrong, or missing speech artifacts. The audio is the ground truth. + +REFERENCE TRANSCRIPT: +"{transcript}" + +MAIN GOAL: Listen carefully to the audio and revise the reference so it faithfully reflects exactly what is spoken, including all disfluencies present in the audio. +- Use the reference as a starting point; do not ignore it. +- When the reference matches the audio, keep it unchanged. +- When the reference conflicts with the audio, follow the audio. +- Do NOT invent words or content not spoken in the audio. +- Do NOT remove substantive content that is spoken in the audio (remove reference words only if they are not spoken). +- Do NOT paraphrase, polish grammar or rewrite sentences that already match the audio. +- Prefer minimal edits: fix mismatches and insert missing speech artifacts. +- Preserve named entities from the reference in their exact written form. +- Normalize numbers to their written form. + +ENTITIES (names, places, brands, titles, etc.): +- Keep every named entity from the reference in its exact written form: spelling, casing, script, and punctuation. This includes names, places, brands, titles, acronyms, and other proper nouns. +- Do not ever transliterate, translate, re-spell, normalize, or "correct" an entity into another script. +- If enetities are part code switched data it should stay the same. + +KEEP REFERENCE DISFLUENCIES: +- If the reference already has fillers, repetitions, false starts, colloquial reductions, or grammatical errors, keep them. +- Add hesitation markers and fillers natural to English wherever they are spoken in the audio but missing from the reference. +- Do NOT clean up, normalize, or remove disfluencies that are already in the reference and are spoken in the audio. +- Add consecutive instances of the same word or short phrase when spoken unintentionally. + - Example: reference "I think" → "I I think" if that is what is spoken. + + +BACKGROUND / QUIET / OVERLAPPING SPEECH: +- Keep all audible speech in the reference, including quieter, distant, or overlapping voices — not just the loudest speaker. +- Add background or secondary speech that is audible but missing; do not drop words because they sound like background. + +FALSE STARTS: +- Add incomplete words or phrases the speaker abandons, marked with a hyphen. + - Example: "I was go- going to the store." +- Do NOT remove false starts already in the reference if they are spoken in the audio. + +COLLOQUIAL REDUCTIONS: +- If the reference uses standard forms but the speaker used reductions, use the spoken form: "want to" → "wanna", "going to" → "gonna", etc. +- Preserve forms such as "wanna", "gonna", "kinda", "lemme", "lotta", "outta", "Imma", "sorta", "ya", "m'kay", "finna", "tryna", etc. Do NOT expand them. + +WRONG GRAMMAR: +- Keep grammatical errors as spoken. Do NOT correct subject-verb agreement, tense errors, or other grammar issues. + +NUMERICALS: +- Keep numbers as spoken in words. Do NOT convert them to digits. + - Example: keep "oh eleven" or "zero eleven". + +Output format: +- Return ONLY the revised transcription text. diff --git a/examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_disfluency_asr.md b/examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_disfluency_asr.md new file mode 100644 index 0000000000..61728d231a --- /dev/null +++ b/examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_disfluency_asr.md @@ -0,0 +1 @@ +Transcribe the {language} audio into text exactly as the speaker says it. Write numbers as spoken words. diff --git a/examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_reference_improvement.md b/examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_reference_improvement.md new file mode 100644 index 0000000000..1a12d72be4 --- /dev/null +++ b/examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_reference_improvement.md @@ -0,0 +1,23 @@ +You receive: +1) An audio file, +2) A Ground Truth Transcription of the audio {transcript}. + +Goal: To normalize numbers from the text and add any disfluencies that are present in the audio. + +ALLOWED ONLY: +1) Normalize numeric expressions into words exactly as they are SPOKEN in the audio. +- Mixed format is forbidden: + Bad: "5 percent", "2 zeros" + Good: "five percent", "two zeros" +- Normalize: percentages, currencies, units, ranges, decimals, dates/years — ONLY if they are spoken. +- If a unit (for example “percent”) is NOT spoken, do not add it. +2) Add any disfluencies present in the audio. +- Disfluencies as "um", "uh" that are present in the audio should be added to the text. +- If word is repeated in the audio but missing from ground truth add it to the text. + +ENTITIES (names, places, brands, titles, etc.) should be the same as inGround Truth Transcription: +- Keep every named entity from the reference in its exact written form: spelling, casing, script, and punctuation. This includes names, places, brands, titles, acronyms, and other proper nouns. + +OUTPUT FORMAT: +- Return only the final text. +- No explanations, no JSON, no lists. diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index 94236f2cfc..c54b10037f 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -12,24 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import time import uuid from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any +from loguru import logger + +from nemo_curator.backends.perf_identity import apply_worker_perf_identity, read_worker_metadata_identity from nemo_curator.core.utils import ignore_ray_head_node -from nemo_curator.tasks import Task -from nemo_curator.utils.performance_utils import StageTimer +from nemo_curator.tasks.task_terminals import preserve_dropped_terminal_tasks +from nemo_curator.utils.performance_utils import StagePerfStats, StageTimer if TYPE_CHECKING: from nemo_curator.stages.base import ProcessingStage + from nemo_curator.tasks import Task @dataclass class NodeInfo: - """Generic node information for setup_on_node calls across backends. - Simplified to match Xenna's structure. - """ + """Generic node information for setup_on_node calls across backends.""" node_id: str = "" @@ -37,12 +42,21 @@ class NodeInfo: @dataclass class WorkerMetadata: """Generic worker metadata for setup_on_node calls across backends. - Simplified to match Xenna's structure. The allocation field can contain - backend-specific allocation information. + + Backends stamp ``actor_id``/``node_id``/``gpu_id`` at setup; perf records + copy them verbatim (see ``backends/perf_identity.py``). """ worker_id: str = "" - allocation: Any = None # Backend-specific allocation info + allocation: Any = None # Backend-specific allocation info (Xenna) + actor_id: str = "" + node_id: str = "" + gpu_id: str = "" + physical_address: str = "" + pod_ip: str = "" + hostname: str = "" + gpu_indices: list[int] = field(default_factory=list) + gpu_uuids: list[str] = field(default_factory=list) class BaseExecutor(ABC): @@ -53,94 +67,237 @@ def __init__(self, config: dict[str, Any] | None = None, ignore_head_node: bool self.ignore_head_node = ignore_head_node or ignore_ray_head_node() @abstractmethod - def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | None = None) -> None: + def execute(self, stages: list[ProcessingStage], initial_tasks: list[Task] | None = None) -> None: """Execute the pipeline.""" + def _cleanup_stage_run_resources(self, stages: list[ProcessingStage]) -> None: + """Release run-scoped resources created by pipeline helper stages. + + Some helpers intentionally create named Ray actors so payload handles can + cross backend-visible stage boundaries. Executors own the run lifecycle, + so cleanup belongs here rather than in one row-processing stage. + """ + for stage in reversed(stages): + cleanup = getattr(stage, "cleanup_run_resources", None) + if not callable(cleanup): + continue + try: + cleanup() + except Exception as exc: # noqa: BLE001 + logger.warning(f"Run-scoped cleanup failed for stage {stage}: {exc}") + + def _start_pipeline_hardware_sampler(self) -> list[Any]: + # Observability is opt-in so existing pipelines keep main's actor count, + # timings, and terminal performance-record shape. + if not bool(self.config.get("pipeline_hardware_sampler_enabled", False)): + return [] + try: + from nemo_curator.utils.pipeline_hardware_sampler import start_pipeline_hardware_samplers + + interval_s = float(self.config.get("pipeline_hardware_sampler_interval_s", 0.5)) + startup_timeout_s = float(self.config.get("pipeline_hardware_sampler_startup_timeout_s", 5.0)) + return start_pipeline_hardware_samplers(interval_s=interval_s, startup_timeout_s=startup_timeout_s) + except Exception as exc: # noqa: BLE001 + logger.debug("Pipeline hardware sampler disabled: {}", exc) + return [] + + def _stop_pipeline_hardware_sampler(self, sampler_actors: list[Any]) -> StagePerfStats | None: + if not sampler_actors: + return None + try: + from nemo_curator.utils.pipeline_hardware_sampler import stop_pipeline_hardware_samplers + + stop_timeout_s = float(self.config.get("pipeline_hardware_sampler_stop_timeout_s", 10.0)) + metrics = stop_pipeline_hardware_samplers(sampler_actors, stop_timeout_s=stop_timeout_s) + except Exception as exc: # noqa: BLE001 + logger.debug("Pipeline hardware sampler stop failed: {}", exc) + return None + wall_time_s = float(metrics.pop("pipeline_hardware_wall_time_s", 0.0)) + return StagePerfStats( + stage_name="pipeline_hardware_sampler", + process_time=wall_time_s, + num_items_processed=1, + custom_metrics=metrics, + ) + + @staticmethod + def _attach_pipeline_hardware_perf(tasks: list[Task], perf_stats: StagePerfStats | None) -> None: + if perf_stats is None: + return + for task in tasks: + task.add_stage_perf(perf_stats) + + def _publish_external_perf(self, stages: list[ProcessingStage], perf_stats: StagePerfStats | None) -> None: + """Publish a run-level perf record to the terminal artifact writer when one exists.""" + if perf_stats is None: + return + for stage in reversed(stages): + recorder = getattr(stage, "record_external_stage_perf", None) + if not callable(recorder): + continue + try: + recorder(perf_stats) + except Exception as exc: # noqa: BLE001 + logger.debug("External perf publish failed for stage {}: {}", stage, exc) + return + class BaseStageAdapter: """Adapts ProcessingStage to an execution backend, if needed.""" - def __init__(self, stage: "ProcessingStage"): + def __init__(self, stage: ProcessingStage): self.stage = stage - def process_batch(self, tasks: list[Task]) -> list[Task]: - """Process a batch of tasks. + @staticmethod + def _stage_resource_expectation_metrics(stage: ProcessingStage) -> dict[str, float]: + """Return non-summing resource expectations attached by wrapper stages.""" + metrics: dict[str, float] = {} + for attr_name, metric_name in ( + ("_curator_expected_stage_gpu_count", "expected_stage_gpu_count"), + ("_curator_expected_stage_worker_count", "expected_stage_worker_count"), + ("_curator_expected_worker_gpu_count", "expected_worker_gpu_count"), + ): + value = getattr(stage, attr_name, None) + if isinstance(value, bool) or value is None: + continue + try: + numeric = float(value) + except (TypeError, ValueError): + continue + if numeric > 0: + metrics[metric_name] = numeric + return metrics + + def _cache_perf_identity(self) -> None: + """Copy backend-stamped identity from ``WorkerMetadata`` (fixed per worker).""" + worker_metadata = getattr(self, "_worker_metadata", None) + self._perf_identity = read_worker_metadata_identity(str(self.stage.name), worker_metadata) - Args: - tasks (list[Task]): List of tasks to process - - Returns: - list[Task]: List of processed tasks - """ - # Lazy initialize timer if needed + def process_batch(self, tasks: list[Task]) -> list[Task]: + """Process a batch of tasks, timing and stamping perf stats on outputs.""" if not hasattr(self, "_timer") or self._timer is None: self._timer = StageTimer(self.stage) - # Calculate input data size for timer input_size = sum(task.num_items for task in tasks) - # Initialize performance timer for this batch self._timer.reinit(input_size) - - with self._timer.time_process(input_size): - # Use the batch processing logic - results = self.stage.process_batch(tasks) + tracks_payload_refs = bool(getattr(self.stage, "_curator_tracks_payload_refs", False)) + input_payload_refs = self._collect_payload_refs(tasks) if tracks_payload_refs else {} + extended_metrics = bool(getattr(self.stage, "extended_performance_metrics", False)) + + window_start = time.time() if extended_metrics else 0.0 + try: + with self._timer.time_process(input_size): + results = self.stage.process_batch(tasks) + except Exception: + self._release_payload_refs(input_payload_refs.values()) + raise + window_end = time.time() if extended_metrics else 0.0 + if bool(getattr(self.stage, "_curator_preserves_terminal_tasks", False)): + results = preserve_dropped_terminal_tasks(self.stage, tasks, results) + if input_payload_refs: + self._release_dropped_payload_refs(input_payload_refs, results) # Guarantee every emitted task has a task_id (derived id, or uuid fallback). results = self._post_process_task_ids(tasks, results) - # Log performance stats and add to result tasks + self._attach_stage_perf(results, window_start, window_end, extended_metrics=extended_metrics) + return results + + def _attach_stage_perf( + self, + results: list[Task], + window_start: float, + window_end: float, + *, + extended_metrics: bool, + ) -> None: + """Attach one invocation record, with optional extended diagnostics.""" _, stage_perf_stats = self._timer.log_stats() - # Consume and attach any custom metrics recorded by the stage during this call + # Unique id per invocation: the same record is attached to every output + # task, so downstream accumulators dedup on it (N tasks count once). + if extended_metrics: + stage_perf_stats.invocation_id = uuid.uuid4().hex custom_metrics = self.stage._consume_custom_metrics() if custom_metrics: stage_perf_stats.custom_metrics.update(custom_metrics) + if extended_metrics: + stage_perf_stats.custom_metrics.update(self._stage_resource_expectation_metrics(self.stage)) + # Fold in windowed GPU utilization (no-op for CPU / no NVML). Namespaced + # per physical device UUID (``gpu_util_pct::``) so the summary can + # attribute it to a GPU index and roll the actor up from its devices. + if extended_metrics: + self._add_gpu_sampler_metrics(stage_perf_stats, window_start, window_end) + # Identity is resolved once per worker in setup() and stamped on WorkerMetadata. + if extended_metrics: + if not hasattr(self, "_perf_identity") or self._perf_identity is None: + self._cache_perf_identity() + apply_worker_perf_identity(stage_perf_stats, self._perf_identity) for task in results: task.add_stage_perf(stage_perf_stats) - return results + def _add_gpu_sampler_metrics( + self, stage_perf_stats: StagePerfStats, window_start: float, window_end: float + ) -> None: + """Add optional per-device diagnostics for one invocation window.""" + sampler = getattr(self, "_gpu_sampler", None) + if sampler is not None: + diagnostics = getattr(sampler, "diagnostics", None) + if callable(diagnostics): + stage_perf_stats.custom_metrics.update(diagnostics()) + for uuid_key, metrics in sampler.window_stats(window_start, window_end).items(): + for metric, value in metrics.items(): + stage_perf_stats.custom_metrics[f"{metric}::{uuid_key}"] = value + + def _collect_payload_refs(self, tasks: list[Task]) -> dict[str, object]: + refs: dict[str, object] = {} + if not bool(getattr(self.stage, "_curator_tracks_payload_refs", False)): + return refs + try: + from nemo_curator.pipeline.payload_refs import task_payload_refs + except ImportError: + return refs + for task in tasks: + for payload_ref in task_payload_refs(task): + payload_id = getattr(payload_ref, "payload_id", None) + if payload_id: + refs[str(payload_id)] = payload_ref + return refs + + def _release_dropped_payload_refs(self, input_refs: dict[str, object], output_tasks: list[Task]) -> None: + if not input_refs: + return + try: + from nemo_curator.pipeline.payload_refs import task_payload_refs + except ImportError: + return + output_ids: set[str] = set() + for task in output_tasks: + if task is None: + continue + for payload_ref in task_payload_refs(task): + payload_id = getattr(payload_ref, "payload_id", None) + if payload_id: + output_ids.add(str(payload_id)) + dropped = [payload_ref for payload_id, payload_ref in input_refs.items() if payload_id not in output_ids] + self._release_payload_refs(dropped) + + @staticmethod + def _release_payload_refs(payload_refs: object) -> None: + if not payload_refs: + return + try: + from nemo_curator.pipeline.payload_refs import PayloadRef, release_payload_ref + except ImportError: + return + for payload_ref in payload_refs: + if isinstance(payload_ref, PayloadRef): + release_payload_ref(payload_ref) def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]: - """Assign a deterministic ``task_id`` to every emitted task. - - This is the single place task ids are assigned — it runs for every - stage on every backend (all backend adapters subclass this), so it - makes no difference whether a stage defines ``process`` or overrides - ``process_batch``. ``task_id`` is the task's id path (parents + own segment); ids are - re-derived at each stage boundary so the same object passing through - N stages gets N ids. - - The input→output mapping decides each output's PARENT; whether the - stage is a source decides each output's SEGMENT (content id vs index) - — the two are independent. ``None`` outputs (Curator's "return None to - filter") are NOT removed before the length check — keeping them in - place preserves positional alignment for filter stages — and are then - dropped from the returned list. - - - single input → every output is its child (fan-out): ``parent_`` - - ``len(output) == len(input)`` → positional 1:1: each ``parent_i_``; - a ``None`` slot just means input ``i`` was filtered. - - any other (ambiguous) cardinality across a batch → a random ``uuid`` - prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never - empty even when a derived id is not possible. The ``"r"`` prefix flags - the id as non-deterministic / ancestry-not-tracked (see - ``Task.task_id`` docstring). - - ``seg`` is the output's content id (``Task.get_deterministic_id()``) - for a source stage when available, else the positional index — so a - source partition keeps a stable id across reorderings regardless of - whether the source is 1→N or N→N. - - Note: a stage that BOTH filters and fans out within a single batch - (returning a flat list rather than a per-input slot) cannot be mapped - positionally; if its length happens to equal the input length the 1:1 - assumption may misattribute parents. That combination is unsupported - until per-slot sentinels (NoneTask/FailedTask) land in a later PR. - """ + """Assign a deterministic ``task_id`` to every emitted task.""" is_source = getattr(self.stage, "is_source_stage", False) if len(input_tasks) == 1: - # Fan-out (incl. a source reading from EmptyTask): every non-None - # output is a child of the single input. parent_id = input_tasks[0].task_id out: list[Task] = [t for t in output_tasks if t is not None] for i, task in enumerate(out): @@ -149,8 +306,6 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas return out if len(output_tasks) == len(input_tasks): - # Positional 1:1. None is kept above so a filtered slot still lines - # up with its own parent; drop the None slots from the result. out = [] for parent, task in zip(input_tasks, output_tasks, strict=True): if task is None: @@ -160,33 +315,45 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas out.append(task) return out - # Ambiguous cardinality across a batch: a derived id is not possible. Use a - # random "r"-prefixed uuid so task_id is non-empty but clearly flagged - # non-deterministic. out = [t for t in output_tasks if t is not None] for task in out: task.task_id = "r" + uuid.uuid4().hex return out def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: - """Setup the stage on a node. - - Args: - node_info (NodeInfo, optional): Information about the node - worker_metadata (WorkerMetadata, optional): Information about the worker - """ - # Call the underlying stage's setup_on_node method - # Some backends may provide node/worker info, others may not + """Setup the stage on a node (node/worker info may be absent on some backends).""" self.stage.setup_on_node(node_info, worker_metadata) def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: - """Setup the stage once per actor. - - Args: - worker_metadata (WorkerMetadata, optional): Information about the worker - """ + """Setup the stage once per actor.""" + self._worker_metadata = worker_metadata + if bool(getattr(self.stage, "extended_performance_metrics", False)): + self._cache_perf_identity() + else: + self._perf_identity = None self.stage.setup(worker_metadata) + self._gpu_sampler = self._maybe_start_gpu_sampler() + + def _maybe_start_gpu_sampler(self) -> object | None: + """Start a background NVML sampler for GPU stages (else ``None``).""" + if not bool(getattr(self.stage, "extended_performance_metrics", False)): + return None + resources = getattr(self.stage, "resources", None) + if resources is None or not getattr(resources, "requires_gpu", False): + return None + try: + from nemo_curator.utils.gpu_sampler import GpuUtilSampler + + gpu_uuids = tuple(getattr(self._perf_identity, "gpu_uuids", ()) or ()) + sampler = GpuUtilSampler(gpu_uuids=gpu_uuids, sample_all_visible=True) + sampler.start() + except Exception: # noqa: BLE001 + return None + return sampler def teardown(self) -> None: """Teardown the stage once per actor.""" + sampler = getattr(self, "_gpu_sampler", None) + if sampler is not None: + sampler.stop() self.stage.teardown() diff --git a/nemo_curator/backends/perf_identity.py b/nemo_curator/backends/perf_identity.py new file mode 100644 index 0000000000..88d1ceeadb --- /dev/null +++ b/nemo_curator/backends/perf_identity.py @@ -0,0 +1,445 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: S110, S112, SIM105 + +"""Backend-specific perf identity labels. + +Each backend resolves ``WorkerPerfIdentity`` once at worker setup from its own +APIs; ``BaseStageAdapter`` copies the values stamped on ``WorkerMetadata``. + +Metrics aggregate by ``actor_id``. ``physical_address`` (``:``) +is the canonical, backend-independent GPU identifier on each GPU actor's block; +the remaining fields are additive cluster-location metadata for debugging. +""" + +from __future__ import annotations + +import os +import socket +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from nemo_curator.backends.base import WorkerMetadata + from nemo_curator.utils.performance_utils import StagePerfStats + + +@dataclass(frozen=True) +class WorkerPerfIdentity: + """Perf identity resolved once per worker at backend setup.""" + + actor_id: str = "" + node_id: str = "" + gpu_id: str = "" + physical_address: str = "" + pod_ip: str = "" + hostname: str = "" + gpu_indices: tuple[int, ...] = () + gpu_uuids: tuple[str, ...] = () + + +def _format_gpu_label(node_label: str, gpu_index: object) -> str: + idx_str = str(gpu_index).strip() + if not idx_str: + return "" + return f"{node_label}:{idx_str}" if node_label else idx_str + + +def _format_actor_label(stage_name: str, worker_or_actor_id: str) -> str: + wid = (worker_or_actor_id or "").strip() + if not wid: + return stage_name + return f"{stage_name}:actor-{wid[:8]}" + + +def _resolve_hostname() -> str: + try: + return (socket.gethostname() or "").strip() + except OSError: + return "" + + +def _resolve_pod_ip() -> str: + for key in ("POD_IP", "STATUS_POD_IP"): + value = (os.environ.get(key) or "").strip() + if value: + return value + return "" + + +def _resolve_host_ip() -> str: + pod_ip = _resolve_pod_ip() + if pod_ip: + return pod_ip + try: + import ray + + return (ray.util.get_node_ip_address() or "").strip() + except Exception: # noqa: BLE001 + return "" + + +def _allocation_gpu_indices(allocation: object | None, requires_gpu: bool) -> tuple[int, ...]: + if not requires_gpu or allocation is None: + return () + gpus = getattr(allocation, "gpus", None) or [] + indices: list[int] = [] + for gpu in gpus: + idx = getattr(gpu, "index", None) + if idx is not None: + indices.append(int(idx)) + return tuple(indices) + + +def _visible_gpu_ordinals(gpu_indices: tuple[int, ...], visible_count: int) -> list[int]: + """Translate physical CUDA indices to torch's *visible* ordinals. + + ``gpu_indices`` are physical ids but torch enumerates only the devices in + ``CUDA_VISIBLE_DEVICES`` as ordinals ``0..visible_count-1``. Map via the env + when it lists integer ids; under per-worker isolation every visible ordinal + belongs to this worker. + """ + env = os.environ.get("CUDA_VISIBLE_DEVICES") + if not env: + # No mask -> all GPUs visible -> physical index == torch ordinal. + return [i for i in gpu_indices if 0 <= i < visible_count] + phys_to_ordinal: dict[int, int] = {} + for ordinal, token in enumerate(t.strip() for t in env.split(",") if t.strip()): + try: + phys_to_ordinal[int(token)] = ordinal + except ValueError: + phys_to_ordinal = {} # UUID-style mask -> no positional int mapping + break + mapped = [phys_to_ordinal[i] for i in gpu_indices if i in phys_to_ordinal] + if mapped: + return mapped + # Isolated worker (or unmappable ids): the visible set *is* this worker's. + return list(range(visible_count)) + + +def _collect_gpu_uuids(gpu_indices: tuple[int, ...]) -> tuple[str, ...]: + if not gpu_indices: + return () + try: + import torch + + if not torch.cuda.is_available(): + return () + visible_count = torch.cuda.device_count() + except Exception: # noqa: BLE001 + return () + uuids: list[str] = [] + # Per-ordinal guard: one bad index must not wipe the rest. + for ordinal in _visible_gpu_ordinals(gpu_indices, visible_count): + try: + uuid = str(getattr(torch.cuda.get_device_properties(ordinal), "uuid", "") or "").strip() + except Exception: # noqa: BLE001 + continue + if uuid: + uuids.append(uuid) + return tuple(uuids) + + +def _format_physical_address(host_token: str, gpu_indices: tuple[int, ...]) -> str: + """Canonical physical GPU address: ``:``. + + ``host_token`` degrades to ``node`` so a GPU worker always gets a non-empty, + backend-independent identifier. Returns ``""`` only when it holds no GPUs. + """ + if not gpu_indices: + return "" + host = (host_token or "").strip() or "node" + idx_part = ",".join(str(idx) for idx in gpu_indices) + return f"{host}:{idx_part}" + + +def build_xenna_perf_identity( + stage_name: str, + *, + worker_id: str, + node_id: str, + allocation: object | None, + requires_gpu: bool, +) -> WorkerPerfIdentity: + """Identity from Xenna ``WorkerMetadata`` + ``NodeInfo`` (allocation-first GPU). + + GPU index comes only from ``allocation.gpus[0].index``; node label falls back + ``node_id`` -> MPI rank env -> ``allocation.node``. + """ + node_label = (node_id or "").strip() + if not node_label: + rank = os.environ.get("OMPI_COMM_WORLD_RANK") + if rank not in (None, ""): + node_label = f"node-{rank}" + elif allocation is not None: + node_label = str(getattr(allocation, "node", "") or "").strip() + + actor_label = _format_actor_label(stage_name, worker_id) + + gpu_label = "" + gpu_indices: tuple[int, ...] = () + if requires_gpu and allocation is not None: + gpu_indices = _allocation_gpu_indices(allocation, requires_gpu=True) + if gpu_indices: + gpu_label = _format_gpu_label(node_label, gpu_indices[0]) + + hostname = _resolve_hostname() + physical_address = _format_physical_address(_resolve_host_ip() or hostname or node_label, gpu_indices) + gpu_uuids = _collect_gpu_uuids(gpu_indices) if gpu_indices else () + + return WorkerPerfIdentity( + actor_id=actor_label, + node_id=node_label, + gpu_id=gpu_label, + physical_address=physical_address, + pod_ip=_resolve_pod_ip(), + hostname=hostname, + gpu_indices=gpu_indices, + gpu_uuids=gpu_uuids, + ) + + +def _ray_node_label(ctx: object) -> str: + try: + node_hex = getattr(ctx, "get_node_id", lambda: "")() + if node_hex: + return f"node-{str(node_hex)[:8]}" + except Exception: # noqa: BLE001 + return "" + return "" + + +def _ray_worker_short_id(ctx: object) -> str: + short_id = "" + try: + short_id = (getattr(ctx, "get_actor_id", lambda: "")() or "") if hasattr(ctx, "get_actor_id") else "" + except Exception: # noqa: BLE001 + short_id = "" + if short_id: + return short_id + try: + return getattr(ctx, "get_worker_id", lambda: "")() or "" + except Exception: # noqa: BLE001 + return "" + + +def _gpu_assignment_tokens(values: object) -> tuple[str, ...]: + """Return non-empty GPU assignment tokens from Ray/env values.""" + if values is None: + return () + if isinstance(values, str): + iterable = values.split(",") + else: + try: + iterable = list(values) # type: ignore[arg-type] + except TypeError: + iterable = [values] + return tuple(token for token in (str(value).strip() for value in iterable) if token) + + +def _parse_int_indices(values: object) -> tuple[int, ...]: + """Best-effort int-index parse; silently drops non-integer (e.g. UUID) ids.""" + out: list[int] = [] + for value in _gpu_assignment_tokens(values): + try: + out.append(int(str(value).strip())) + except (TypeError, ValueError): + continue # UUID-style assignment -> no positional index + return tuple(out) + + +def _normalize_gpu_uuid(value: object) -> str: + try: + text = value.decode() if isinstance(value, bytes) else str(value) + except Exception: # noqa: BLE001 + text = "" + text = text.strip().lower() + return text.removeprefix("gpu-") + + +def _uuid_gpu_assignment(tokens: tuple[str, ...]) -> tuple[tuple[int, ...], tuple[str, ...]]: + """Map Ray/CUDA UUID assignment tokens back to physical GPU indices with NVML.""" + wanted = tuple((token, _normalize_gpu_uuid(token)) for token in tokens if _normalize_gpu_uuid(token)) + if not wanted: + return (), () + try: + import pynvml + + pynvml.nvmlInit() + except Exception: # noqa: BLE001 + return (), () + + matches: dict[str, int] = {} + try: + for index in range(int(pynvml.nvmlDeviceGetCount())): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(index) + uuid = _normalize_gpu_uuid(pynvml.nvmlDeviceGetUUID(handle)) + except Exception: # noqa: BLE001 + continue + if uuid: + matches[uuid] = index + finally: + try: + pynvml.nvmlShutdown() + except Exception: # noqa: BLE001 + pass + + indices: list[int] = [] + uuids: list[str] = [] + for token, normalized in wanted: + if normalized in matches: + indices.append(matches[normalized]) + uuids.append(token) + return tuple(indices), tuple(uuids) + + +def _gpu_assignment_from_tokens(tokens: tuple[str, ...]) -> tuple[tuple[int, ...], tuple[str, ...]]: + indices = _parse_int_indices(tokens) + if indices: + return indices, _collect_gpu_uuids(indices) + return _uuid_gpu_assignment(tokens) + + +def _ray_gpu_assignment(requires_gpu: bool) -> tuple[tuple[int, ...], tuple[str, ...]]: + if not requires_gpu: + return (), () + try: + import ray + + tokens = _gpu_assignment_tokens(ray.get_gpu_ids()) + except Exception: # noqa: BLE001 + tokens = () + + indices, uuids = _gpu_assignment_from_tokens(tokens) + if indices: + return indices, uuids + + # Ray may leave CUDA_VISIBLE_DEVICES set to this worker's assigned slice + # when get_gpu_ids() is empty. Support both integer and UUID masks. + env_tokens = _gpu_assignment_tokens(os.environ.get("CUDA_VISIBLE_DEVICES")) + return _gpu_assignment_from_tokens(env_tokens) + + +def _ray_gpu_indices(requires_gpu: bool) -> tuple[int, ...]: + return _ray_gpu_assignment(requires_gpu)[0] + + +def _ray_gpu_label(node_label: str, requires_gpu: bool) -> str: + gpu_indices = _ray_gpu_indices(requires_gpu) + if gpu_indices: + return _format_gpu_label(node_label, gpu_indices[0]) + return "" + + +def build_ray_perf_identity( + stage_name: str, + *, + requires_gpu: bool, +) -> WorkerPerfIdentity: + """Identity from Ray runtime context (Ray Data / Ray Actor Pool). + + GPU assignment comes from ``ray.get_gpu_ids()``, falling back to + ``CUDA_VISIBLE_DEVICES`` when Ray returns no ids. Supports both integer and + UUID assignments so Ray Data actors still start GPU utilization sampling. + """ + blank = WorkerPerfIdentity() + + try: + import ray + + if hasattr(ray, "is_initialized") and not ray.is_initialized(): + return blank + ctx = ray.get_runtime_context() + except Exception: # noqa: BLE001 + return blank + + worker_id = _ray_worker_short_id(ctx) + node_label = _ray_node_label(ctx) + if not (worker_id or node_label): + return blank + + actor_label = _format_actor_label(stage_name, worker_id) + gpu_indices, gpu_uuids = _ray_gpu_assignment(requires_gpu) + gpu_label = _format_gpu_label(node_label, gpu_indices[0]) if gpu_indices else "" + host_ip = "" + try: + import ray + + host_ip = (ray.util.get_node_ip_address() or "").strip() + except Exception: # noqa: BLE001 + host_ip = "" + hostname = _resolve_hostname() + physical_address = _format_physical_address(host_ip or hostname or node_label, gpu_indices) + if gpu_indices and not gpu_uuids: + gpu_uuids = _collect_gpu_uuids(gpu_indices) + + return WorkerPerfIdentity( + actor_id=actor_label, + node_id=node_label, + gpu_id=gpu_label, + physical_address=physical_address, + pod_ip=_resolve_pod_ip(), + hostname=hostname, + gpu_indices=gpu_indices, + gpu_uuids=gpu_uuids, + ) + + +def read_worker_metadata_identity( + stage_name: str, + worker_metadata: WorkerMetadata | None, +) -> WorkerPerfIdentity: + """Return perf labels previously stamped on ``WorkerMetadata`` by the backend.""" + if worker_metadata is None: + return WorkerPerfIdentity() + actor_id = (worker_metadata.actor_id or "").strip() + node_id = (worker_metadata.node_id or "").strip() + gpu_id = (worker_metadata.gpu_id or "").strip() + if not (actor_id or node_id or gpu_id): + return WorkerPerfIdentity() + return WorkerPerfIdentity( + actor_id=actor_id or stage_name, + node_id=node_id, + gpu_id=gpu_id, + physical_address=(worker_metadata.physical_address or "").strip(), + pod_ip=(worker_metadata.pod_ip or "").strip(), + hostname=(worker_metadata.hostname or "").strip(), + gpu_indices=tuple(worker_metadata.gpu_indices or ()), + gpu_uuids=tuple(worker_metadata.gpu_uuids or ()), + ) + + +def stamp_worker_metadata(worker_metadata: WorkerMetadata, identity: WorkerPerfIdentity) -> None: + """Copy a resolved identity onto generic ``WorkerMetadata``.""" + worker_metadata.actor_id = identity.actor_id + worker_metadata.node_id = identity.node_id + worker_metadata.gpu_id = identity.gpu_id + worker_metadata.physical_address = identity.physical_address + worker_metadata.pod_ip = identity.pod_ip + worker_metadata.hostname = identity.hostname + worker_metadata.gpu_indices = list(identity.gpu_indices) + worker_metadata.gpu_uuids = list(identity.gpu_uuids) + + +def apply_worker_perf_identity(stage_perf_stats: StagePerfStats, identity: WorkerPerfIdentity) -> None: + """Copy resolved worker identity onto a ``StagePerfStats`` record.""" + stage_perf_stats.actor_id = identity.actor_id + stage_perf_stats.node_id = identity.node_id + stage_perf_stats.gpu_id = identity.gpu_id + stage_perf_stats.physical_address = identity.physical_address + stage_perf_stats.pod_ip = identity.pod_ip + stage_perf_stats.hostname = identity.hostname + stage_perf_stats.gpu_indices = list(identity.gpu_indices) + stage_perf_stats.gpu_uuids = list(identity.gpu_uuids) diff --git a/nemo_curator/backends/ray_data/adapter.py b/nemo_curator/backends/ray_data/adapter.py index d4e5ec64d9..6ac52f8e1e 100644 --- a/nemo_curator/backends/ray_data/adapter.py +++ b/nemo_curator/backends/ray_data/adapter.py @@ -20,10 +20,19 @@ from ray.data import Dataset, TaskPoolStrategy from nemo_curator.backends.base import BaseStageAdapter -from nemo_curator.backends.utils import RayStageSpecKeys, get_worker_metadata_and_node_id +from nemo_curator.backends.utils import ( + RayStageSpecKeys, + get_worker_metadata_and_node_id, + get_worker_metadata_and_node_id_with_perf, +) from nemo_curator.stages.base import ProcessingStage -from .utils import get_actor_compute_strategy_for_stage, get_configured_actor_pool_sizing_keys, is_actor_stage +from .utils import ( + coerce_batch_tasks, + get_actor_compute_strategy_for_stage, + get_configured_actor_pool_sizing_keys, + is_actor_stage, +) CURATOR_MANAGED_MAP_BATCHES_KWARGS = {"compute", "max_calls", "num_cpus", "num_gpus"} @@ -67,7 +76,7 @@ def _process_batch_internal(self, batch: dict[str, Any]) -> dict[str, Any]: Returns: Dictionary with arrays/lists representing processed Task objects """ - tasks = batch["item"] + tasks = coerce_batch_tasks(batch["item"]) results = self.process_batch(tasks) # Return the results as Ray Data expects them # For Task objects, we return them in the 'item' column @@ -99,15 +108,34 @@ def process_dataset(self, dataset: Dataset) -> Dataset: Returns: Dataset: Processed Ray Data dataset """ - ray_stage_spec = self.stage.ray_stage_spec() - stage_is_actor = ray_stage_spec.get(RayStageSpecKeys.IS_ACTOR_STAGE, is_actor_stage(self.stage)) + is_actor_stage_ = self.stage.ray_stage_spec().get(RayStageSpecKeys.IS_ACTOR_STAGE, is_actor_stage(self.stage)) + + map_batches_fn, concurrency_kwargs = self._map_batches_fn_and_kwargs( + is_actor_stage=is_actor_stage_, + ) + + # Calculate concurrency based on available resources + logger.info(f"{self.stage.__class__.__name__} {is_actor_stage_=} with {concurrency_kwargs=}") + + processed_dataset = dataset.map_batches(map_batches_fn, batch_size=self.batch_size, **concurrency_kwargs) # type: ignore[reportArgumentType] + + if self.stage.ray_stage_spec().get(RayStageSpecKeys.IS_FANOUT_STAGE, False): + processed_dataset = processed_dataset.repartition(target_num_rows_per_block=1) + + return processed_dataset - if stage_is_actor: + def _map_batches_fn_and_kwargs( + self, + *, + is_actor_stage: bool, + ) -> tuple[Any, dict[str, Any]]: + ray_stage_spec = self.stage.ray_stage_spec() + if is_actor_stage: map_batches_fn = create_actor_from_stage(self.stage) - map_batches_kwargs = {"compute": get_actor_compute_strategy_for_stage(self.stage)} + concurrency_kwargs = {"compute": get_actor_compute_strategy_for_stage(self.stage)} else: map_batches_fn = create_task_from_stage(self.stage) - map_batches_kwargs = {} + concurrency_kwargs = {} actor_pool_sizing_keys = get_configured_actor_pool_sizing_keys(ray_stage_spec) if actor_pool_sizing_keys: @@ -118,18 +146,15 @@ def process_dataset(self, dataset: Dataset) -> Dataset: num_workers = self.stage.num_workers() if num_workers is not None and num_workers > 0: - map_batches_kwargs["compute"] = TaskPoolStrategy(size=num_workers) + concurrency_kwargs["compute"] = TaskPoolStrategy(size=num_workers) - max_calls = ray_stage_spec.get(RayStageSpecKeys.MAX_CALLS_PER_WORKER) + max_calls = ray_stage_spec.get(RayStageSpecKeys.MAX_CALLS_PER_WORKER, None) if max_calls is not None: - map_batches_kwargs["max_calls"] = max_calls + concurrency_kwargs["max_calls"] = max_calls - map_batches_kwargs.update(self._build_resource_kwargs(ray_stage_spec)) + concurrency_kwargs.update(self._build_resource_kwargs(ray_stage_spec)) - # Per-stage ray_remote_args (e.g. runtime_env with different pip versions per stage). ray_remote_args = copy.deepcopy(ray_stage_spec.get(RayStageSpecKeys.RAY_REMOTE_ARGS) or {}) - # If the stage declares runtime_env, forward it directly to Ray so Ray creates and - # caches an isolated virtualenv for this stage's workers. if self.stage.runtime_env: ray_remote_args["runtime_env"] = self.stage.runtime_env @@ -141,20 +166,13 @@ def process_dataset(self, dataset: Dataset) -> Dataset: ) raise ValueError(msg) - map_batches_kwargs.update(ray_remote_args) - - # Let Ray Data apply the selected compute strategy and resource requirements. - logger.info(f"{self.stage.__class__.__name__} stage_is_actor={stage_is_actor} with {map_batches_kwargs=}") - - processed_dataset = dataset.map_batches(map_batches_fn, batch_size=self.batch_size, **map_batches_kwargs) # type: ignore[reportArgumentType] - - if ray_stage_spec.get(RayStageSpecKeys.IS_FANOUT_STAGE, False): - processed_dataset = processed_dataset.repartition(target_num_rows_per_block=1) - - return processed_dataset + concurrency_kwargs.update(ray_remote_args) + return map_batches_fn, concurrency_kwargs -def create_actor_from_stage(stage: ProcessingStage) -> type[RayDataStageAdapter]: +def create_actor_from_stage( + stage: ProcessingStage, +) -> type[RayDataStageAdapter]: """Create a StageProcessor class with the proper stage name for display.""" class RayDataStageActorAdapter(RayDataStageAdapter): @@ -164,7 +182,13 @@ def __init__(self): """Initialize the stage processor.""" super().__init__(stage) self.setup_done = False - node_info, worker_metadata = get_worker_metadata_and_node_id() + requires_gpu = bool(getattr(getattr(stage, "resources", None), "requires_gpu", False)) + if bool(getattr(stage, "extended_performance_metrics", False)): + node_info, worker_metadata = get_worker_metadata_and_node_id_with_perf( + str(stage.name), requires_gpu=requires_gpu + ) + else: + node_info, worker_metadata = get_worker_metadata_and_node_id() self.setup_on_node(node_info, worker_metadata) self.setup(worker_metadata) @@ -179,7 +203,9 @@ def __call__(self, batch: dict[str, Any]) -> dict[str, Any]: return RayDataStageActorAdapter -def create_task_from_stage(stage: ProcessingStage) -> Callable[[dict[str, Any]], dict[str, Any]]: +def create_task_from_stage( + stage: ProcessingStage, +) -> Callable[[dict[str, Any]], dict[str, Any]]: """Create a named Ray Data stage adapter function. This creates a standalone function that wraps the stage processing logic diff --git a/nemo_curator/backends/ray_data/executor.py b/nemo_curator/backends/ray_data/executor.py index 85b06fa861..d7657dd3cf 100644 --- a/nemo_curator/backends/ray_data/executor.py +++ b/nemo_curator/backends/ray_data/executor.py @@ -18,7 +18,9 @@ from loguru import logger from ray.data import DataContext, Dataset -from nemo_curator.backends.base import BaseExecutor +from nemo_curator.backends.base import ( + BaseExecutor, +) from nemo_curator.backends.utils import execute_setup_on_node, register_loguru_serializer from nemo_curator.tasks import EmptyTask, Task @@ -32,22 +34,13 @@ class RayDataExecutor(BaseExecutor): """Ray Data-based executor for pipeline execution. This executor: - 1. Executes setup on Ray nodes for all stages + 1. Executes setup on all nodes for all stages 2. Converts initial tasks to Ray Data dataset 3. Applies each stage as a Ray Data transformation (as a task or actor in map_batches) 4. Returns final results as a list of tasks """ def __init__(self, config: dict[str, Any] | None = None, ignore_head_node: bool = False): - """Initialize the executor. - - Args: - config (dict[str, Any], optional): Configuration dictionary. - ignore_head_node (bool, optional): Whether to skip the Ray head node for - ``setup_on_node``. Ray Data controls ``map_batches`` task/actor placement - through Ray's scheduler; this flag does not cap actor-pool size or force - Ray Data workers away from the head node. - """ super().__init__(config, ignore_head_node) def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | None = None) -> list[Task]: @@ -69,6 +62,8 @@ def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | N # Initialize with initial tasks if provided, otherwise start with EmptyTask tasks: list[Task] = initial_tasks or [EmptyTask()] output_tasks: list[Task] = [] + hardware_sampler: list[Any] = [] + hardware_perf = None # When runtime_env with pip is used, Ray's pip plugin sets up per-stage virtualenvs # lazily on first task dispatch by cloning the current virtualenv. The NeMo Curator # container's /opt/venv is created with `uv venv --seed` so pip is available in clones. @@ -78,6 +73,7 @@ def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | N ray.init( ignore_reinit_error=True, runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": ""}} ) + hardware_sampler = self._start_pipeline_hardware_sampler() # Convert tasks to dataset current_dataset = self._tasks_to_dataset(tasks) @@ -91,12 +87,7 @@ def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | N # TODO: add pipeline level config for verbosity logger.info(f"Processing stage {i + 1}/{len(stages)}: {stage}") logger.info(f" CPU cores: {stage.resources.cpus}, GPU ratio: {stage.resources.gpus}") - - # Create adapter for this stage - adapter = RayDataStageAdapter(stage) - - # Apply stage transformation - current_dataset = adapter.process_dataset(current_dataset) + current_dataset = self._process_stage_dataset(stage, current_dataset) except Exception as e: logger.error(f"Error during pipeline execution: {e}") raise @@ -104,12 +95,26 @@ def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | N # Convert final dataset back to tasks # TODO: add pipeline configuration to check if user wants to return last stages output to driver output_tasks = self._dataset_to_tasks(current_dataset) + hardware_perf = self._stop_pipeline_hardware_sampler(hardware_sampler) + hardware_sampler = [] + self._attach_pipeline_hardware_perf(output_tasks, hardware_perf) + self._publish_external_perf(stages, hardware_perf) logger.info(f"Pipeline completed. Final results: {len(output_tasks)} tasks") finally: # This ensures we unset all the env vars set above during initialize and kill the pending actors. - ray.shutdown() + try: + if hardware_sampler: + self._stop_pipeline_hardware_sampler(hardware_sampler) + self._cleanup_stage_run_resources(stages) + finally: + ray.shutdown() return output_tasks + def _process_stage_dataset(self, stage: "ProcessingStage", dataset: Dataset) -> Dataset: + """Process one stage as a Ray Data transform.""" + adapter = RayDataStageAdapter(stage) + return adapter.process_dataset(dataset) + def _tasks_to_dataset(self, tasks: list[Task]) -> Dataset: """Convert list of tasks to Ray Data dataset. diff --git a/nemo_curator/backends/ray_data/utils.py b/nemo_curator/backends/ray_data/utils.py index d3b631209e..2b3df48ff5 100644 --- a/nemo_curator/backends/ray_data/utils.py +++ b/nemo_curator/backends/ray_data/utils.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ruff: noqa: ANN401 from collections.abc import Mapping +from typing import Any from loguru import logger from ray.data import ActorPoolStrategy @@ -27,6 +29,22 @@ ) +def coerce_batch_tasks(batch_items: Any) -> list[Any]: + """Normalize a Ray Data ``map_batches`` column to a Python ``list``. + + Ray Data delivers batches as column arrays; coercing to ``list[Task]`` keeps + stage code backend-agnostic. + """ + if batch_items is None: + return [] + if isinstance(batch_items, list): + return batch_items + try: + return list(batch_items) + except TypeError: + return [batch_items] + + def get_configured_actor_pool_sizing_keys(ray_stage_spec: Mapping[str, object]) -> list[str]: """Return actor-pool sizing keys configured in a ray stage spec.""" stage_spec_keys = {key.value if isinstance(key, RayStageSpecKeys) else key for key in ray_stage_spec} @@ -34,12 +52,7 @@ def get_configured_actor_pool_sizing_keys(ray_stage_spec: Mapping[str, object]) def get_actor_compute_strategy_for_stage(stage: ProcessingStage) -> ActorPoolStrategy: - """Get the Ray Data actor-pool compute strategy for a processing stage. - - Explicit stage ``num_workers`` requests a fixed-size actor pool. Otherwise, - actor stages use Ray Data's autoscaling pool and can optionally override - min/max/initial workers through ``ray_stage_spec``. - """ + """Get the Ray Data actor-pool compute strategy for a processing stage.""" num_workers = stage.num_workers() if num_workers is not None and num_workers > 0: actor_pool_sizing_keys = get_configured_actor_pool_sizing_keys(stage.ray_stage_spec()) diff --git a/nemo_curator/backends/utils.py b/nemo_curator/backends/utils.py index 6231c52da0..2e91f39041 100644 --- a/nemo_curator/backends/utils.py +++ b/nemo_curator/backends/utils.py @@ -22,6 +22,7 @@ from loguru import logger from nemo_curator.backends.base import NodeInfo, WorkerMetadata +from nemo_curator.backends.perf_identity import build_ray_perf_identity, stamp_worker_metadata from nemo_curator.stages.base import ProcessingStage from nemo_curator.utils.ray_utils import get_head_node_id, submit_on_each_node @@ -38,13 +39,11 @@ def _logger_custom_serializer( def _logger_custom_deserializer( _: None, ) -> "loguru.Logger": - # Initialize a default logger return logger def register_loguru_serializer() -> None: - """Initialize a new local Ray cluster or connects to an existing one.""" - # Turn off serization for loguru. This is needed as loguru is not serializable in general. + """Register a no-op (de)serializer for loguru (not serializable in general).""" ray.util.register_serializer( logger.__class__, serializer=_logger_custom_serializer, @@ -53,29 +52,15 @@ def register_loguru_serializer() -> None: def merge_executor_configs(base_config: dict | None, override_config: dict | None) -> dict: - """ - Recursively merge two executor configs with deep merging of nested dicts. + """Recursively deep-merge two executor configs (override wins, inputs untouched). Args: base_config: Base configuration dictionary - override_config: Configuration to merge on top of base_config + override_config: Configuration merged on top of base_config Returns: - Merged configuration dictionary with all nested dicts recursively merged - - Notes: - - Recursively merges all nested dictionaries - - Non-dict values in override_config will overwrite base_config - - Handles None values gracefully - - Does not modify original inputs (uses deep copy) - - Examples: - >>> base = {"runtime_env": {"env_vars": {"A": "1", "B": "2"}}} - >>> override = {"runtime_env": {"env_vars": {"B": "3", "C": "4"}}} - >>> merge_executor_configs(base, override) - {"runtime_env": {"env_vars": {"A": "1", "B": "3", "C": "4"}}} + Merged config with nested dicts merged recursively """ - # Handle None cases if base_config is None and override_config is None: return {} if base_config is None: @@ -83,20 +68,15 @@ def merge_executor_configs(base_config: dict | None, override_config: dict | Non if override_config is None: return deepcopy(base_config) - # Deep copy to avoid modifying originals merged_config = deepcopy(base_config) - # Recursively merge each key from override_config for key, value in override_config.items(): if isinstance(value, dict): if key not in merged_config or not isinstance(merged_config[key], dict): - # If key doesn't exist or isn't a dict, just use the override value merged_config[key] = deepcopy(value) else: - # Recursively merge nested dicts merged_config[key] = merge_executor_configs(merged_config[key], value) else: - # For non-dict values, overwrite merged_config[key] = value return merged_config @@ -143,16 +123,27 @@ def get_worker_metadata_and_node_id() -> tuple[NodeInfo, WorkerMetadata]: return NodeInfo(node_id=ray_context.get_node_id()), WorkerMetadata(worker_id=ray_context.get_worker_id()) +def get_worker_metadata_and_node_id_with_perf( + stage_name: str, + *, + requires_gpu: bool = False, +) -> tuple[NodeInfo, WorkerMetadata]: + """Get worker metadata with opt-in Ray-resolved performance identity.""" + node_info, worker_metadata = get_worker_metadata_and_node_id() + identity = build_ray_perf_identity(stage_name, requires_gpu=requires_gpu) + stamp_worker_metadata(worker_metadata, identity) + return node_info, worker_metadata + + def get_available_cpu_gpu_resources( init_and_shutdown: bool = False, ignore_head_node: bool = False ) -> tuple[int, int]: """Get available CPU and GPU resources from Ray.""" if init_and_shutdown: ray.init(ignore_reinit_error=True) - time.sleep(0.2) # ray.available_resources() returns might have a lag - # available resources can be different from total resources, however curator assumes - # entire cluster is available for use and only one pipeline is being run at a time. - # therefore available resources should match total resources. + time.sleep(0.2) # ray.available_resources() can lag + # Curator assumes the whole cluster is free (one pipeline at a time), so + # available resources should match total resources. available_resources = ray.available_resources() available_cpus = available_resources.get("CPU", 0) available_gpus = available_resources.get("GPU", 0) @@ -175,12 +166,10 @@ def get_available_cpu_gpu_resources( def check_total_gpu_capacity(gpus_needed: int, *, ignore_head_node: bool = False) -> None: - """Raise if the cluster doesn't have enough GPUs to satisfy aggregate demand. + """Raise if the cluster lacks enough GPUs for aggregate demand. - Intended as a coarse pre-check before submitting placement groups: Ray's - PG scheduler can hang indefinitely on ``pg.ready()`` when demand exceeds - capacity, so a fast, explicit error with the actual numbers is friendlier - than waiting on a timeout. + Coarse pre-check: Ray's placement-group scheduler can hang on ``pg.ready()`` + when demand exceeds capacity, so fail fast with the actual numbers. """ _, available_gpus = get_available_cpu_gpu_resources(ignore_head_node=ignore_head_node) available = int(available_gpus) @@ -191,13 +180,11 @@ def check_total_gpu_capacity(gpus_needed: int, *, ignore_head_node: bool = False @ray.remote def _setup_stage_on_node(stage: ProcessingStage) -> None: - """Ray remote function to execute setup_on_node for a stage. + """Run ``setup_on_node`` for a stage as a Ray task. - This runs as a Ray remote task (not an actor). - vLLM's auto-detection only forces the spawn multiprocessing method inside Ray actors, - not in Ray tasks. Without this override, vLLM defaults to fork in tasks and hits - RuntimeError: Cannot re-initialize CUDA in forked subprocess. - We explicitly set the environment variable to spawn to prevent this. + Force vLLM's spawn method: it auto-sets spawn only inside Ray actors, not + tasks, so without this fork would hit "Cannot re-initialize CUDA in forked + subprocess". """ os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") node_id = ray.get_runtime_context().get_node_id() @@ -207,10 +194,9 @@ def _setup_stage_on_node(stage: ProcessingStage) -> None: def execute_setup_on_node(stages: list[ProcessingStage], ignore_head_node: bool = False) -> None: """Execute ``setup_on_node`` for every stage on every alive Ray node. - All ``(stage, node)`` setup tasks are submitted up front and awaited with a single - ``ray.get``, so total wall-clock time is bounded by the slowest stage rather than - the sum of per-stage times — important when setup is heavy (model downloads, weight - loads) and stages don't contend for the same resources. + All ``(stage, node)`` tasks are submitted up front and awaited with one + ``ray.get``, so wall-clock time is bounded by the slowest stage (matters when + setup is heavy: model downloads, weight loads). """ head_node_id = get_head_node_id() if ignore_head_node else None for node in ray.nodes(): @@ -223,13 +209,14 @@ def execute_setup_on_node(stages: list[ProcessingStage], ignore_head_node: bool refs: list = [] for stage in stages: + setup_resources = stage.setup_on_node_resources() refs.extend( submit_on_each_node( _setup_stage_on_node, stage, ignore_head_node=ignore_head_node, - num_cpus=stage.resources.cpus if stage.resources is not None else 1, - num_gpus=stage.resources.gpus if stage.resources is not None else 0, + num_cpus=setup_resources.cpus, + num_gpus=setup_resources.gpus, ) ) ray.get(refs) diff --git a/nemo_curator/backends/xenna/adapter.py b/nemo_curator/backends/xenna/adapter.py index 0d4e17d3d3..ab9cce5a77 100644 --- a/nemo_curator/backends/xenna/adapter.py +++ b/nemo_curator/backends/xenna/adapter.py @@ -19,9 +19,13 @@ from cosmos_xenna.pipelines.private.resources import NodeInfo as XennaNodeInfo from cosmos_xenna.pipelines.private.resources import Resources as XennaResources from cosmos_xenna.pipelines.private.resources import WorkerMetadata as XennaWorkerMetadata -from loguru import logger -from nemo_curator.backends.base import BaseStageAdapter, NodeInfo, WorkerMetadata +from nemo_curator.backends.base import ( + BaseStageAdapter, + NodeInfo, + WorkerMetadata, +) +from nemo_curator.backends.perf_identity import build_xenna_perf_identity, stamp_worker_metadata from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import Task @@ -68,7 +72,6 @@ def __init__( @property def required_resources(self) -> XennaResources: """Get the resources required for this stage.""" - logger.info(f"Resources: {self.processing_stage.resources}") return XennaResources( cpus=self.processing_stage.resources.cpus, gpus=self.processing_stage.resources.gpus, @@ -78,7 +81,7 @@ def required_resources(self) -> XennaResources: def stage_batch_size(self) -> int: """Get the batch size for this stage.""" batch_size = self.processing_stage.batch_size - return batch_size if batch_size is not None else 1 + return 1 if batch_size is None else int(batch_size) @property def env_info(self) -> pipelines_v1.RuntimeEnv | None: @@ -98,7 +101,6 @@ def process_data(self, tasks: list[Task]) -> list[Task] | None: Returns: List of processed tasks or None """ - # Use the base stage's monitoring capability return self.process_batch(tasks) def setup_on_node(self, node_info: XennaNodeInfo, worker_metadata: XennaWorkerMetadata) -> None: @@ -111,10 +113,20 @@ def setup_on_node(self, node_info: XennaNodeInfo, worker_metadata: XennaWorkerMe """ # Convert Xenna's types to our generic types (simplified) generic_node_info = NodeInfo(node_id=node_info.node_id) + requires_gpu = bool(getattr(getattr(self.processing_stage, "resources", None), "requires_gpu", False)) generic_worker_metadata = WorkerMetadata( worker_id=worker_metadata.worker_id, - allocation=worker_metadata.allocation, # Keep the original allocation object + allocation=worker_metadata.allocation, ) + if bool(getattr(self.processing_stage, "extended_performance_metrics", False)): + identity = build_xenna_perf_identity( + str(self.processing_stage.name), + worker_id=worker_metadata.worker_id, + node_id=node_info.node_id, + allocation=worker_metadata.allocation, + requires_gpu=requires_gpu, + ) + stamp_worker_metadata(generic_worker_metadata, identity) super().setup_on_node(generic_node_info, generic_worker_metadata) def setup(self, worker_metadata: XennaWorkerMetadata) -> None: @@ -125,10 +137,20 @@ def setup(self, worker_metadata: XennaWorkerMetadata) -> None: worker_metadata: Xenna's WorkerMetadata object """ # Convert Xenna's WorkerMetadata to our generic type + requires_gpu = bool(getattr(getattr(self.processing_stage, "resources", None), "requires_gpu", False)) generic_worker_metadata = WorkerMetadata( worker_id=worker_metadata.worker_id, - allocation=worker_metadata.allocation, # Keep the original allocation object + allocation=worker_metadata.allocation, ) + if bool(getattr(self.processing_stage, "extended_performance_metrics", False)): + identity = build_xenna_perf_identity( + str(self.processing_stage.name), + worker_id=worker_metadata.worker_id, + node_id="", + allocation=worker_metadata.allocation, + requires_gpu=requires_gpu, + ) + stamp_worker_metadata(generic_worker_metadata, identity) super().setup(generic_worker_metadata) diff --git a/nemo_curator/backends/xenna/executor.py b/nemo_curator/backends/xenna/executor.py index fd86afb354..b9a0808aca 100644 --- a/nemo_curator/backends/xenna/executor.py +++ b/nemo_curator/backends/xenna/executor.py @@ -19,9 +19,13 @@ from cosmos_xenna.utils.verbosity import VerbosityLevel from loguru import logger -from nemo_curator.backends.base import BaseExecutor +from nemo_curator.backends.base import ( + BaseExecutor, +) from nemo_curator.backends.utils import register_loguru_serializer -from nemo_curator.backends.xenna.adapter import create_named_xenna_stage_adapter +from nemo_curator.backends.xenna.adapter import ( + create_named_xenna_stage_adapter, +) from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import EmptyTask, Task @@ -57,6 +61,11 @@ def __init__(self, config: dict[str, Any] | None = None, ignore_head_node: bool "execution_mode": "streaming", "cpu_allocation_percentage": 0.95, "autoscale_interval_s": 180, + "actor_pool_verbosity_level": "INFO", + "monitoring_verbosity_level": "INFO", + "autoscaler_verbosity_level": "INFO", + "executor_verbosity_level": "INFO", + "log_worker_allocation_layout": True, } def execute(self, stages: list[ProcessingStage], initial_tasks: list[Task] | None = None) -> list[Task]: @@ -69,49 +78,22 @@ def execute(self, stages: list[ProcessingStage], initial_tasks: list[Task] | Non Returns: list[Task]: List of output tasks from the pipeline """ - # Convert stages to Xenna stage specs - stage_specs = [] - - # Initialize with initial tasks if provided, otherwise start with EmptyTask initial_tasks = initial_tasks if initial_tasks else [EmptyTask()] + return self._run_xenna_pipeline(stages, initial_tasks) - for stage in stages: - # Get stage configuration - stage_config = stage.xenna_stage_spec() - if "num_workers" in stage_config: - msg = f"Stage {stage.name} sets num_workers in xenna_stage_spec(). Use num_workers() instead." - raise ValueError(msg) - - num_workers = stage.num_workers() - num_workers_per_node = stage_config.get("num_workers_per_node") - if num_workers is not None and num_workers_per_node is not None: - msg = ( - f"Stage {stage.name} sets both num_workers() and " - "xenna_stage_spec()['num_workers_per_node']. Use only one worker sizing option." - ) - raise ValueError(msg) - - # Create Xenna stage adapter with the original stage's name - xenna_stage = create_named_xenna_stage_adapter( - stage=stage, - ) + def _run_xenna_pipeline( + self, + stages: list[ProcessingStage], + initial_tasks: list[Any], + ) -> list[Any]: + if not stages: + return initial_tasks - # Create stage spec with configuration from stage - stage_spec = pipelines_v1.StageSpec( - stage=xenna_stage, - num_workers=num_workers, - num_workers_per_node=num_workers_per_node, - num_setup_attempts_python=stage_config.get("num_setup_attempts_python"), - num_run_attempts_python=stage_config.get("num_run_attempts_python"), - ignore_failures=stage_config.get("ignore_failures"), - reset_workers_on_failure=stage_config.get("reset_workers_on_failure"), - slots_per_actor=stage_config.get("slots_per_actor"), - worker_max_lifetime_m=stage_config.get("worker_max_lifetime_m"), - worker_restart_interval_m=stage_config.get("worker_restart_interval_m"), - max_setup_failure_percentage=stage_config.get("max_setup_failure_percentage"), - ) + # Convert stages to Xenna stage specs + stage_specs = [] - stage_specs.append(stage_spec) + for stage in stages: + stage_specs.append(self._build_stage_spec(stage)) # Determine execution mode exec_mode = pipelines_v1.ExecutionMode.STREAMING @@ -123,21 +105,21 @@ def execute(self, stages: list[ProcessingStage], initial_tasks: list[Task] | Non if exec_mode == pipelines_v1.ExecutionMode.STREAMING: streaming_config = pipelines_v1.StreamingSpecificSpec( autoscale_interval_s=self._get_pipeline_config("autoscale_interval_s"), - autoscaler_verbosity_level=VerbosityLevel.INFO, # TODO: Move this to pipeline config - executor_verbosity_level=VerbosityLevel.INFO, + autoscaler_verbosity_level=self._get_verbosity_config("autoscaler_verbosity_level"), + executor_verbosity_level=self._get_verbosity_config("executor_verbosity_level"), ) # Create pipeline configuration pipeline_config = pipelines_v1.PipelineConfig( execution_mode=exec_mode, logging_interval_s=self._get_pipeline_config("logging_interval"), - log_worker_allocation_layout=True, + log_worker_allocation_layout=bool(self._get_pipeline_config("log_worker_allocation_layout")), return_last_stage_outputs=True, ignore_failures=self._get_pipeline_config("ignore_failures"), cpu_allocation_percentage=self._get_pipeline_config("cpu_allocation_percentage"), mode_specific=streaming_config, - actor_pool_verbosity_level=VerbosityLevel.INFO, # TODO: Move this to pipeline config - monitoring_verbosity_level=VerbosityLevel.INFO, + actor_pool_verbosity_level=self._get_verbosity_config("actor_pool_verbosity_level"), + monitoring_verbosity_level=self._get_verbosity_config("monitoring_verbosity_level"), ) # Create pipeline specification @@ -146,6 +128,7 @@ def execute(self, stages: list[ProcessingStage], initial_tasks: list[Task] | Non # Log pipeline configuration logger.info(f"Execution mode: {exec_mode.name}") + hardware_sampler: list[Any] = [] try: register_loguru_serializer() # Prevent Ray from overriding accelerator env vars when num_gpus=0, letting Xenna manage them instead. @@ -158,17 +141,87 @@ def execute(self, stages: list[ProcessingStage], initial_tasks: list[Task] | Non } }, ) + hardware_sampler = self._start_pipeline_hardware_sampler() # Run the pipeline (this will re-initialize ray but that'll be a no-op and the ray.init above will take precedence) results = pipelines_v1.run_pipeline(pipeline_spec) + hardware_perf = self._stop_pipeline_hardware_sampler(hardware_sampler) + hardware_sampler = [] + if results: + self._attach_pipeline_hardware_perf(results, hardware_perf) + self._publish_external_perf(stages, hardware_perf) logger.info(f"Pipeline completed successfully with {len(results) if results else 0} output tasks") except Exception as e: logger.error(f"Pipeline execution failed: {e}") raise finally: # This ensures we unset all the env vars set above during initialize and kill the pending actors. - ray.shutdown() + try: + if hardware_sampler: + self._stop_pipeline_hardware_sampler(hardware_sampler) + self._cleanup_stage_run_resources(stages) + finally: + ray.shutdown() return results if results else [] + def _build_stage_spec(self, stage: ProcessingStage) -> pipelines_v1.StageSpec: + """Create a Xenna StageSpec from a Curator stage.""" + stage_config = stage.xenna_stage_spec() + num_workers, num_workers_per_node = self._resolve_stage_worker_sizing(stage, stage_config) + xenna_stage = create_named_xenna_stage_adapter(stage=stage) + + return pipelines_v1.StageSpec( + stage=xenna_stage, + num_workers=num_workers, + num_workers_per_node=num_workers_per_node, + num_setup_attempts_python=stage_config.get("num_setup_attempts_python"), + num_run_attempts_python=stage_config.get("num_run_attempts_python"), + ignore_failures=stage_config.get("ignore_failures"), + reset_workers_on_failure=stage_config.get("reset_workers_on_failure"), + slots_per_actor=stage_config.get("slots_per_actor"), + worker_max_lifetime_m=stage_config.get("worker_max_lifetime_m"), + worker_restart_interval_m=stage_config.get("worker_restart_interval_m"), + max_setup_failure_percentage=stage_config.get("max_setup_failure_percentage"), + ) + + @staticmethod + def _resolve_stage_worker_sizing( + stage: ProcessingStage, stage_config: dict[str, Any] + ) -> tuple[int | None, int | None]: + """Resolve Xenna worker sizing with the main-branch contract.""" + if "num_workers" in stage_config: + msg = f"Stage {stage.name} sets num_workers in xenna_stage_spec(). Use num_workers() instead." + raise ValueError(msg) + num_workers = stage.num_workers() + num_workers_per_node = stage_config.get("num_workers_per_node") + if num_workers is not None and num_workers_per_node is not None: + msg = ( + f"Stage {stage.name} sets both num_workers() and " + "xenna_stage_spec()['num_workers_per_node']. Use only one worker sizing option." + ) + raise ValueError(msg) + return num_workers, num_workers_per_node + def _get_pipeline_config(self, key: str) -> Any: # noqa: ANN401 """Get configuration value with fallback to defaults.""" return self.config.get(key, self._default_pipeline_config.get(key)) + + def _get_verbosity_config(self, key: str) -> VerbosityLevel: + """Get Xenna verbosity level from enum, integer, or string config.""" + value = self._get_pipeline_config(key) + if value is None: + value = self._default_pipeline_config.get(key, "INFO") + if isinstance(value, VerbosityLevel): + return value + if isinstance(value, str): + try: + return VerbosityLevel[value.upper()] + except KeyError as exc: + valid = ", ".join(level.name for level in VerbosityLevel) + msg = f"Invalid Xenna verbosity config {key}={value!r}; expected one of: {valid}" + raise ValueError(msg) from exc + try: + return VerbosityLevel(value) + except ValueError as exc: + valid = ", ".join(level.name for level in VerbosityLevel) + msg = f"Invalid Xenna verbosity config {key}={value!r}; expected one of: {valid}" + raise ValueError(msg) from exc diff --git a/nemo_curator/models/asr/__init__.py b/nemo_curator/models/asr/__init__.py new file mode 100644 index 0000000000..8319c954d4 --- /dev/null +++ b/nemo_curator/models/asr/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ASR model adapters. + +Exposes :class:`ASRAdapter` / :class:`ASRResult` (always importable) from +``base``. Concrete adapters (e.g. :class:`QwenOmniASRAdapter`) live in +submodules, resolved via YAML ``adapter_target`` or lazy attribute access +(PEP 562) to avoid importing heavy GPU deps eagerly. +""" + +from nemo_curator.models.asr.base import ASRAdapter, ASRResult + +_LAZY: dict[str, str] = { + "QwenOmniASRAdapter": ".qwen_omni", +} + +__all__ = ["ASRAdapter", "ASRResult", "QwenOmniASRAdapter"] + + +def __getattr__(name: str) -> object: + if name in _LAZY: + import importlib + + mod = importlib.import_module(_LAZY[name], package=__name__) + return getattr(mod, name) + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/nemo_curator/models/asr/base.py b/nemo_curator/models/asr/base.py new file mode 100644 index 0000000000..f5a86dd24d --- /dev/null +++ b/nemo_curator/models/asr/base.py @@ -0,0 +1,123 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage-adapter contract for audio speech-recognition. + +Mirrors the diarization/LID/VAD contract: ``ASRStage`` owns Curator-side glue +(``task.data`` reads, batching, ISO language mapping, ``_skip_me``, metrics), +while ``ASRAdapter`` (this module) owns the model-side call (prefetch, setup, +generation, packing into ``ASRResult``). The split lets the stage swap models +via a single YAML ``adapter_target:`` line. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + + +@dataclass +class ASRResult: + """Canonical per-utterance ASR adapter output. + + Identical across every adapter so the stage's schema-mutation path stays + constant when the adapter is swapped. + + Attributes: + text: Primary transcription (Turn-1 / sole output). Empty if skipped. + secondary_text: Optional Turn-2 / disfluency-preserved output; + ``None`` for single-turn or skipped Turn-2. Written to + ``task.data`` only when ``ASRStage.disfluency_text_key`` is set. + skipped: True when the item could not be processed (e.g. empty/corrupt + waveform); the stage then sets ``skip_me_key = "empty_audio"``. + model_id: Identifier of the model actually run (populated by adapter). + extras: Adapter-specific diagnostics outside the canonical shape; the + stage never reads inside this dict. + """ + + text: str + secondary_text: str | None = None + skipped: bool = False + model_id: str = "" + extras: dict[str, Any] = field(default_factory=dict) + + +@runtime_checkable +class ASRAdapter(Protocol): + """Structural protocol every ASR adapter must implement. + + Constructor contract: the stage builds adapters as + ``cls(model_id=..., revision=..., **adapter_kwargs)``, so every adapter + must accept ``model_id`` and ``revision`` keyword args plus its own knobs. + + Per-batch contract: ``transcribe_batch`` receives a list of per-task dicts + (unpacked from ``task.data``) and returns one ``ASRResult`` per input, in + order. Expected per-item keys (stage-populated): + + * ``waveform``: canonical Curator waveform object from the stage + (typically a torch tensor shaped ``(channels, samples)``); adapters own + any model-specific conversion such as squeezing to 1-D numpy. + * ``sample_rate`` (``int``): source rate; adapter handles any resampling. + * ``language`` (``str | None``): human-readable name (e.g. ``"English"``). + * ``language_code`` (``str | None``): original language code from the + configured stage input column. + * ``reference_text`` (``str | None``): optional transcript/reference text + from ``ASRStage.reference_text_key`` for adapters or prompts that need + row-level text context. + * ``task_id`` (``str | None``): carried through for diagnostics. + * ``audio_seconds`` (``float``): estimated item duration for metrics or + adapter-side policy decisions. + * ``chunk_idx`` / ``chunk_count`` (``int | None``): chunk position metadata + when the stage split a long parent row. Adapters may ignore any metadata + keys they do not need. + + Attributes: + model_id: Identifier of the underlying model checkpoint. + last_metrics: Scalar metrics from the last ``transcribe_batch`` call; + the stage merges these under ``model_`` aliases. + + Optional methods: + ``estimate_item_cost(item) -> float | None`` may be implemented by an + adapter to provide a scheduler cost better than raw duration, such as + estimated encoder tokens or approximate VRAM units. ``ASRStage`` falls + back to duration when the method is absent or returns ``None``. + """ + + model_id: str + last_metrics: dict[str, float] + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + """Download weights to local cache without allocating a GPU. + + Classmethod so the stage can call it (once per node) without + instantiating the adapter or importing heavy GPU libraries. + """ + ... + + def setup(self) -> None: + """Load the model into the worker process (once per worker).""" + ... + + def teardown(self) -> None: + """Release GPU memory and worker-local state.""" + ... + + def transcribe_batch(self, items: list[dict[str, Any]]) -> list[ASRResult]: + """Run inference on a batch of per-task dicts. + + Returns one ``ASRResult`` per input, in order; skipped items must + still appear with ``skipped=True`` to preserve task ordering. + """ + ... diff --git a/nemo_curator/models/asr/qwen_omni.py b/nemo_curator/models/asr/qwen_omni.py new file mode 100644 index 0000000000..87771c76dd --- /dev/null +++ b/nemo_curator/models/asr/qwen_omni.py @@ -0,0 +1,624 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3-Omni ASR adapter (in-process vLLM). + +Implements the :class:`~nemo_curator.models.asr.ASRAdapter` protocol on the +in-process vLLM thinker-only path. Two-turn (Turn-1 transcribe, Turn-2 +disfluency/refinement) when ``followup_prompt`` is set; single-turn otherwise. + +Engine plumbing is inherited from +:class:`nemo_curator.models.vllm_model.VLLMBase`; this module adds the +Qwen-Omni surface (multimodal preprocessing, prompt construction, prep thread +pool, adapter protocol methods). Both turns share ``_infer_turn`` and +``_pack_vllm_inputs``, differing only in prompt and output list. +""" + +from __future__ import annotations + +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np +from huggingface_hub import snapshot_download +from loguru import logger + +from nemo_curator.models.asr.base import ASRResult +from nemo_curator.models.vllm_model import VLLM_AVAILABLE, VLLMBase +from nemo_curator.utils.gpu_utils import get_gpu_count + +try: + from qwen_omni_utils import process_mm_info +except ImportError: + process_mm_info = None # type: ignore[assignment,misc] + +try: + from transformers import Qwen3OmniMoeProcessor +except ImportError: + Qwen3OmniMoeProcessor = None # type: ignore[assignment,misc] + + +def _require_audio_qwen_stack(*, context: str) -> None: + """Raise a single ImportError listing missing audio_qwen-only deps.""" + missing: list[str] = [] + if not VLLM_AVAILABLE: + missing.append("vllm") + if process_mm_info is None: + missing.append("qwen-omni-utils") + if Qwen3OmniMoeProcessor is None: + missing.append("transformers (Qwen3OmniMoeProcessor)") + if missing: + msg = ( + f"QwenOmniASRAdapter {context} requires the audio_qwen extra. " + f"Missing: {', '.join(missing)}. Install with: uv sync --extra audio_qwen" + ) + raise ImportError(msg) + + +_QWEN3_OMNI_MODEL_ID = "Qwen/Qwen3-Omni-30B-A3B-Instruct" +_QWEN_SAMPLE_RATE = 16000 +_MIN_QWEN_AUDIO_SAMPLES = 1600 +_WAVEFORM_2D_NDIM = 2 +_FOLLOWUP_PROMPT_DEFAULT = ( + "Now listen to the audio again and add any false starts, filler words " + "and preserve colloquial words (like lemme, gonna, wanna, etc) as is " + "spoken in the audio." +) + + +@dataclass +class QwenOmniASRAdapter(VLLMBase): + """Qwen3-Omni in-process vLLM adapter (thinker-only path). + + Stages construct adapters via + ``cls(model_id=..., revision=..., **adapter_kwargs)``, so every field + below is a keyword-only knob settable from the YAML ``adapter_kwargs``. + + Resource expectations: + * ~40 GB VRAM for Qwen3-Omni-30B-A3B (FP8): one A100-80GB or two + A100-40GB with ``tensor_parallel_size=2``. + * ~50-80 audio-seconds/GPU-second on A100-80GB at ``batch_size=32``. + * ~15 GB cached weights on first run (HuggingFace Hub). + + Notable Args (most are plain vLLM/sampling knobs): + prompt_text / *_file: Turn-1 user prompt; ``{language}`` and + ``{transcript}`` are interpolated per-item when the stage supplies + language and reference text values. ``*_file`` variants load text + from a UTF-8 file at ``__post_init__`` time. + en_prompt_text / en_prompt_file: override used when language is + ``"English"``. + followup_prompt / *_file: when set, enables Turn-2 inference. + system_prompt / *_file: optional system message for both turns. + tensor_parallel_size: ``None`` -> auto-detect from visible GPUs. + enable_prefix_caching: default ``True`` since prompts repeat across + requests; disable for highly variable prompts. + limit_mm_per_prompt_audio: per-prompt audio cap; ``2`` covers the + two-turn flow, ``1`` for strictly single-turn. This audio adapter + passes image/video multimodal caps as ``1`` for Qwen/vLLM + compatibility even though ASR requests only attach audio payloads. + max_num_batched_tokens: optional vLLM scheduler/encoder-cache budget. + Long single audio items can exceed the default multimodal encoder + cache even when ``max_model_len`` is large enough; set this to at + least the observed audio feature length for 40-50 minute probes. + seed: exposed so reproducibility / bit-exactness tests can override. + """ + + model_id: str = _QWEN3_OMNI_MODEL_ID + revision: str | None = None + + prompt_text: str = "Transcribe the audio." + prompt_file: str | None = None + en_prompt_text: str | None = None + en_prompt_file: str | None = None + followup_prompt: str | None = None + followup_prompt_file: str | None = None + system_prompt: str | None = None + system_prompt_file: str | None = None + max_model_len: int = 32768 + max_num_batched_tokens: int | None = None + max_num_seqs: int = 32 + gpu_memory_utilization: float = 0.95 + tensor_parallel_size: int | None = None + max_output_tokens: int = 256 + temperature: float = 0.0 + top_k: int = 1 + prep_workers: int = 8 + + enable_prefix_caching: bool = True + prefix_caching_hash_algo: str = "xxhash" + limit_mm_per_prompt_audio: int = 2 + seed: int = 1234 + + last_metrics: dict[str, float] = field(default_factory=dict) + + def __post_init__(self) -> None: + self.prompt_text = self._load_text(self.prompt_text, self.prompt_file) or "" + self.en_prompt_text = self._load_text(self.en_prompt_text, self.en_prompt_file) + self.followup_prompt = self._load_text(self.followup_prompt, self.followup_prompt_file) + self.system_prompt = self._load_text(self.system_prompt, self.system_prompt_file) + + if self.max_num_batched_tokens is not None and self.max_num_batched_tokens <= 0: + msg = "max_num_batched_tokens must be positive when set" + raise ValueError(msg) + if self.limit_mm_per_prompt_audio <= 0: + msg = "limit_mm_per_prompt_audio must be positive" + raise ValueError(msg) + + self._processor: Any = None + self._prep_pool: ThreadPoolExecutor | None = None + + @staticmethod + def _load_text(text: str | None, file_path: str | None) -> str | None: + if file_path: + path = Path(file_path) + if not path.exists(): + msg = f"QwenOmniASRAdapter prompt file not found: {path}" + raise FileNotFoundError(msg) + return path.read_text(encoding="utf-8").strip() + return text + + @classmethod + def prefetch_weights(cls, model_id: str, revision: str | None = None) -> None: + """Cache the model snapshot on local disk without touching the GPU.""" + kwargs: dict[str, Any] = {} + if revision is not None: + kwargs["revision"] = revision + snapshot_download(model_id, **kwargs) + + def setup(self) -> None: + if self._llm is not None: + return + _require_audio_qwen_stack(context="setup()") + + tp_size = self.tensor_parallel_size or get_gpu_count() + logger.info( + f"Loading QwenOmni model={self.model_id} tp={tp_size} " + f"max_model_len={self.max_model_len} max_num_seqs={self.max_num_seqs}" + + ( + f" max_num_batched_tokens={self.max_num_batched_tokens}" + if self.max_num_batched_tokens is not None + else "" + ) + + (f" revision={self.revision}" if self.revision is not None else "") + ) + + model_kwargs: dict[str, Any] = { + "model": self.model_id, + "trust_remote_code": True, + "gpu_memory_utilization": self.gpu_memory_utilization, + "tensor_parallel_size": tp_size, + "limit_mm_per_prompt": {"image": 1, "video": 1, "audio": int(self.limit_mm_per_prompt_audio)}, + "max_num_seqs": self.max_num_seqs, + "max_model_len": self.max_model_len, + "seed": int(self.seed), + "enable_prefix_caching": bool(self.enable_prefix_caching), + "prefix_caching_hash_algo": str(self.prefix_caching_hash_algo), + } + if self.max_num_batched_tokens is not None: + model_kwargs["max_num_batched_tokens"] = int(self.max_num_batched_tokens) + if self.revision is not None: + model_kwargs["revision"] = self.revision + + sampling_kwargs: dict[str, Any] = { + "temperature": self.temperature, + "top_k": self.top_k, + "max_tokens": self.max_output_tokens, + } + + try: + self._init_engine(model_kwargs, sampling_kwargs) + + proc_kwargs: dict[str, Any] = {} + if self.revision is not None: + proc_kwargs["revision"] = self.revision + self._processor = Qwen3OmniMoeProcessor.from_pretrained(self.model_id, **proc_kwargs) + self._prep_pool = ThreadPoolExecutor(max_workers=self.prep_workers) + except Exception: + self.teardown() + raise + + def teardown(self) -> None: + if self._prep_pool is not None: + self._prep_pool.shutdown(wait=False) + self._prep_pool = None + self._processor = None + self._cleanup_gpu() + + def estimate_item_cost(self, item: dict[str, Any]) -> float | None: + """Return an optional scheduler cost for one prepared ASR item.""" + + for key in ("estimated_vram_units", "estimated_encoder_tokens", "audio_seconds"): + value = item.get(key) + if value is not None: + return float(value) + return None + + def transcribe_batch(self, items: list[dict[str, Any]]) -> list[ASRResult]: + """Run batched two-turn inference over per-task dicts. + + Skipped items (empty / unprocessable waveforms) round-trip as + ``ASRResult(text="", skipped=True)`` to preserve ordering. + """ + if not items: + return [] + waveforms = [it["waveform"] for it in items] + sample_rates = [it["sample_rate"] for it in items] + languages = [it.get("language") for it in items] + reference_texts = [it.get("reference_text") for it in items] + pred_texts, disfl_texts, skipped_indices = self._run_two_turn( + waveforms, + sample_rates, + languages, + reference_texts, + ) + has_t2 = bool(self.followup_prompt) + return [ + ASRResult( + text=pred, + secondary_text=(disfl if has_t2 else None), + skipped=(i in skipped_indices), + model_id=self.model_id, + ) + for i, (pred, disfl) in enumerate(zip(pred_texts, disfl_texts, strict=True)) + ] + + # Input preparation + + @staticmethod + def _to_mono_numpy_1d(waveform: object) -> np.ndarray: + """Normalize Curator waveform objects to Qwen's 1-D mono numpy input.""" + if waveform is None: + return np.asarray([], dtype=np.float32) + if hasattr(waveform, "detach"): + waveform = waveform.detach().cpu().numpy() + arr = np.asarray(waveform, dtype=np.float32) + if arr.size == 0: + return arr.reshape(0) + if arr.ndim == 0: + return arr.reshape(1) + if arr.ndim == 1: + return np.ascontiguousarray(arr) + + squeezed = np.squeeze(arr) + if squeezed.ndim == 1: + return np.ascontiguousarray(squeezed.astype(np.float32, copy=False)) + if squeezed.ndim == _WAVEFORM_2D_NDIM: + # Curator's canonical waveform is channels-first (C, T). If an + # adapter caller supplies channel-last (T, C), average over the + # smaller channel-looking axis. + axis = 0 if squeezed.shape[0] <= squeezed.shape[1] else 1 + return np.ascontiguousarray(squeezed.mean(axis=axis).astype(np.float32, copy=False)) + + msg = f"Expected 1-D or 2-D waveform, got shape {arr.shape}" + raise ValueError(msg) + + @staticmethod + def _resample(waveform: np.ndarray, orig_sr: int, target_sr: int = _QWEN_SAMPLE_RATE) -> np.ndarray: + if orig_sr == target_sr: + return waveform + import librosa + + return librosa.resample(waveform, orig_sr=orig_sr, target_sr=target_sr) + + def _resolve_prompt(self, template: str, language: str | None, reference_text: str | None = None) -> str: + result = template + if language and "{language}" in result: + result = result.replace("{language}", language) + if reference_text is not None and "{transcript}" in result: + result = result.replace("{transcript}", reference_text) + return result + + def _get_prompt_text(self, language: str | None, reference_text: str | None = None) -> str: + if language == "English" and self.en_prompt_text: + return self._resolve_prompt(self.en_prompt_text, language, reference_text) + return self._resolve_prompt(self.prompt_text, language, reference_text) + + def _build_audio_prompt_messages( + self, + waveform: np.ndarray, + language: str | None = None, + reference_text: str | None = None, + ) -> list[dict[str, Any]]: + prompt = self._get_prompt_text(language, reference_text) + messages: list[dict[str, Any]] = [] + if self.system_prompt: + sys_prompt = self._resolve_prompt(self.system_prompt, language) + messages.append({"role": "system", "content": [{"type": "text", "text": sys_prompt}]}) + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "audio", "audio": waveform}, + ], + } + ) + return messages + + def _build_messages( + self, + waveform: np.ndarray, + language: str | None = None, + reference_text: str | None = None, + ) -> list[dict[str, Any]]: + return self._build_audio_prompt_messages(waveform, language, reference_text) + + def _build_turn2_messages( + self, + waveform: np.ndarray, + pred_text: str, + language: str | None = None, + reference_text: str | None = None, + ) -> list[dict[str, Any]]: + followup = self._resolve_prompt(self.followup_prompt or _FOLLOWUP_PROMPT_DEFAULT, language, reference_text) + messages = self._build_audio_prompt_messages(waveform, language, reference_text) + messages.append({"role": "assistant", "content": [{"type": "text", "text": pred_text}]}) + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": followup}, + ], + } + ) + return messages + + def _pack_vllm_inputs(self, messages: list[dict[str, Any]]) -> dict[str, Any]: + """Render chat ``messages`` into a vLLM request dict (shared by both turns).""" + text = self._processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + audios, images, videos = process_mm_info(messages, use_audio_in_video=False) + inputs: dict[str, Any] = { + "prompt": text, + "multi_modal_data": {}, + "mm_processor_kwargs": {"use_audio_in_video": False}, + } + if audios is not None: + inputs["multi_modal_data"]["audio"] = audios + if images is not None: + inputs["multi_modal_data"]["image"] = images + if videos is not None: + inputs["multi_modal_data"]["video"] = videos + return inputs + + def _prepare_single( + self, + waveform: object, + sample_rate: int, + language: str | None = None, + reference_text: str | None = None, + ) -> tuple[dict[str, Any], np.ndarray] | None: + try: + waveform_1d = self._to_mono_numpy_1d(waveform) + if waveform_1d.size == 0: + logger.warning("Skipping empty waveform") + return None + if waveform_1d.size < _MIN_QWEN_AUDIO_SAMPLES: + logger.warning("Skipping too-short waveform ({} samples)", waveform_1d.size) + return None + waveform_16k = self._resample(waveform_1d, sample_rate) + messages = self._build_messages(waveform_16k, language, reference_text) + inputs = self._pack_vllm_inputs(messages) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to preprocess audio, skipping (waveform shape={}, sr={}): {}", + getattr(waveform, "shape", None), + sample_rate, + exc, + ) + return None + + return inputs, waveform_16k + + def _prepare_batch( + self, + waveforms: list[object], + sample_rates: list[int], + languages: list[str | None] | None = None, + reference_texts: list[str | None] | None = None, + ) -> list[tuple[dict[str, Any], np.ndarray] | None]: + langs = languages or [None] * len(waveforms) + refs = reference_texts or [None] * len(waveforms) + if self._prep_pool is None: + return [ + self._prepare_single(w, sr, lang, ref) + for w, sr, lang, ref in zip(waveforms, sample_rates, langs, refs, strict=False) + ] + return list(self._prep_pool.map(self._prepare_single, waveforms, sample_rates, langs, refs)) + + def _prepare_turn2_single( + self, + waveform_16k: np.ndarray, + pred_text: str, + language: str | None = None, + reference_text: str | None = None, + ) -> dict[str, Any] | None: + try: + messages = self._build_turn2_messages(waveform_16k, pred_text, language, reference_text) + inputs = self._pack_vllm_inputs(messages) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to preprocess Turn 2 audio (shape={}): {}", + getattr(waveform_16k, "shape", None), + exc, + ) + return None + + return inputs + + def _prepare_turn2_batch( + self, + waveforms_16k: list[np.ndarray], + pred_texts: list[str], + languages: list[str | None] | None = None, + reference_texts: list[str | None] | None = None, + ) -> list[dict[str, Any] | None]: + langs = languages or [None] * len(waveforms_16k) + refs = reference_texts or [None] * len(waveforms_16k) + if self._prep_pool is None: + return [ + self._prepare_turn2_single(w, pt, lang, ref) + for w, pt, lang, ref in zip(waveforms_16k, pred_texts, langs, refs, strict=False) + ] + return list(self._prep_pool.map(self._prepare_turn2_single, waveforms_16k, pred_texts, langs, refs)) + + @staticmethod + def _count_output_tokens(outputs: list[Any]) -> float: + total = 0.0 + for output in outputs: + sequences = getattr(output, "outputs", None) or [] + if not sequences: + continue + token_ids = getattr(sequences[0], "token_ids", None) + if token_ids is not None: + total += float(len(token_ids)) + return total + + @staticmethod + def _first_output_text(output: Any) -> str: # noqa: ANN401 + sequences = getattr(output, "outputs", None) or [] + if not sequences: + return "" + return (getattr(sequences[0], "text", "") or "").strip() + + def _infer_turn( + self, + inputs: list[dict[str, Any]], + indices: list[int], + n: int, + ) -> tuple[list[str], float, float]: + """Run one vLLM turn and scatter its texts back to input order. + + ``indices[k]`` is the position in the length-``n`` batch that + ``inputs[k]`` came from. Returns + ``(texts_of_len_n, generation_time_s, output_token_count)``. + """ + t0 = time.perf_counter() + outputs = self._generate(inputs) + generation_time_s = time.perf_counter() - t0 + output_tokens = self._count_output_tokens(outputs) + texts: list[str] = [""] * n + # strict=True: a count mismatch means a broken engine contract; fail + # loud rather than silently emit empty text with skipped=False. + for idx, out in zip(indices, outputs, strict=True): + texts[idx] = self._first_output_text(out) + return texts, generation_time_s, output_tokens + + def _run_vllm_turn( + self, + inputs: list[dict[str, Any]], + indices: list[int], + n: int, + metrics: dict[str, float], + turn_name: str, + ) -> list[str]: + texts, generation_s, output_tokens = self._infer_turn(inputs, indices, n) + metrics[f"{turn_name}_generation_time_s"] = generation_s + metrics[f"{turn_name}_output_tokens"] = output_tokens + metrics["output_tokens"] += output_tokens + return texts + + def _run_two_turn( + self, + waveforms: list[object], + sample_rates: list[int], + languages: list[str | None] | None = None, + reference_texts: list[str | None] | None = None, + ) -> tuple[list[str], list[str], set[int]]: + """Run batched two-turn inference on in-memory waveforms. + + Returns ``(pred_texts, disfluency_texts, skipped_indices)``. + ``disfluency_texts`` is all empty strings when ``followup_prompt`` + is not set. + """ + n = len(waveforms) + # audio_duration_s / waveform_bytes are deliberately omitted: the stage + # (ASRStage.assemble) owns those canonical, adapter-agnostic counters. + metrics: dict[str, float] = { + "utterances_input": float(n), + "turn1_prep_time_s": 0.0, + "turn1_generation_time_s": 0.0, + "turn2_prep_time_s": 0.0, + "turn2_generation_time_s": 0.0, + "turn1_valid_inputs": 0.0, + "turn2_valid_inputs": 0.0, + "utterances_skipped_preprocess": 0.0, + "utterances_skipped_empty_output": 0.0, + "output_tokens": 0.0, + "turn1_output_tokens": 0.0, + "turn2_output_tokens": 0.0, + } + self.last_metrics = metrics + + # -- Turn 1 ---------------------------------------------------------- + prep_t0 = time.perf_counter() + prepared = self._prepare_batch(waveforms, sample_rates, languages, reference_texts) + metrics["turn1_prep_time_s"] = time.perf_counter() - prep_t0 + valid_indices = [i for i, p in enumerate(prepared) if p is not None] + valid_inputs = [prepared[i][0] for i in valid_indices] + waveforms_16k: dict[int, np.ndarray] = {i: prepared[i][1] for i in valid_indices} + skipped_indices = set(range(n)) - set(valid_indices) + metrics["turn1_valid_inputs"] = float(len(valid_inputs)) + metrics["utterances_skipped_preprocess"] = float(len(skipped_indices)) + + if not valid_inputs: + logger.warning(f"All {n} audio samples in batch failed preprocessing") + return [""] * n, [""] * n, skipped_indices + + if len(valid_inputs) < n: + logger.warning(f"Skipped {n - len(valid_inputs)}/{n} corrupt audio samples") + + pred_texts = self._run_vllm_turn(valid_inputs, valid_indices, n, metrics, "turn1") + empty_output_indices = {i for i in valid_indices if not pred_texts[i]} + if empty_output_indices: + skipped_indices.update(empty_output_indices) + metrics["utterances_skipped_empty_output"] = float(len(empty_output_indices)) + logger.warning( + "Skipping {}/{} audio samples with empty Turn 1 vLLM output", + len(empty_output_indices), + len(valid_indices), + ) + + # -- Turn 2 (disfluency refinement) ----------------------------------- + if not self.followup_prompt: + return pred_texts, [""] * n, skipped_indices + + t2_indices = [i for i in valid_indices if i not in skipped_indices and pred_texts[i]] + if not t2_indices: + return pred_texts, [""] * n, skipped_indices + + langs = languages or [None] * n + refs = reference_texts or [None] * n + t2_prep_t0 = time.perf_counter() + t2_prepared = self._prepare_turn2_batch( + [waveforms_16k[i] for i in t2_indices], + [pred_texts[i] for i in t2_indices], + [langs[i] for i in t2_indices], + [refs[i] for i in t2_indices], + ) + metrics["turn2_prep_time_s"] = time.perf_counter() - t2_prep_t0 + + t2_valid = [(i, p) for i, p in zip(t2_indices, t2_prepared, strict=False) if p is not None] + metrics["turn2_valid_inputs"] = float(len(t2_valid)) + if not t2_valid: + logger.warning("All Turn 2 samples failed preprocessing") + return pred_texts, [""] * n, skipped_indices + + t2_valid_indices = [i for i, _ in t2_valid] + t2_inputs = [p for _, p in t2_valid] + disfluency_texts = self._run_vllm_turn(t2_inputs, t2_valid_indices, n, metrics, "turn2") + + return pred_texts, disfluency_texts, skipped_indices diff --git a/nemo_curator/models/vllm_model.py b/nemo_curator/models/vllm_model.py index 09e1e67ccf..b28029840b 100644 --- a/nemo_curator/models/vllm_model.py +++ b/nemo_curator/models/vllm_model.py @@ -12,8 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""vLLM model wrappers. + +- :class:`VLLMBase` - shared engine management (creation, generation, GPU + cleanup). ``_generate`` accepts text prompts *or* multimodal prompt dicts + so audio/vision adapters reuse the same plumbing. +- :class:`VLLMModel` - generic text-generation :class:`ModelInterface`. +""" + +from __future__ import annotations + +import gc +import time from typing import Any +import torch from loguru import logger from nemo_curator.models.base import ModelInterface @@ -33,7 +46,65 @@ class SamplingParams: pass -class VLLMModel(ModelInterface): +class VLLMBase: + """Shared vLLM engine management for text and multimodal generation. + + Holds the loaded ``LLM`` engine and ``SamplingParams`` and exposes + protected helpers for engine creation, generation, and GPU cleanup. Not + for direct instantiation. ``_generate`` returns raw ``RequestOutput`` + objects so callers can read text *and* token-level metadata. + """ + + _llm: LLM | None = None + _sampling_params: SamplingParams | None = None + + def _init_engine(self, model_kwargs: dict[str, Any], sampling_kwargs: dict[str, Any]) -> None: + """Create the vLLM ``LLM`` engine and ``SamplingParams``. + + Args forward to ``vllm.LLM`` / ``vllm.SamplingParams``. Constructor + exceptions propagate unchanged, matching the public main behavior. + """ + start_time = time.perf_counter() + self._llm = LLM(**model_kwargs) + logger.info("vLLM engine loaded in {:.3f}s", time.perf_counter() - start_time) + self._sampling_params = SamplingParams(**sampling_kwargs) + + def _generate(self, prompts: list, *, use_tqdm: bool = False) -> list: + """Run generation and return raw ``RequestOutput`` objects, one per prompt. + + ``prompts`` are text strings or multimodal prompt dicts. Raises + ``RuntimeError`` if the engine is uninitialized or generation fails. + """ + if self._llm is None or self._sampling_params is None: + msg = "vLLM engine not initialized. Call setup() first." + raise RuntimeError(msg) + try: + return self._llm.generate(prompts, sampling_params=self._sampling_params, use_tqdm=use_tqdm) + except (RuntimeError, ValueError, TypeError) as e: + msg = f"Error generating text: {e}" + raise RuntimeError(msg) from e + + def _cleanup_gpu(self) -> None: + """Release the engine and GPU memory. + + vLLM owns its tensor-parallel process group, so we do not call + ``torch.distributed.destroy_process_group()`` here: that destroys the + default/global group and would corrupt any other component (another + stage, Ray primitives) sharing it in this process. + """ + if self._llm is not None: + del self._llm + self._llm = None + self._sampling_params = None + gc.collect() + try: + torch.cuda.empty_cache() + torch.cuda.synchronize() + except Exception as e: # noqa: BLE001 + logger.debug("CUDA cache clear skipped: {}", e) + + +class VLLMModel(VLLMBase, ModelInterface): """Generic vLLM language model wrapper for text generation.""" def __init__( # noqa: PLR0913 @@ -49,23 +120,15 @@ def __init__( # noqa: PLR0913 max_tokens: int | None = None, cache_dir: str | None = None, ): - """ - Initialize the vLLM model wrapper. + """Initialize the vLLM model wrapper. Args: - model: Model identifier (e.g., "microsoft/phi-4") - max_model_len: Maximum model context length. If not specified, - will be auto-detected from HuggingFace AutoConfig. - tensor_parallel_size: Number of GPUs for tensor parallelism. - If not specified, auto-detects available GPUs. - max_num_batched_tokens: Maximum tokens per batch. Defaults to - 4096. - temperature: Sampling temperature. Defaults to 0.7. - top_p: Top-p sampling parameter. Defaults to 0.8. - top_k: Top-k sampling parameter. Defaults to 20. - min_p: Min-p sampling parameter (for Qwen3). Defaults to 0.0. - max_tokens: Maximum tokens to generate. Defaults to None. - cache_dir: Cache directory for model weights. Defaults to None. + model: Model identifier (e.g., "microsoft/phi-4"). + max_model_len: Context length; auto-detected from HF AutoConfig + when ``None``. + tensor_parallel_size: TP GPU count; auto-detected when ``None``. + min_p: Min-p sampling (Qwen3 only). + cache_dir: Model weight cache directory. """ self.model = model self.max_model_len = max_model_len @@ -77,8 +140,6 @@ def __init__( # noqa: PLR0913 self.min_p = min_p self.max_tokens = max_tokens self.cache_dir = cache_dir - self._llm: LLM | None = None - self._sampling_params: SamplingParams | None = None self._final_max_model_len: int | None = None self._is_qwen3: bool = False @@ -93,16 +154,13 @@ def setup(self) -> None: msg = "vLLM is required for VLLMModel. Please install it: pip install vllm" raise ImportError(msg) - # Fetch max_model_len from user param or auto-detect from HuggingFace AutoConfig if self.max_model_len is not None: final_max_model_len = self.max_model_len else: final_max_model_len = get_max_model_len_from_config(self.model, cache_dir=self.cache_dir) - # Set tensor_parallel_size as user param or auto-detect from GPU count final_tp_size = self.tensor_parallel_size if self.tensor_parallel_size is not None else get_gpu_count() - # Set max_num_batched_tokens as user param or use default final_max_batched = self.max_num_batched_tokens llm_kwargs: dict[str, Any] = { @@ -160,33 +218,15 @@ def generate( self, prompts: list[str], ) -> list[str]: - """ - Generate text from prompts. + """Generate text from prompt strings (or chat message dicts). - Args: - prompts: List of prompt strings or list of message dicts - (for chat template). - - Returns: - List of generated text strings. - - Raises: - RuntimeError: If the model is not set up or generation fails. + Raises ``RuntimeError`` if the model is not set up or generation fails. """ if self._llm is None or self._sampling_params is None: msg = "Model not initialized. Call setup() first." raise RuntimeError(msg) - - try: - outputs = self._llm.generate( - prompts, - sampling_params=self._sampling_params, - use_tqdm=False, - ) - return [out.outputs[0].text if out.outputs else "" for out in outputs] - except (RuntimeError, ValueError, TypeError) as e: - msg = f"Error generating text: {e}" - raise RuntimeError(msg) from e + outputs = self._generate(prompts) + return [out.outputs[0].text if out.outputs else "" for out in outputs] def get_tokenizer(self) -> Any: # noqa: ANN401 """Get the tokenizer from the LLM instance.""" diff --git a/nemo_curator/pipeline/payload_lifecycle.py b/nemo_curator/pipeline/payload_lifecycle.py new file mode 100644 index 0000000000..b0df8c1ac0 --- /dev/null +++ b/nemo_curator/pipeline/payload_lifecycle.py @@ -0,0 +1,668 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: ANN401, ARG001, BLE001, PLR0913, PLR2004, S110, TRY004 + +"""Generic pipeline expansion for payload handle lifecycles.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from loguru import logger + +if TYPE_CHECKING: + from nemo_curator.stages.base import ProcessingStage + + +@dataclass(frozen=True) +class PayloadBindingSpec: + source_key: str + ref_key: str + waveform_key: str + sample_rate_key: str + num_samples_key: str + duration_key: str + materialize_stage_name: str + + +def expand_payload_lifecycle_stages( + stages: list[ProcessingStage], + config: Any, +) -> list[ProcessingStage]: + """Insert payload materialize/release helpers around logical stages. + + The lifecycle config is pipeline-level and backend-neutral. It keeps compute + stages visible to executors while adding only the mechanical stages that own + payload I/O and cleanup. + """ + + payload_cfg = _config_section(config, "payload_lifecycle") + if not bool(payload_cfg.get("enabled", False)): + return stages + + payload_release_stage_cls = _payload_release_stage_class() + materialize_idx, release_idx, consumers = _payload_lifecycle_positions(stages, payload_cfg) + run_id = _pipeline_run_id(config) + + reader = _last_manifest_reader(stages[: materialize_idx + 1]) + payload_specs = _payload_binding_specs(payload_cfg, stages=stages, consumers=consumers, reader=reader) + _configure_planned_source_segment_inputs(reader, payload_cfg, payload_specs, config) + _validate_payload_consumers(consumers, payload_specs) + _validate_single_segment_planner_owner( + reader, + consumers, + config=config, + ) + + materializers = [ + _build_payload_materializer(reader, spec, payload_cfg, config, run_id=run_id) for spec in payload_specs + ] + primary_spec = payload_specs[0] + release = payload_release_stage_cls( + name=str(payload_cfg.get("release_stage_name", "payload_release")), + payload_ref_key=primary_spec.ref_key, + waveform_key=primary_spec.waveform_key, + ) + + assembler = _post_release_payload_lifecycle_stage(config, reader, consumers, primary_spec, run_id=run_id) + execution_source = _payload_lifecycle_source_stage(reader) + extended_metrics = any(bool(getattr(stage, "extended_performance_metrics", False)) for stage in stages) + for helper in [*materializers, release, *([assembler] if assembler is not None else [])]: + helper.extended_performance_metrics = extended_metrics + + # Keep lifecycle bookkeeping out of unrelated pipelines. Only stages that + # can actually receive refs pay the recursive ref-scan cost, and only the + # global segmented path preserves terminal rows for downstream assembly. + for stage in stages[materialize_idx + 1 : release_idx + 1]: + stage._curator_tracks_payload_refs = True + if assembler is not None: + for stage in [*materializers, *stages[materialize_idx + 1 : release_idx + 1]]: + stage._curator_preserves_terminal_tasks = True + + expanded: list[ProcessingStage] = [] + for idx, stage in enumerate(stages): + expanded.append(execution_source if stage is reader else stage) + if idx == materialize_idx: + expanded.extend(materializers) + if idx == release_idx: + expanded.append(release) + if assembler is not None: + expanded.append(assembler) + logger.info( + "Expanded logical graph into payload lifecycle execution graph: {}", + " -> ".join(stage.name for stage in expanded), + ) + return expanded + + +def _payload_lifecycle_source_stage(reader: ProcessingStage | None) -> ProcessingStage | None: + """Let a modality-owned planner replace its logical reader at execution time.""" + if reader is None: + return None + builder = getattr(reader, "build_payload_lifecycle_source_stage", None) + if not callable(builder): + return reader + source_stage = builder() + if source_stage is None: + msg = f"{type(reader).__name__}.build_payload_lifecycle_source_stage() returned None" + raise TypeError(msg) + return source_stage + + +def _payload_lifecycle_positions( + stages: list[ProcessingStage], payload_cfg: dict[str, Any] +) -> tuple[int, int, list[ProcessingStage]]: + """Validate selectors and resolve lifecycle stage positions.""" + helpers = [stage.name for stage in stages if bool(getattr(stage, "_curator_pipeline_helper_stage", False))] + if helpers: + msg = ( + "Payload lifecycle configs must declare logical stages only. Do not list payload " + "materialization, payload release, or other pipeline helper stages explicitly. " + f"Remove implementation helper stage(s): {helpers}" + ) + raise ValueError(msg) + + materialize_after = _single_selector( + payload_cfg.get("materialize_after"), key="payload_lifecycle.materialize_after" + ) + release_after = _single_selector(payload_cfg.get("release_after"), key="payload_lifecycle.release_after") + consumer_selectors = _normalise_string_list(payload_cfg.get("consumers"), key="payload_lifecycle.consumers") + + materialize_idx = _find_stage_index(stages, materialize_after, key="payload_lifecycle.materialize_after") + release_idx = _find_stage_index(stages, release_after, key="payload_lifecycle.release_after") + if release_idx <= materialize_idx: + msg = "payload_lifecycle.release_after must come after payload_lifecycle.materialize_after" + raise ValueError(msg) + consumer_indices = [ + _find_stage_index(stages, selector, key="payload_lifecycle.consumers") for selector in consumer_selectors + ] + if any(idx <= materialize_idx or idx > release_idx for idx in consumer_indices): + msg = ( + "payload_lifecycle.consumers must appear after materialize_after and no later than release_after; " + f"got consumer indices {consumer_indices}, materialize_after={materialize_idx}, release_after={release_idx}" + ) + raise ValueError(msg) + return materialize_idx, release_idx, [stages[idx] for idx in consumer_indices] + + +def _payload_binding_specs( + payload_cfg: dict[str, Any], + *, + stages: list[ProcessingStage], + consumers: list[ProcessingStage], + reader: ProcessingStage | None, +) -> list[PayloadBindingSpec]: + payloads = payload_cfg.get("payloads") + if payloads: + payload_entries = _as_container(payloads) + if not isinstance(payload_entries, list): + msg = "payload_lifecycle.payloads must be a list of mappings" + raise TypeError(msg) + specs = [ + _payload_spec_from_mapping(dict(entry), payload_cfg=payload_cfg, index=idx) + for idx, entry in enumerate(payload_entries) + ] + else: + payload_keys = _normalise_string_list( + payload_cfg.get("payload_keys", ["audio_filepath"]), key="payload_lifecycle.payload_keys" + ) + specs = [ + _payload_spec_from_legacy_key( + source_key=source_key, + payload_cfg=payload_cfg, + stages=stages, + consumers=consumers, + reader=reader, + index=idx, + total=len(payload_keys), + ) + for idx, source_key in enumerate(payload_keys) + ] + _validate_unique([spec.source_key for spec in specs], "payload source key") + _validate_unique([spec.ref_key for spec in specs], "payload ref key") + _validate_unique([spec.waveform_key for spec in specs], "payload waveform key") + return specs + + +def _payload_spec_from_mapping( + entry: dict[str, Any], *, payload_cfg: dict[str, Any], index: int +) -> PayloadBindingSpec: + source_key = str( + entry.get("source_key") or entry.get("payload_key") or entry.get("audio_filepath_key") or "" + ).strip() + if not source_key: + msg = f"payload_lifecycle.payloads[{index}] requires source_key" + raise ValueError(msg) + return PayloadBindingSpec( + source_key=source_key, + ref_key=str(entry.get("ref_key") or _derived_key(source_key, "ref")), + waveform_key=str(entry.get("waveform_key") or _derived_key(source_key, "waveform")), + sample_rate_key=str(entry.get("sample_rate_key") or _derived_key(source_key, "sample_rate")), + num_samples_key=str(entry.get("num_samples_key") or _derived_key(source_key, "num_samples")), + duration_key=str(entry.get("duration_key") or payload_cfg.get("duration_key", "duration")), + materialize_stage_name=str( + entry.get("materialize_stage_name") or _materialize_stage_name(payload_cfg, index, source_key) + ), + ) + + +def _payload_spec_from_legacy_key( + *, + source_key: str, + payload_cfg: dict[str, Any], + stages: list[ProcessingStage], + consumers: list[ProcessingStage], + reader: ProcessingStage | None, + index: int, + total: int, +) -> PayloadBindingSpec: + if total == 1: + ref_key = str( + payload_cfg.get("ref_key") or _consumer_payload_key(consumers, "waveform_ref_key", "waveform_ref") + ) + waveform_key = str( + payload_cfg.get("waveform_key") or _consumer_payload_key(consumers, "waveform_key", "waveform") + ) + sample_rate_key = str( + payload_cfg.get("sample_rate_key") or _consumer_payload_key(consumers, "sample_rate_key", "sample_rate") + ) + num_samples_key = str( + payload_cfg.get("num_samples_key") or _first_attr(stages, "num_samples_key", "num_samples") + ) + duration_key = str( + payload_cfg.get("duration_key") + or _first_attr(stages, "duration_key", getattr(reader, "duration_key", "duration")) + ) + else: + ref_key = _list_or_derived(payload_cfg, "ref_keys", index, source_key, "ref") + waveform_key = _list_or_derived(payload_cfg, "waveform_keys", index, source_key, "waveform") + sample_rate_key = _list_or_derived(payload_cfg, "sample_rate_keys", index, source_key, "sample_rate") + num_samples_key = _list_or_derived(payload_cfg, "num_samples_keys", index, source_key, "num_samples") + duration_key = _list_or_default( + payload_cfg, "duration_keys", index, str(_first_attr(stages, "duration_key", "duration")) + ) + return PayloadBindingSpec( + source_key=source_key, + ref_key=ref_key, + waveform_key=waveform_key, + sample_rate_key=sample_rate_key, + num_samples_key=num_samples_key, + duration_key=duration_key, + materialize_stage_name=_materialize_stage_name(payload_cfg, index, source_key, total=total), + ) + + +def _validate_payload_consumers(consumers: list[ProcessingStage], payload_specs: list[PayloadBindingSpec]) -> None: + by_ref = {spec.ref_key: spec for spec in payload_specs} + for stage in consumers: + if not hasattr(stage, "resolve_payload_refs_for_batch"): + msg = ( + f"Payload consumer {stage.name!r} is not payload-aware. Stages that consume payload refs must " + "implement PayloadAwareStageMixin or an equivalent resolve_payload_refs_for_batch() contract." + ) + raise TypeError(msg) + bindings = _stage_payload_bindings(stage) + if not bindings: + msg = f"Payload consumer {stage.name!r} does not declare any payload ref bindings" + raise ValueError(msg) + for binding in bindings: + ref_key = str(binding["ref_key"]) + if ref_key not in by_ref: + msg = ( + f"Payload consumer {stage.name!r} declares ref_key={ref_key!r}, but payload_lifecycle " + f"materializes only {sorted(by_ref)}" + ) + raise ValueError(msg) + expected_waveform_key = by_ref[ref_key].waveform_key + waveform_key = str(binding.get("waveform_key") or "") + if waveform_key != expected_waveform_key: + msg = ( + f"Payload consumer {stage.name!r} must declare waveform_key={expected_waveform_key!r} " + f"for ref_key={ref_key!r}; got {waveform_key!r}" + ) + raise ValueError(msg) + + +def _stage_payload_bindings(stage: ProcessingStage) -> list[dict[str, str]]: + bindings = getattr(stage, "payload_bindings", None) + if callable(bindings): + bindings = bindings() + if bindings: + result = [] + for item in bindings: + if not isinstance(item, dict): + msg = f"{stage.name}.payload_bindings entries must be mappings" + raise TypeError(msg) + ref_key = str(item.get("ref_key") or "").strip() + waveform_key = str(item.get("waveform_key") or "").strip() + if ref_key and waveform_key: + result.append( + {**{str(k): str(v) for k, v in item.items()}, "ref_key": ref_key, "waveform_key": waveform_key} + ) + return result + ref_key = getattr(stage, "waveform_ref_key", None) + waveform_key = getattr(stage, "waveform_key", None) + if ref_key and waveform_key: + return [ + { + "ref_key": str(ref_key), + "waveform_key": str(waveform_key), + "sample_rate_key": str(getattr(stage, "sample_rate_key", "sample_rate")), + "num_samples_key": str(getattr(stage, "num_samples_key", "num_samples")), + } + ] + return [] + + +def _configure_planned_source_segment_inputs( + reader: ProcessingStage | None, + payload_cfg: dict[str, Any], + payload_specs: list[PayloadBindingSpec], + config: Any, +) -> None: + if reader is None or not bool(getattr(reader, "enable_global_bucketing", False)): + return + scheduler_cfg = _config_section(config, "global_audio_scheduler") + configured = scheduler_cfg.get("segment_input_keys", payload_cfg.get("segment_input_keys")) + segment_input_keys: list[str] = [] + if configured is not None: + segment_input_keys.extend(_normalise_string_list(configured, key="global_audio_scheduler.segment_input_keys")) + segment_input_keys.extend(spec.source_key for spec in payload_specs) + reader.segment_input_keys = _dedupe_strings(segment_input_keys) + reader.run_id = _pipeline_run_id(config) + if "parent_store_actor_name_prefix" in scheduler_cfg: + reader.parent_store_actor_name_prefix = str(scheduler_cfg["parent_store_actor_name_prefix"]) + + +def _validate_single_segment_planner_owner( + reader: ProcessingStage | None, + consumers: list[ProcessingStage], + *, + config: Any, +) -> None: + if reader is None or not bool(getattr(reader, "enable_global_bucketing", False)): + return + owner_stage = _single_selector(getattr(reader, "owner_stage", None), key="global_audio_scheduler.owner_stage") + matching_consumers = [stage for stage in consumers if owner_stage in _stage_match_idents(stage)] + if not matching_consumers: + available = sorted({ident for stage in consumers for ident in _stage_match_idents(stage)}) + msg = ( + "global_audio_scheduler.owner_stage must select exactly one stage listed in " + "payload_lifecycle.consumers. Global bucketing has a single planning owner; " + f"{owner_stage!r} was not found in payload consumers {available}." + ) + raise ValueError(msg) + if len(matching_consumers) > 1: + names = [stage.name for stage in matching_consumers] + msg = f"global_audio_scheduler.owner_stage must select exactly one payload consumer; matched {names}" + raise ValueError(msg) + _validate_planner_owner_has_largest_model_window(reader=reader, owner=matching_consumers[0], consumers=consumers) + reader.owner_stage = owner_stage + + +def _validate_planner_owner_has_largest_model_window( + *, + reader: ProcessingStage, + owner: ProcessingStage, + consumers: list[ProcessingStage], +) -> None: + owner_max_s = _required_positive_seconds(owner, "max_inference_duration_s") + consumer_max_s = [ + (stage.name, _required_positive_seconds(stage, "max_inference_duration_s")) for stage in consumers + ] + larger_consumers = [(name, max_s) for name, max_s in consumer_max_s if max_s > owner_max_s] + if larger_consumers: + details = ", ".join(f"{name}={value:g}s" for name, value in larger_consumers) + msg = ( + "global_audio_scheduler.owner_stage must select the payload consumer with the largest " + "max_inference_duration_s because the source planner emits one segment plan. " + f"Selected owner {owner.name!r} has max_inference_duration_s={owner_max_s:g}s, " + f"but larger consumer(s) exist: {details}." + ) + raise ValueError(msg) + + reader_max_s = _required_positive_seconds(reader, "max_inference_duration_s") + if abs(reader_max_s - owner_max_s) > 1e-6: + msg = ( + "ManifestReader(enable_global_bucketing=True).max_inference_duration_s must match the " + "selected owner stage's max_inference_duration_s. " + f"Reader has {reader_max_s:g}s, owner {owner.name!r} has {owner_max_s:g}s." + ) + raise ValueError(msg) + + +def _required_positive_seconds(stage: ProcessingStage, attr: str) -> float: + value = getattr(stage, attr, None) + if value is None: + msg = f"Global bucketing requires stage {stage.name!r} to define positive {attr}" + raise ValueError(msg) + return _positive_seconds(value, label=f"{stage.name}.{attr}") + + +def _optional_positive_seconds(stage: ProcessingStage, attr: str) -> float | None: + value = getattr(stage, attr, None) + if value is None: + return None + return _positive_seconds(value, label=f"{stage.name}.{attr}") + + +def _positive_seconds(value: Any, *, label: str) -> float: + try: + seconds = float(value) + except (TypeError, ValueError) as exc: + msg = f"{label} must be a positive number of seconds, got {value!r}" + raise TypeError(msg) from exc + if seconds <= 0: + msg = f"{label} must be > 0 seconds, got {seconds:g}" + raise ValueError(msg) + return seconds + + +def _build_payload_materializer( + reader: ProcessingStage | None, + spec: PayloadBindingSpec, + payload_cfg: dict[str, Any], + config: Any, + *, + run_id: str, +) -> ProcessingStage: + builder = getattr(reader, "build_payload_materialize_stage", None) + if not callable(builder): + reader_name = type(reader).__name__ if reader is not None else "" + msg = ( + "payload_lifecycle requires the source/reader stage to provide " + "build_payload_materialize_stage(). This keeps payload materialization " + f"modality-owned instead of hard-coding audio in the central planner; got {reader_name}." + ) + raise ValueError(msg) + return builder( + payload_spec=spec, + payload_config=payload_cfg, + pipeline_config=config, + run_id=run_id, + ) + + +def _post_release_payload_lifecycle_stage( + config: Any, + reader: ProcessingStage | None, + consumers: list[ProcessingStage], + primary_spec: PayloadBindingSpec, + *, + run_id: str, +) -> ProcessingStage | None: + if reader is None or not bool(getattr(reader, "enable_global_bucketing", False)): + return None + builder = getattr(reader, "build_payload_lifecycle_post_release_stage", None) + if not callable(builder): + msg = ( + "Global bucketing is enabled, but the source/reader stage does not provide " + "build_payload_lifecycle_post_release_stage(). The central payload lifecycle " + "planner only owns generic insertion order; modality-specific assembly must be " + f"provided by the planner stage, got {type(reader).__name__}." + ) + raise ValueError(msg) + return builder( + pipeline_config=config, + consumers=consumers, + primary_payload_spec=primary_spec, + run_id=run_id, + ) + + +def _pipeline_run_id(config: Any) -> str: + value = _config_get(config, "_curator_pipeline_run_id") + text = str(value or "").strip() + if not text: + msg = "Pipeline config is missing internal _curator_pipeline_run_id" + raise ValueError(msg) + return text + + +def _last_manifest_reader(stages: list[ProcessingStage]) -> ProcessingStage | None: + readers = [stage for stage in stages if _is_manifest_reader(stage)] + return readers[-1] if readers else None + + +def _is_manifest_reader(stage: ProcessingStage) -> bool: + return callable(getattr(stage, "build_payload_materialize_stage", None)) + + +def _payload_release_stage_class() -> type[ProcessingStage]: + from nemo_curator.stages.payload_lifecycle import PayloadReleaseStage + + return PayloadReleaseStage + + +def _config_section(config: Any, key: str) -> dict[str, Any]: + value = _config_get(config, key, {}) + if value is None: + return {} + value = _as_container(value) + if not isinstance(value, dict): + msg = f"{key} must be a mapping when configured, got {type(value).__name__}" + raise TypeError(msg) + return dict(value) + + +def _config_get(config: Any, key: str, default: Any = None) -> Any: + if config is None: + return default + if isinstance(config, dict): + return config.get(key, default) + get = getattr(config, "get", None) + if callable(get): + return get(key, default) + return default + + +def _as_container(value: Any) -> Any: + try: + from omegaconf import OmegaConf + + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=True) + except Exception: + pass + return value + + +def _normalise_string_list(value: Any, *, key: str) -> list[str]: + value = _as_container(value) + if value is None: + return [] + items = [value] if isinstance(value, str) else list(value) + result = [] + for item in items: + if item is None: + continue + text = str(item).strip() + if text: + result.append(text) + if not result: + msg = f"{key} must contain at least one non-empty value" + raise ValueError(msg) + return result + + +def _dedupe_strings(values: list[str]) -> list[str]: + result: list[str] = [] + seen: set[str] = set() + for value in values: + text = str(value).strip() + if text and text not in seen: + seen.add(text) + result.append(text) + if not result: + msg = "At least one non-empty string is required" + raise ValueError(msg) + return result + + +def _single_selector(value: Any, *, key: str) -> str: + values = _normalise_string_list(value, key=key) + if len(values) != 1: + msg = f"{key} must contain exactly one stage selector, got {values}" + raise ValueError(msg) + return values[0] + + +def _find_stage_index(stages: list[ProcessingStage], selector: str, *, key: str) -> int: + matches = [idx for idx, stage in enumerate(stages) if selector in _stage_match_idents(stage)] + if not matches: + available = sorted({ident for stage in stages for ident in _stage_match_idents(stage)}) + msg = f"{key} selector {selector!r} did not match any stage. Available selectors: {available}" + raise ValueError(msg) + if len(matches) > 1: + names = [stages[idx].name for idx in matches] + msg = f"{key} selector {selector!r} matched multiple stages: {names}" + raise ValueError(msg) + return matches[0] + + +def _stage_match_idents(stage: ProcessingStage) -> set[str]: + stage_type = type(stage) + return { + ident + for ident in ( + getattr(stage, "_curator_stage_id", None), + getattr(stage, "name", None), + stage_type.__name__, + f"{stage_type.__module__}.{stage_type.__name__}", + ) + if ident + } + + +def _first_attr(stages: list[ProcessingStage], attr: str, default: Any) -> Any: + for stage in stages: + value = getattr(stage, attr, None) + if value not in (None, ""): + return value + return default + + +def _consumer_payload_key(stages: list[ProcessingStage], attr: str, default: str) -> str: + return str(_first_attr(stages, attr, default)) + + +def _list_or_derived(payload_cfg: dict[str, Any], key: str, index: int, source_key: str, suffix: str) -> str: + values = _as_container(payload_cfg.get(key)) + if values: + values = list(values) + if index >= len(values): + msg = f"payload_lifecycle.{key} must contain one value for each payload key" + raise ValueError(msg) + return str(values[index]) + return _derived_key(source_key, suffix) + + +def _list_or_default(payload_cfg: dict[str, Any], key: str, index: int, default: str) -> str: + values = _as_container(payload_cfg.get(key)) + if values: + values = list(values) + if index >= len(values): + msg = f"payload_lifecycle.{key} must contain one value for each payload key" + raise ValueError(msg) + return str(values[index]) + return default + + +def _derived_key(source_key: str, suffix: str) -> str: + stem = re.sub(r"(_filepath|_path|_file)$", "", source_key) + return f"{stem}_{suffix}" + + +def _materialize_stage_name( + payload_cfg: dict[str, Any], index: int, source_key: str, *, total: int | None = None +) -> str: + base = str(payload_cfg.get("materialize_stage_name", "audio_payload_materialize")) + if total in (None, 1): + return base + return f"{base}_{index}_{re.sub(r'[^A-Za-z0-9_]+', '_', source_key)}" + + +def _validate_unique(values: list[str], label: str) -> None: + seen: set[str] = set() + duplicates: set[str] = set() + for value in values: + if value in seen: + duplicates.add(value) + else: + seen.add(value) + if duplicates: + msg = f"Duplicate {label}(s): {sorted(duplicates)}" + raise ValueError(msg) diff --git a/nemo_curator/pipeline/payload_refs.py b/nemo_curator/pipeline/payload_refs.py new file mode 100644 index 0000000000..5ca762e269 --- /dev/null +++ b/nemo_curator/pipeline/payload_refs.py @@ -0,0 +1,263 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: ANN401, BLE001, EM102 + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from loguru import logger + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from nemo_curator.tasks import Task + +_DEFAULT_LEASE_TTL_S = 3600.0 + + +def _ray_get(obj: Any) -> Any: + import ray + + return ray.get(obj) + + +@dataclass(frozen=True) +class PayloadRef: + payload_id: str + owner_node_id: str + store_actor_name: str + admission_actor_name: str + amount_bytes: int + sample_rate: int + num_samples: int + dtype: str = "float32" + lease_ttl_s: float = _DEFAULT_LEASE_TTL_S + actor_namespace: str | None = None + + +def _get_named_actor(name: str, namespace: str | None = None) -> Any: + import ray + + if namespace: + return ray.get_actor(name, namespace=namespace) + return ray.get_actor(name) + + +def resolve_payload_ref(payload_ref: PayloadRef) -> Any: + heartbeat_payload_ref(payload_ref) + store = _get_named_actor(payload_ref.store_actor_name, payload_ref.actor_namespace) + return _ray_get(store.get.remote(payload_ref.payload_id, payload_ref.lease_ttl_s)) + + +def heartbeat_payload_ref(payload_ref: PayloadRef) -> None: + admission = _get_named_actor(payload_ref.admission_actor_name, payload_ref.actor_namespace) + if not _ray_get( + admission.heartbeat.remote( + payload_ref.owner_node_id, + payload_ref.payload_id, + payload_ref.lease_ttl_s, + ) + ): + raise KeyError( + f"Payload admission lease {payload_ref.payload_id} is no longer present in " + f"{payload_ref.admission_actor_name}" + ) + store = _get_named_actor(payload_ref.store_actor_name, payload_ref.actor_namespace) + if not _ray_get(store.pin.remote(payload_ref.payload_id, payload_ref.lease_ttl_s)): + raise KeyError(f"Payload {payload_ref.payload_id} is no longer present in {payload_ref.store_actor_name}") + + +def heartbeat_payload_refs_batched(payload_refs: Sequence[PayloadRef]) -> None: + """Refresh payload leases with one RPC per admission/store actor. + + The singular :func:`heartbeat_payload_ref` contract remains unchanged for + existing callers. This opt-in batched path is used by payload-aware stages + that know their actors provide ``heartbeat_many`` and ``pin_many``. + """ + refs = _unique_payload_refs(payload_refs) + if not refs: + return + + admission_groups = _group_payload_refs(refs, actor_name=lambda ref: ref.admission_actor_name) + admission_calls = [] + admission_group_refs = [] + for (actor_name, namespace), grouped_refs in admission_groups.items(): + actor = _get_named_actor(actor_name, namespace) + admission_calls.append( + actor.heartbeat_many.remote([(ref.owner_node_id, ref.payload_id, ref.lease_ttl_s) for ref in grouped_refs]) + ) + admission_group_refs.append(grouped_refs) + for grouped_refs, results in zip(admission_group_refs, _ray_get(admission_calls), strict=True): + for ref, present in zip(grouped_refs, results, strict=True): + if not present: + raise KeyError( + f"Payload admission lease {ref.payload_id} is no longer present in {ref.admission_actor_name}" + ) + + store_groups = _group_payload_refs(refs, actor_name=lambda ref: ref.store_actor_name) + store_calls = [] + store_group_refs = [] + for (actor_name, namespace), grouped_refs in store_groups.items(): + actor = _get_named_actor(actor_name, namespace) + store_calls.append(actor.pin_many.remote([(ref.payload_id, ref.lease_ttl_s) for ref in grouped_refs])) + store_group_refs.append(grouped_refs) + for grouped_refs, results in zip(store_group_refs, _ray_get(store_calls), strict=True): + for ref, present in zip(grouped_refs, results, strict=True): + if not present: + raise KeyError(f"Payload {ref.payload_id} is no longer present in {ref.store_actor_name}") + + +def resolve_payload_refs_batched( + payload_refs: Sequence[PayloadRef], + *, + max_batch_bytes: int | None = None, +) -> list[Any]: + """Resolve refs in input order using byte-bounded, actor-grouped RPCs.""" + refs = list(payload_refs) + if not refs: + return [] + if max_batch_bytes is not None and (isinstance(max_batch_bytes, bool) or max_batch_bytes <= 0): + msg = "max_batch_bytes must be positive when set" + raise ValueError(msg) + + resolved_by_key: dict[tuple[str | None, str, str], Any] = {} + for batch in _byte_bounded_payload_batches(_unique_payload_refs(refs), max_batch_bytes): + heartbeat_payload_refs_batched(batch) + store_groups = _group_payload_refs(batch, actor_name=lambda ref: ref.store_actor_name) + calls = [] + grouped = [] + for (actor_name, namespace), grouped_refs in store_groups.items(): + actor = _get_named_actor(actor_name, namespace) + calls.append(actor.get_many.remote([(ref.payload_id, ref.lease_ttl_s) for ref in grouped_refs])) + grouped.append(grouped_refs) + for grouped_refs, payloads in zip(grouped, _ray_get(calls), strict=True): + for ref, payload in zip(grouped_refs, payloads, strict=True): + resolved_by_key[_payload_ref_key(ref)] = payload + return [resolved_by_key[_payload_ref_key(ref)] for ref in refs] + + +def _payload_ref_key(payload_ref: PayloadRef) -> tuple[str | None, str, str]: + return payload_ref.actor_namespace, payload_ref.store_actor_name, payload_ref.payload_id + + +def _unique_payload_refs(payload_refs: Sequence[PayloadRef]) -> list[PayloadRef]: + unique: dict[tuple[str | None, str, str], PayloadRef] = {} + for payload_ref in payload_refs: + unique.setdefault(_payload_ref_key(payload_ref), payload_ref) + return list(unique.values()) + + +def _group_payload_refs( + payload_refs: Sequence[PayloadRef], + *, + actor_name: Callable[[PayloadRef], str], +) -> dict[tuple[str, str | None], list[PayloadRef]]: + groups: dict[tuple[str, str | None], list[PayloadRef]] = defaultdict(list) + for payload_ref in payload_refs: + groups[(actor_name(payload_ref), payload_ref.actor_namespace)].append(payload_ref) + return groups + + +def _byte_bounded_payload_batches( + payload_refs: Sequence[PayloadRef], + max_batch_bytes: int | None, +) -> list[list[PayloadRef]]: + if max_batch_bytes is None: + return [list(payload_refs)] + batches: list[list[PayloadRef]] = [] + current: list[PayloadRef] = [] + current_bytes = 0 + for payload_ref in payload_refs: + amount = max(0, int(payload_ref.amount_bytes)) + if current and current_bytes + amount > max_batch_bytes: + batches.append(current) + current = [] + current_bytes = 0 + current.append(payload_ref) + current_bytes += amount + if current: + batches.append(current) + return batches + + +def release_payload_ref(payload_ref: PayloadRef) -> None: + try: + store = _get_named_actor(payload_ref.store_actor_name, payload_ref.actor_namespace) + released_bytes = int(_ray_get(store.release.remote(payload_ref.payload_id))) + except Exception: + released_bytes = int(payload_ref.amount_bytes) + if released_bytes <= 0: + released_bytes = int(payload_ref.amount_bytes) + try: + admission = _get_named_actor(payload_ref.admission_actor_name, payload_ref.actor_namespace) + _ray_get( + admission.release.remote( + payload_ref.owner_node_id, + payload_ref.payload_id, + released_bytes, + ) + ) + except Exception: + logger.debug("Failed to release payload admission tokens for {}", payload_ref.payload_id) + + +_DROP_PAYLOAD_REF = object() + + +def strip_payload_refs(value: Any) -> Any: + stripped = _strip_payload_refs(value) + if stripped is _DROP_PAYLOAD_REF: + return None + return stripped + + +def _strip_payload_refs(value: Any) -> Any: + if isinstance(value, PayloadRef): + return _DROP_PAYLOAD_REF + if isinstance(value, dict): + return _strip_payload_ref_dict(value) + if isinstance(value, list): + return _strip_payload_ref_list(value) + if isinstance(value, tuple): + return tuple(_strip_payload_ref_list(value)) + if isinstance(value, set): + return set(_strip_payload_ref_list(value)) + return value + + +def _strip_payload_ref_dict(value: dict[Any, Any]) -> dict[Any, Any]: + result: dict[Any, Any] = {} + for key, item in value.items(): + stripped = _strip_payload_refs(item) + if stripped is not _DROP_PAYLOAD_REF: + result[key] = stripped + return result + + +def _strip_payload_ref_list(value: Any) -> list[Any]: + result: list[Any] = [] + for item in value: + stripped = _strip_payload_refs(item) + if stripped is not _DROP_PAYLOAD_REF: + result.append(stripped) + return result + + +def iter_payload_refs(value: Any) -> list[PayloadRef]: + refs: list[PayloadRef] = [] + if isinstance(value, PayloadRef): + refs.append(value) + elif isinstance(value, dict): + for item in value.values(): + refs.extend(iter_payload_refs(item)) + elif isinstance(value, (list, tuple, set)): + for item in value: + refs.extend(iter_payload_refs(item)) + return refs + + +def task_payload_refs(task: Task) -> list[PayloadRef]: + return iter_payload_refs(task.data) if isinstance(task.data, dict) else [] diff --git a/nemo_curator/pipeline/pipeline.py b/nemo_curator/pipeline/pipeline.py index 961ae33c6f..90e9059699 100644 --- a/nemo_curator/pipeline/pipeline.py +++ b/nemo_curator/pipeline/pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import uuid from typing import Any from loguru import logger @@ -24,11 +25,7 @@ def assign_root_task_ids(initial_tasks: list[Task]) -> list[Task]: """Assign root ``task_id``s to user-provided initial tasks. - Every task in a run descends from the implicit root ``"0"`` (the id of - :class:`EmptyTask`). User-provided initial tasks are its direct - children, so they get ``"0_0"``, ``"0_1"``, … ``EmptyTask`` instances - are skipped (already ``"0"``). All downstream ``task_id`` assignment - happens in ``BaseStageAdapter``. + Every non-sentinel task is rooted under ``"0"`` exactly as on main. NOTE: we deliberately use the positional index here, NOT ``get_deterministic_id()``, even for content-bearing tasks like @@ -65,8 +62,16 @@ def __init__( """ self.name = name self.description = description - self.stages: list[ProcessingStage] = stages or [] + self._logical_stages: list[ProcessingStage] = stages or [] + self.stages: list[ProcessingStage] = self._logical_stages + # Preserve main's public config identity and never inject framework + # state into the caller's mapping. Expansion receives an ephemeral copy. self.config = config or {} + self._built = False + self._default_source_stage: ProcessingStage | None = None + self._default_sink_stage: ProcessingStage | None = None + self._planned_stage_snapshot: list[ProcessingStage] = [] + self._curator_pipeline_run_id = uuid.uuid4().hex def add_stage(self, stage: ProcessingStage) -> "Pipeline": """Add a stage to the pipeline. @@ -81,7 +86,12 @@ def add_stage(self, stage: ProcessingStage) -> "Pipeline": msg = f"Stage must be a ProcessingStage, got {type(stage)}" raise TypeError(msg) - self.stages.append(stage) + self._sync_public_stage_mutations() + self._clear_default_source_sink_roles() + self._logical_stages.append(stage) + self.stages = self._logical_stages + self._built = False + self._planned_stage_snapshot = [] logger.info(f"Added stage '{stage.name}' to pipeline '{self.name}'") return self @@ -91,27 +101,102 @@ def build(self) -> None: Raises: ValueError: If the pipeline has no stages """ + self._sync_public_stage_mutations() + if self._built: + logger.info(f"Pipeline '{self.name}' is already planned; reusing execution graph") + return + logger.info(f"Planning pipeline: {self.name}") + self._clear_default_source_sink_roles() # 1. Validate pipeline has stages - if not self.stages: + if not self._logical_stages: msg = f"Pipeline '{self.name}' has no stages" raise ValueError(msg) - # 2. Decompose composite stages into execution stages - execution_stages, decomposition_info = self._decompose_stages(self.stages) + # 2. Expand pipeline-level graph rules before composite decomposition. + planned_stages = self._expand_pipeline_graph(list(self._logical_stages)) + + # 3. Decompose composite stages into execution stages + execution_stages, decomposition_info = self._decompose_stages(planned_stages) self.stages = execution_stages self.decomposition_info = decomposition_info - # 3. Source / sink defaults: at most one stage may be explicitly + # 4. Source / sink defaults: at most one stage may be explicitly # marked; if none, the first stage is the source and the last is # the sink. The source flag activates content-based ids in the # default ``process_batch``; the sink flag is used by the # resumability layer in a follow-up PR. self._assign_source_sink_roles() + self._built = True + self._planned_stage_snapshot = list(self.stages) + + def _expand_pipeline_graph(self, stages: list[ProcessingStage]) -> list[ProcessingStage]: + """Apply generic pipeline-level graph expansion rules.""" + for stage in stages: + stage_state = getattr(stage, "__dict__", None) + if stage_state is not None: + stage_state.pop("_curator_tracks_payload_refs", None) + stage_state.pop("_curator_preserves_terminal_tasks", None) + payload_cfg = self.config.get("payload_lifecycle") + payload_cfg_get = getattr(payload_cfg, "get", None) + if not callable(payload_cfg_get) or not bool(payload_cfg_get("enabled", False)): + return stages + from nemo_curator.pipeline.payload_lifecycle import expand_payload_lifecycle_stages + + expansion_config = dict(self.config) + expansion_config["_curator_pipeline_run_id"] = self._curator_pipeline_run_id + return expand_payload_lifecycle_stages(stages, expansion_config) + + def _sync_public_stage_mutations(self) -> None: + """Preserve the historical public ``stages`` list mutation behavior. + + ``_logical_stages`` is the canonical source for graph expansion, but + existing user code may still mutate ``pipeline.stages`` directly. Treat + those mutations as logical graph edits before planning instead of + silently ignoring them. + """ + if self._built: + if self.stages == self._planned_stage_snapshot: + return + logger.warning( + "Pipeline.stages was mutated after build(); treating the current public stages list " + "as the new logical graph. Prefer Pipeline.add_stage() for future code." + ) + self._clear_default_source_sink_roles() + self._logical_stages = list(self.stages) + self._built = False + self._planned_stage_snapshot = [] + return + + if self.stages != self._logical_stages: + logger.warning( + "Pipeline.stages was mutated directly; syncing it into the logical graph. " + "Prefer Pipeline.add_stage() for future code." + ) + self._clear_default_source_sink_roles() + self._logical_stages = list(self.stages) + self._planned_stage_snapshot = [] + + def _clear_default_source_sink_roles(self) -> None: + """Clear source/sink roles that were assigned by a previous build. + + Stage instances are reused when a pipeline is replanned. Without + clearing defaults, a role assigned to the previous execution graph can + look like an explicit user mark on the next build. + """ + if self._default_source_stage is not None: + self._default_source_stage.is_source_stage = False + self._default_source_stage = None + if self._default_sink_stage is not None: + self._default_sink_stage.is_sink_stage = False + self._default_sink_stage = None def _assign_source_sink_roles(self) -> None: + self._default_source_stage = None + self._default_sink_stage = None + explicit_sources = [s for s in self.stages if s.is_source_stage] if len(explicit_sources) > 1: names = [s.name for s in explicit_sources] @@ -119,6 +204,7 @@ def _assign_source_sink_roles(self) -> None: raise ValueError(msg) if not explicit_sources: self.stages[0].is_source_stage = True + self._default_source_stage = self.stages[0] explicit_sinks = [s for s in self.stages if s.is_sink_stage] if len(explicit_sinks) > 1: @@ -127,6 +213,7 @@ def _assign_source_sink_roles(self) -> None: raise ValueError(msg) if not explicit_sinks: self.stages[-1].is_sink_stage = True + self._default_sink_stage = self.stages[-1] def _decompose_stages( self, stages: list[ProcessingStage | CompositeStage] @@ -150,10 +237,8 @@ def _decompose_stages( sub_stages = stage.decompose_and_apply_with() if isinstance(stage, CompositeStage) else [stage] if len(sub_stages) > 1: - # This was a composite stage logger.info(f"Decomposing composite stage: {stage.name}") - # Validate that decomposed stages are not composite for sub_stage in sub_stages: if isinstance(sub_stage, CompositeStage) and len(sub_stage.decompose()) > 1: msg = ( diff --git a/nemo_curator/pipeline/prefetch.py b/nemo_curator/pipeline/prefetch.py new file mode 100644 index 0000000000..377bb5991f --- /dev/null +++ b/nemo_curator/pipeline/prefetch.py @@ -0,0 +1,88 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Backend-neutral, bounded one-item lookahead prefetching.""" + +from __future__ import annotations + +from concurrent.futures import Future, ThreadPoolExecutor +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator + +_WorkT = TypeVar("_WorkT") +_ValueT = TypeVar("_ValueT") + + +class BoundedOneAheadPrefetchIterator(Generic[_WorkT, _ValueT]): + """Load work in order while overlapping at most one next item. + + ``max_inflight_bytes`` bounds the estimated bytes retained by the current + value and its one prefetched successor. An individual item larger than the + bound is still loaded synchronously so callers can handle or reject it. + The helper does not know about Ray, payload refs, audio, or model adapters. + """ + + def __init__( + self, + work: Iterable[_WorkT], + *, + loader: Callable[[_WorkT], _ValueT], + size_bytes: Callable[[_WorkT], int], + max_inflight_bytes: int, + thread_name_prefix: str = "curator-prefetch", + ) -> None: + if int(max_inflight_bytes) <= 0: + msg = "max_inflight_bytes must be positive" + raise ValueError(msg) + self._work = work + self._loader = loader + self._size_bytes = size_bytes + self._max_inflight_bytes = int(max_inflight_bytes) + self._thread_name_prefix = thread_name_prefix + + def __iter__(self) -> Iterator[tuple[_WorkT, _ValueT]]: + work_iter = iter(self._work) + try: + current_work = next(work_iter) + except StopIteration: + return + + current_value = self._loader(current_work) + current_size = self._checked_size(current_work) + executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=self._thread_name_prefix) + pending: Future[_ValueT] | None = None + try: + while True: + try: + next_work = next(work_iter) + except StopIteration: + yield current_work, current_value + return + + next_size = self._checked_size(next_work) + if current_size + next_size <= self._max_inflight_bytes: + pending = executor.submit(self._loader, next_work) + + yield current_work, current_value + + if pending is None: + next_value = self._loader(next_work) + else: + next_value = pending.result() + pending = None + current_work = next_work + current_value = next_value + current_size = next_size + finally: + if pending is not None: + pending.cancel() + executor.shutdown(wait=True, cancel_futures=True) + + def _checked_size(self, work: _WorkT) -> int: + size = int(self._size_bytes(work)) + if size < 0: + msg = "size_bytes must return a non-negative value" + raise ValueError(msg) + return size diff --git a/nemo_curator/pipelines/__init__.py b/nemo_curator/pipelines/__init__.py new file mode 100644 index 0000000000..99d959496f --- /dev/null +++ b/nemo_curator/pipelines/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Packaged Curator pipeline entry points.""" diff --git a/nemo_curator/pipelines/audio/__init__.py b/nemo_curator/pipelines/audio/__init__.py new file mode 100644 index 0000000000..d3d83dfd77 --- /dev/null +++ b/nemo_curator/pipelines/audio/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audio pipeline entry points.""" diff --git a/nemo_curator/pipelines/audio/qwen_omni_inprocess.py b/nemo_curator/pipelines/audio/qwen_omni_inprocess.py new file mode 100644 index 0000000000..7ff9c4de81 --- /dev/null +++ b/nemo_curator/pipelines/audio/qwen_omni_inprocess.py @@ -0,0 +1,351 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: ANN202, ANN401, C901 + +"""Hydra entry point for the Granary v2 Qwen-Omni in-process pipeline. + +The raw-audio Hydra config declares logical stages only and uses the +``payload_lifecycle`` section to insert waveform materialization/release helper +stages before executor planning. Compute stages stay visible to Ray Data/Xenna, +so heterogeneous GPU stages keep their own resource contracts. Stage entries +may carry ``stage_with``/``with_`` metadata applied after instantiation to set +resource, batch-size, and composite worker specs. +Secrets are redacted before logging. + +Hugging Face credentials are NOT handled here: weights download on remote Ray +workers (``ASRStage.setup_on_node`` -> ``prefetch_weights``), so a token in this +driver would not propagate. For gated models set ``HF_TOKEN``/``HF_HOME`` in the +worker environment (cluster env or executor ``runtime_env``). +""" + +import importlib +import os +import time +from typing import Any + +os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") +os.environ.setdefault("VLLM_LOGGING_LEVEL", "ERROR") + +import hydra +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.base import CompositeStage, ProcessingStage +from nemo_curator.stages.resources import Resources + +_EXECUTOR_FACTORIES = { + "xenna": "nemo_curator.backends.xenna:XennaExecutor", + "ray_data": "nemo_curator.backends.ray_data:RayDataExecutor", +} + +_SECRET_KEY_NAMES = { + "access_key", + "access_token", + "api_key", + "auth_token", + "bearer_token", + "credential", + "credentials", + "password", + "passwd", + "secret", + "secret_key", + "token", +} +_SECRET_KEY_PARTS = ( + "access_key", + "api_key", + "auth_token", + "bearer", + "credential", + "password", + "passwd", + "secret", +) + + +def _as_container(value: Any, *, resolve: bool = True) -> Any: + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=resolve) + return value + + +def _normalise_name_set(value: Any) -> set[str] | None: + """Return selected stage names, or ``None`` for all stages.""" + value = _as_container(value) + if value is None: + return None + if isinstance(value, str): + if value.lower() == "all": + return None + items = [part.strip() for part in value.split(",")] + else: + items = [str(item).strip() for item in value] + names = {item for item in items if item} + return None if any(item.lower() == "all" for item in names) else names + + +_ALLOWED_RESOURCES_TARGETS = frozenset( + { + "nemo_curator.stages.resources.Resources", + "Resources", + } +) +_HYDRA_RESOURCE_META_KEYS = frozenset({"_target_", "_recursive_", "_convert_"}) + + +def _instantiate_resources(value: Any) -> Resources: + """Build a ``Resources`` from a Hydra ``stage_with`` override. + + Never call open-ended ``hydra.utils.instantiate`` here: a malicious or + mistyped ``_target_`` in YAML could execute arbitrary code. Only the + canonical ``Resources`` dataclass is accepted. + """ + if isinstance(value, Resources): + return value + if OmegaConf.is_config(value) or isinstance(value, dict): + cfg = value if OmegaConf.is_config(value) else OmegaConf.create(value) + raw = OmegaConf.to_container(cfg, resolve=True) + if not isinstance(raw, dict): + msg = f"Invalid resources override: {value!r}" + raise TypeError(msg) + target = raw.get("_target_") + if target is not None and str(target) not in _ALLOWED_RESOURCES_TARGETS: + msg = f"resources overrides may only target nemo_curator.stages.resources.Resources; got {target!r}" + raise ValueError(msg) + fields = {key: item for key, item in raw.items() if key not in _HYDRA_RESOURCE_META_KEYS} + return Resources(**fields) + msg = f"Invalid resources override: {value!r}" + raise TypeError(msg) + + +def _normalise_with_kwargs(value: Any) -> dict[str, Any]: + kwargs = _as_container(value) or {} + if not isinstance(kwargs, dict): + msg = f"stage_with entries must be mappings, got {type(kwargs).__name__}" + raise TypeError(msg) + kwargs = dict(kwargs) + if "resources" in kwargs: + kwargs["resources"] = _instantiate_resources(kwargs["resources"]) + return kwargs + + +def _apply_stage_with(stage: ProcessingStage, value: Any) -> ProcessingStage: + """Apply optional ``stage_with`` metadata after Hydra construction.""" + if value is None: + return stage + value = _as_container(value) + if isinstance(stage, CompositeStage): + return stage.with_({name: _normalise_with_kwargs(kwargs) for name, kwargs in value.items()}) + return stage.with_(**_normalise_with_kwargs(value)) + + +def _target_idents(stage_id: str | None, target: str | None) -> set[str]: + return { + ident + for ident in ( + stage_id, + target, + target.rsplit(".", 1)[-1] if target else None, + ) + if ident + } + + +def _configured_stage_entries(cfg: DictConfig) -> Any: + if "stages" in cfg and cfg.stages: + return cfg.stages + if "processors" in cfg and cfg.processors: + logger.warning("Using legacy 'processors:' config key; prefer 'stages:' for Qwen-Omni.") + return cfg.processors + return None + + +def _is_secret_key(key: str) -> bool: + key_lower = key.lower() + return ( + key_lower in _SECRET_KEY_NAMES + or key_lower.endswith(("_token", "_secret", "_password")) + or any(part in key_lower for part in _SECRET_KEY_PARTS) + ) + + +def _redact_secret_values(value: Any) -> Any: + value = _as_container(value) + if isinstance(value, dict): + return { + key: ("" if _is_secret_key(str(key)) and item not in (None, "") else _redact_secret_values(item)) + for key, item in value.items() + } + if isinstance(value, list): + return [_redact_secret_values(item) for item in value] + return value + + +def _safe_config_yaml(cfg: DictConfig) -> str: + redacted = _redact_secret_values(cfg) + return OmegaConf.to_yaml(OmegaConf.create(redacted)) + + +def _instantiate_configured_stages(cfg: DictConfig) -> list[ProcessingStage]: + stage_entries = _configured_stage_entries(cfg) + if not stage_entries: + msg = ( + "qwen_omni_inprocess requires a Hydra 'stages:' list. " + "There is no implicit Granary-v2 stage fallback; every stage must be listed explicitly." + ) + raise ValueError(msg) + + run_set = _normalise_name_set(cfg.get("stages_to_run", cfg.get("processors_to_run", "all"))) + skip_set = _normalise_name_set(cfg.get("stages_to_skip", cfg.get("processors_to_skip", []))) or set() + available: set[str] = set() + stages: list[ProcessingStage] = [] + + for stage_cfg in stage_entries: + raw_unresolved = _as_container(stage_cfg, resolve=False) + if not isinstance(raw_unresolved, dict): + msg = f"Each stage entry must be a mapping, got {type(raw_unresolved).__name__}" + raise TypeError(msg) + + stage_id = raw_unresolved.get("stage_id", raw_unresolved.get("processor_id", raw_unresolved.get("id"))) + enabled = bool(stage_cfg.get("enabled", True)) + target = str(raw_unresolved.get("_target_", "")) + idents = _target_idents(stage_id, target) + available.update(idents) + selected = enabled + if selected and run_set is not None: + selected = bool(idents & run_set) + if selected and idents & skip_set: + selected = False + if selected: + raw = _as_container(stage_cfg, resolve=True) + raw.pop("stage_id", raw.pop("processor_id", raw.pop("id", None))) + raw.pop("enabled", None) + stage_with = raw.pop("stage_with", raw.pop("with_", None)) + stage = hydra.utils.instantiate(OmegaConf.create(raw)) + stage = _apply_stage_with(stage, stage_with) + stage.extended_performance_metrics = bool(cfg.get("extended_performance_metrics", True)) + if stage_id: + stage._curator_stage_id = str(stage_id) + stages.append(stage) + logger.info("Enabled stage {} ({})", stage_id or stage.name, type(stage).__name__) + else: + logger.info("Skipped stage {} ({})", stage_id or target.rsplit(".", 1)[-1], target) + + unknown_run = run_set - available if run_set is not None else set() + unknown_skip = skip_set - available + if unknown_run or unknown_skip: + details = [] + if unknown_run: + details.append(f"stages_to_run={sorted(unknown_run)}") + if unknown_skip: + details.append(f"stages_to_skip={sorted(unknown_skip)}") + msg = f"Unknown stage selector(s): {', '.join(details)}. Available: {sorted(available)}" + raise ValueError(msg) + if not stages: + msg = "No stages selected; check enabled/stages_to_run/stages_to_skip" + raise ValueError(msg) + return stages + + +def build_granary_v2_pipeline(cfg: DictConfig) -> Pipeline: + """Construct the Granary v2 stage chain from the logical Hydra stage list.""" + return Pipeline( + name="qwen_omni_inference", + stages=_instantiate_configured_stages(cfg), + config=OmegaConf.to_container(cfg, resolve=True), + ) + + +def _xenna_executor_config(cfg: DictConfig) -> dict[str, Any]: + xenna_cfg = cfg.get("xenna", {}) + if not isinstance(xenna_cfg, (dict, DictConfig)): + xenna_cfg = {} + + def _get(key: str, default: Any) -> Any: + return xenna_cfg.get(key, default) + + return { + **_executor_observability_config(cfg), + "execution_mode": cfg.get("execution_mode", "streaming"), + "autoscale_interval_s": cfg.get("autoscale_interval_s", 180), + "logging_interval": _get("logging_interval", 60), + "actor_pool_verbosity_level": _get("actor_pool_verbosity_level", "NONE"), + "monitoring_verbosity_level": _get("monitoring_verbosity_level", "NONE"), + "autoscaler_verbosity_level": _get("autoscaler_verbosity_level", "NONE"), + "executor_verbosity_level": _get("executor_verbosity_level", "NONE"), + "log_worker_allocation_layout": _get("log_worker_allocation_layout", False), + } + + +def _executor_observability_config(cfg: DictConfig) -> dict[str, Any]: + """Opt the Qwen benchmark pipeline into run-level hardware telemetry.""" + return { + "pipeline_hardware_sampler_enabled": bool(cfg.get("pipeline_hardware_sampler_enabled", True)), + "pipeline_hardware_sampler_interval_s": float(cfg.get("pipeline_hardware_sampler_interval_s", 0.5)), + "pipeline_hardware_sampler_startup_timeout_s": float( + cfg.get("pipeline_hardware_sampler_startup_timeout_s", 5.0) + ), + "pipeline_hardware_sampler_stop_timeout_s": float(cfg.get("pipeline_hardware_sampler_stop_timeout_s", 10.0)), + } + + +def _create_executor(cfg: DictConfig): + backend = cfg.get("backend", "ray_data") + if backend not in _EXECUTOR_FACTORIES: + msg = f"Unknown backend '{backend}'. Choose from: {list(_EXECUTOR_FACTORIES)}" + raise ValueError(msg) + + module_path, class_name = _EXECUTOR_FACTORIES[backend].rsplit(":", 1) + executor_cls = getattr(importlib.import_module(module_path), class_name) + logger.info(f"Using backend: {backend}") + + if backend == "xenna": + return executor_cls(config=_xenna_executor_config(cfg)) + if cfg.get("execution_mode") not in (None, "streaming"): + logger.info("execution_mode={} is Xenna-only and is ignored by Ray Data", cfg.get("execution_mode")) + return executor_cls(config=_executor_observability_config(cfg)) + + +@hydra.main(version_base=None) +def main(cfg: DictConfig) -> None: + """Hydra entry point for the Granary v2 Qwen-Omni pipeline.""" + logger.info(f"Hydra config:\n{_safe_config_yaml(cfg)}") + + pipeline = build_granary_v2_pipeline(cfg) + logger.info(f"Pipeline: {pipeline.describe()}") + + executor = _create_executor(cfg) + + t0 = time.time() + pipeline.run(executor=executor) + elapsed = time.time() - t0 + output_dir = cfg.get("output_dir") or cfg.get("workspace_dir", "./output") + logger.info(f"Pipeline finished in {elapsed / 60:.1f} min. Output: {output_dir}") + + perf_summary_path = os.path.join(output_dir, "perf_summary.json") + if os.path.isfile(perf_summary_path): + import json as _json + + with open(perf_summary_path) as _f: + perf_data = _json.load(_f) + perf_data["pipeline_duration_s"] = elapsed + with open(perf_summary_path, "w") as _f: + _json.dump(perf_data, _f, indent=2, ensure_ascii=False) + logger.info(f"Performance summary ({perf_summary_path}):\n{_json.dumps(perf_data, indent=2)}") + + +if __name__ == "__main__": + main() diff --git a/nemo_curator/stages/audio/README.md b/nemo_curator/stages/audio/README.md index 9a73322991..e172740b77 100644 --- a/nemo_curator/stages/audio/README.md +++ b/nemo_curator/stages/audio/README.md @@ -129,8 +129,9 @@ Key differences from a CPU stage: dicts into a single multi-row `pd.DataFrame` in one `DocumentBatch`, avoiding N single-row DataFrame allocations. Not a GPU stage, but benefits from batched processing. -- `ManifestWriterStage` (`common.py`) — writes - entries to JSONL, returns `AudioTask`. +- `ManifestWriterStage` (`common.py`) — batch-writes entries to JSONL, + drops waveform/array-like values from serialized rows, optionally writes + `perf_summary.json`, and returns `AudioTask`. ### Setting `batch_size` for GPU inference @@ -197,6 +198,407 @@ Backend reads stage.batch_size Start with `16` and increase until you see OOM or throughput plateaus. For NeMo ASR FastConformer models, `16–64` is typical on a single GPU. +## Pluggable ASR adapters, payload lifecycle, and duration-aware batching + +Long-form ASR pipelines need three separate controls that should not be +confused: + +1. **Backend scheduling** decides how many parent `AudioTask` rows are delivered + to one worker `process_batch()` call. This remains owned by the backend + (Xenna streaming, Xenna batch, Ray Data) and the stage `batch_size` / + worker-resource settings. +2. **Payload lifecycle** decides how much decoded waveform memory can be live. + It is not a model batch policy and it does not wrap compute stages. A + materializer stage reserves byte tokens before decode, decodes a staged local + file, stores the decoded waveform in a Ray-backed payload store, and passes + only a lightweight `PayloadRef` through normal `Task.data`. +3. **Model-call batching** decides how many bounded model-input segments go + into one adapter call after long audio has been segmented and optionally + duration bucketed inside `ASRStage.process_batch()`. + +The raw in-memory Qwen-style ASR shape is expressed in config as logical +stages plus a `payload_lifecycle` rule. The rule inserts materialization after +the configured reader/producer, keeps any listed consumer stages +backend-visible, and inserts release after the final configured consumer: + +```text +ManifestReader + -> AudioPayloadMaterializeStage + -> ASRStage(...) or any payload-aware CPU/GPU consumer stage + -> ... additional payload-aware consumers ... + -> PayloadReleaseStage + -> ManifestWriterStage +``` + +For a heterogeneous pipeline such as four GPU stages with different +`Resources(gpus=...)`, do not fuse them into one payload manager. Declare the +four stages normally, set `payload_lifecycle.consumers` to their stage ids, and +set `payload_lifecycle.release_after` to the last consumer. Ray Data and Xenna +then still see four independent processing stages and can schedule each one +with its own CPU/GPU contract. + +Payload lifecycle is intentionally a mechanical expansion rule, not a second +pipeline scheduler. It is applied by `Pipeline.build()` before composite-stage +decomposition and executor planning. Simple pipelines can use +`payload_keys: [audio_filepath]`; multi-source pipelines should use +`payloads:` entries with explicit `source_key`, `ref_key`, `waveform_key`, +`sample_rate_key`, `num_samples_key`, and optional `duration_key`. Any number +of downstream payload-aware consumers can share those refs. A consumer must +implement `PayloadAwareStageMixin` or an equivalent +`resolve_payload_refs_for_batch()` plus `payload_bindings()` contract so it can +resolve and drop every declared `PayloadRef` handle inside its own +`process_batch()` call. This local/windowed branch does not perform +full-manifest global planning; duration-aware bucketing stays inside each +backend-visible consumer's `process_batch()` call. + +To add another CPU or GPU stage that needs decoded audio, list the new logical +stage after the reader and before the writer, add its stage id to +`payload_lifecycle.consumers`, and move `payload_lifecycle.release_after` to +that final audio consumer. The stage receives ordinary `AudioTask` rows in the +same backend-scheduled batches as any other visible stage. It should resolve +`waveform_ref` to `waveform` only inside `process_batch()` and drop the +temporary tensor before returning. Metadata-only stages do not need to be +listed as payload consumers; they only need to run before `PayloadReleaseStage` +if they depend on payload-ref metadata that release will strip. + +The payload lifecycle planner is still generic in this branch: it can +materialize one or more configured payload sources once, keep any number of +payload-aware CPU/GPU stages backend-visible, and release refs after the final +consumer. It does not decide global segment ownership and does not change how +Xenna streaming, Xenna batch, or Ray Data split work across visible stages. + +The important invariant is **one audio-byte decode per row**. Remote object +storage is staged outside Curator. `AudioFileReaderStage` accepts local paths +only and is used by `AudioPayloadMaterializeStage` as the only code path that +opens `audio_filepath`, decodes, resamples, and creates the in-memory +`waveform` payload. After materialization, normal task rows carry `waveform_ref` +rather than the tensor itself. `ASRStage` resolves that ref only inside its own +`process_batch()` call, drops the temporary waveform before returning, and +`PayloadReleaseStage` releases the payload store entry before the writer runs. +`ManifestWriterStage` also omits waveform/array-like values from JSON output by +default. + +Reusable pieces live alongside the audio inference stages: + +| Component | Location | Role | +|---|---|---| +| `AudioPayloadMaterializeStage` | `stages/payload_lifecycle.py` | Uses `AudioFileReaderStage` to decode one staged local audio row, reserves positive duration-derived byte tokens from a Ray-backed admission actor before decode, stores the waveform in a node-local payload store actor, and emits a `PayloadRef` instead of raw waveform bytes. Admission tracks per-node usage and can also enforce an explicit cluster-wide payload cap. The reservation uses `lease_ttl_s` during decode, switches to the longer finite `materialized_lease_ttl_s` after publication, and fails with an admission snapshot if `admission_wait_timeout_s` elapses. Explicit release is the normal path; finite expiry reclaims orphaned published rows. | +| `PayloadReleaseStage` | `stages/payload_lifecycle.py` | Releases payload-store entries and admission reservations once downstream stages no longer need the waveform. | +| `PayloadRef` | `stages/payload_lifecycle.py` | Lightweight serializable handle carried through `Task.data`; compute stages resolve it only when they need the waveform. | +| `PayloadAwareStageMixin` | `stages/payload_lifecycle.py` | Reusable helper for CPU/GPU stages that need payload bytes. It resolves `PayloadRef` values with actor-grouped, byte-bounded RPCs at `process_batch()` time and drops temporary waveform tensors before returning. | +| `BoundedOneAheadPrefetchIterator` | `pipeline/prefetch.py` | Backend-neutral opt-in helper that overlaps one next work-item load while bounding estimated current-plus-next bytes. The Qwen ASR config uses it per adapter call; existing consumers keep eager resolution by default. | +| `AudioFileReaderStage` | `stages/audio/io/audio_file_reader.py` | Opens a local audio file and decodes/resamples it into an in-memory waveform. In handle-based pipelines it is used by `AudioPayloadMaterializeStage`, not placed before every GPU stage. | +| `BatchPolicy` | `stages/audio/inference/batch_policy.py` | Duration/cost bucket config (`buckets_sec`, `max_items_per_batch_by_bucket`, `bucketed_inference_batch_size`, `max_audio_sec_per_batch`). | +| `run_bucketed` | `stages/audio/inference/batch_policy.py` | Helper that dispatches cost-bucketed sub-batches and realigns results to original order. | +| `BucketedInferenceStage` | `stages/audio/inference/bucketed_stage.py` | Abstract inference-stage base for item expansion, bucketed model dispatch, and parent reassembly. | + +Import these from `nemo_curator.stages.audio.io`, +`nemo_curator.stages.payload_lifecycle`, and +`nemo_curator.stages.audio.inference`. + +### Reviewer code chain for the raw in-memory Qwen path + +Line numbers below describe the current branch and are meant as a quick code +tour for reviewers. + +1. `Pipeline` separates logical stages from execution stages, then applies + graph expansion before composite decomposition: + `pipeline/pipeline.py:70`, `pipeline/pipeline.py:102`, + `pipeline/pipeline.py:139`. +2. `expand_payload_lifecycle_stages()` reads the lifecycle config, validates + `materialize_after` / `consumers` / `release_after`, asks the reader for + materializers, and inserts release: + `pipeline/payload_lifecycle.py:39`, + `pipeline/payload_lifecycle.py:64`, + `pipeline/payload_lifecycle.py:377`, + `pipeline/payload_lifecycle.py:106`. +3. `ManifestReader` stays ordinary in this branch: it decomposes to + `FilePartitioningStage -> ManifestReaderStage` and also provides the audio + materializer hook used by the generic lifecycle expander: + `stages/audio/common.py:201`, + `stages/audio/common.py:229`, + `stages/audio/common.py:249`. +4. `AudioPayloadMaterializeStage` estimates bytes, reserves memory, decodes one + local row, stores the waveform in the payload store, and emits `PayloadRef`: + `stages/payload_lifecycle.py:641`, + `stages/payload_lifecycle.py:708`. +5. `AudioFileReaderStage` is the only audio-byte I/O path in the raw pipeline. + It rejects remote paths and decodes local files with ffmpeg: + `stages/audio/io/audio_file_reader.py:43`, + `stages/audio/io/audio_file_reader.py:154`, + `stages/audio/io/audio_file_reader.py:161`. +6. `ASRStage.process_batch()` either bulk-resolves refs eagerly or, when + explicitly configured, plans calls from ref metadata and resolves only the + current plus one byte-bounded prefetched call. It then applies the + local/windowed `BatchPolicy`, caps adapter calls, and stitches results: + `stages/audio/inference/asr/stage.py:508`, + `stages/audio/inference/asr/stage.py:531`, + `stages/audio/inference/asr/stage.py:582`. +7. `QwenOmniASRAdapter` owns Qwen/vLLM prompt construction and inference only: + `models/asr/qwen_omni.py:84`, + `models/asr/qwen_omni.py:250`, + `models/asr/qwen_omni.py:330`. +8. `PayloadReleaseStage` releases refs before writing, while `BaseStageAdapter` + handles exception cleanup, dropped-row ref cleanup, terminal tombstones, and + task-id preservation: + `stages/payload_lifecycle.py:918`, + `backends/base.py:408`, + `backends/base.py:425`, + `backends/base.py:503`. + +### Local/windowed bucketing mode + +This branch keeps `ManifestReader` ordinary: it reads manifest rows in manifest +order and emits one `AudioTask` per row. There is no full-manifest planning pass +and no backend-specific window object. + +Duration-aware bucketing happens only after a backend has scheduled rows to an +`ASRStage.process_batch()` call. In other words, Xenna streaming, Xenna batch, +and Ray Data still decide how many parent rows reach a worker exactly as they do +for normal stages. The local branch only improves the work *inside* that one +call: it segments long parents, buckets model items by duration, caps final +adapter-call sizes per bucket, and stitches outputs back to the same parent +rows. + +### ASR duration-aware dispatch inside `process_batch` + +`ASRStage` is the concrete long-form ASR stage. It owns Curator-side glue: +input validation, language resolution, required model-input segmentation, +duration-aware model-item bucketing, adapter-call counting, stitch-back, and +metrics. Model logic lives behind the adapter. + +Backend executors call `ASRStage.process_batch(tasks)` with the parent rows they +scheduled. Inside that call: + +1. Each parent row is validated and sliced into contiguous model-input segments + no longer than `max_inference_duration_s`. Rows already below the ceiling + become one segment. +2. The flat segment/item list is bucketed by `BatchPolicy` when + `batch_policy.enabled=True`. +3. `max_items_per_batch_by_bucket` and `max_audio_sec_per_batch` determine which + same-bucket segments are considered together. +4. `bucketed_inference_batch_size` determines the final adapter-call item cap + per duration bucket. This is the knob for "short segments can run more + samples per model call; long segments stay single-item." +5. Results are realigned to original item order and stitched back to one output + row per input parent. + +When `batch_policy.enabled=False`, model-input segmentation still happens as +OOM/model-limit protection, but segments are split only by `adapter_batch_size` +when set, otherwise by the stage's `batch_size`. Backend parallelism and work +stealing remain exactly the normal backend behavior. + +There are three distinct batch-size knobs: + +- `batch_size` is the backend-visible candidate window: Ray Data and Xenna use + it to decide how many parent rows reach one `process_batch()` call. +- `adapter_batch_size` is the fallback cap for one adapter/model call when no + enabled `BatchPolicy` supplies a bucket-specific cap. +- `bucketed_inference_batch_size` is the per-duration-bucket adapter/model-call + cap used when `batch_policy.enabled=True`. + +The core `BucketedInferenceStage` hook contract is: + +```text +build_items(tasks) -> (items, parent_of) expand parent tasks into model items +item_cost(item) -> float duration/cost used for bucketing +run_inference(items) -> results one adapter call; results are 1:1 +assemble(tasks, items, parent_of, results) -> out_tasks +``` + +Some stages may expose optional scheduler hooks, but the Qwen raw in-memory +pipeline does not rely on reader or payload prebatching. Keep batching decisions +at the backend stage level and at the ASR model-call level. + +### The ASR adapter split (Tier-1 / Tier-2) + +For audio speech recognition the concrete stage is `ASRStage` +(`stages/audio/inference/asr/stage.py`), a `BucketedInferenceStage` +subclass that owns only Curator-side glue — input validation, ISO-code -> +language-name resolution, model-input segmentation for clips longer than +`max_inference_duration_s`, stitch-back, and metrics. The *model-specific* logic +(vLLM setup, prompt formatting, two-turn generation) lives behind a swappable +**adapter**: + +| Layer | Location | Responsibility | +|---|---|---| +| `ASRStage` | `stages/audio/inference/asr/stage.py` | Generic, model-independent stage glue. | +| `ASRAdapter` (Protocol) + `ASRResult` | `models/asr/base.py` | The contract a model adapter must satisfy (`setup`, `teardown`, `transcribe_batch`, `prefetch_weights`, `last_metrics`). | +| `QwenOmniASRAdapter` | `models/asr/qwen_omni.py` | Qwen3-Omni implementation (built on the shared `VLLMBase` in `models/vllm_model.py`). | + +The split is **Tier-1 / Tier-2**: + +Install the Qwen implementation with `uv sync --extra audio_qwen`. This +Qwen-only extra composes the unchanged `audio_cuda12` stack with vLLM and +`qwen-omni-utils`; existing non-Qwen audio environments are unaffected. + +- **Tier-1** fields are universal stage knobs set in YAML (`adapter_target`, + `model_id`, I/O keys, `max_inference_duration_s`, `keep_waveform`, `batch_size`, + `adapter_batch_size`, `batch_policy`, ...). +- **Tier-2** is the opaque `adapter_kwargs` dict forwarded verbatim to the + adapter constructor; the stage never reads inside it. + +`ASRStage` also builds the per-item adapter payload. Every adapter call receives +the waveform, sample rate, resolved language name, original language code, +task id, estimated audio seconds, and chunk position metadata. If +`reference_text_key` is configured, the stage forwards that row value as +`reference_text` so prompt-driven adapters can use transcript/reference context. +Adapters can ignore fields they do not need, but they should not require +Curator-specific `AudioTask` objects. + +Swapping the model is a one-line `adapter_target:` change in YAML; the +adapter class is resolved at `setup()` via `hydra.utils.get_class`. See +`tutorials/audio/qwen_omni_inprocess/` for the end-to-end config. + +> **Per-call accumulator note (multi-worker safety):** `ASRStage` keeps a +> couple of per-`process_batch` accumulators on `self` (model-metric sums, +> inference wall time), reset in `build_items` and consumed in `assemble`. +> This is safe because each worker runs one `process_batch` at a time +> (Ray Actor Pool / Ray Data / single-slot Xenna). Do not enable an +> executor that overlaps invocations on one stage instance without making +> those accumulators call-local. + +### When to use which base + +| Pattern | Base | Override | +|---|---|---| +| CPU, one task at a time | `ProcessingStage[AudioTask, AudioTask]` | `process` | +| GPU/IO, one batched call, no bucketing | `ProcessingStage` | `process_batch` (e.g. `InferenceAsrNemoStage`) | +| GPU inference needing cost/duration bucketing + a swappable model | `BucketedInferenceStage` + an adapter | the four hooks (e.g. `ASRStage` + `ASRAdapter`) | + +## Performance metrics (`perf_summary.json`) + +Audio manifest writer stages aggregate per-stage stats into `perf_summary.json` +(all math stays in Curator; downstream tooling should transport the file as-is). + +### Design principle: collect everywhere, write once + +| Layer | Who | What | +|-------|-----|------| +| **Collection** | Every stage, CPU and GPU | Backend adapter times each `process_batch` and appends `StagePerfStats` to `task._stage_perf` | +| **Serialization** | Single CPU writer (`num_workers=1`) | Maintains output JSONL and the aggregate `perf_summary.json` | +| **Upload** | External orchestrator (optional) | Verbatim copy/transport of one `perf_summary.json` | + +Do **not** add per-GPU file writers or a second metrics actor. Multiple +`perf_summary.json` writers produce incompatible summaries and require explicit +multi-writer handling downstream. + +Toggle perf file output with `write_perf_stats: false` on either +`ManifestWriterStage` or `ShardedManifestWriterStage` (manifest output still +written; sharded `.done` markers still written for the sharded writer). + +### Collection flow + +1. **Every backend adapter** (Xenna, Ray Data, Ray Actor Pool) subclasses + `BaseStageAdapter`. Its `process_batch` times the call, pulls + `stage._consume_custom_metrics()` (from `_log_metrics` / `_log_metric` / + `_time_metric` during the stage body), stamps identity, and calls + `task.add_stage_perf(stage_perf_stats)` on **each output task**. This applies + equally to CPU stages (tar reader, discovery, filters) and GPU stages + (inference): CPU stages get full `process_time` / `custom_metrics`; they + simply leave `gpu_id` empty. +2. **Stage identity** — `WorkerPerfIdentity` is resolved once per worker in + `build_xenna_perf_identity()` / `build_ray_perf_identity()` + (`backends/perf_identity.py`, stamped on `WorkerMetadata` at setup): + - **Scheduling:** `actor_id`, `node_id`, `gpu_id`. Under Xenna, `gpu_id` uses + `WorkerMetadata.allocation.gpus[0].index` only. Under Ray Data / Actor Pool: + `ray.get_gpu_ids()[0]` only. These strings are stripped from + `StagePerfStats.items()` so framework metric collectors never `float()` them. + - **Cluster location (additive):** `physical_address` + (`:`, the canonical backend-independent + GPU identifier), `pod_ip` (K8s `POD_IP` when set), `hostname`, `gpu_indices` + (full allocation, e.g. `[0, 1]` for `tp=2`), optional `gpu_uuids` from CUDA + device properties. +3. **Per-actor scheduling breakdown** — the writer’s `AudioPerformanceSummary` + builds per-stage `actor_count` and `per_actor` (keyed by `actor_id`: items + processed, audio hours, batch-size / queue-wait percentiles) for **every + actor-backed stage, GPU or CPU**. GPU actors additionally carry their + `physical_address` + `gpu_indices` / `gpu_uuids` and the NVML + `gpu_util_pct_p*` / `gpu_mem_used_pct_p*` percentiles inside their `per_actor` + block, and the stage gets `gpu_addresses` (per-actor physical addresses) + + `gpu_count` (true device count). Top-level `pipeline_throughput` rolls up the + GPU-stage `gpu_addresses` / `gpu_count`. +4. **Dedup** — `AudioPerformanceSummary.record_stage_perf` fingerprints each + `StagePerfStats` (including identity) so fan-out stages do not multiply-count + upstream invocations. + +### File writes (`ManifestWriterStage` / `ShardedManifestWriterStage`) + +When `write_perf_stats=true` (default): + +- **`perf_summary.json`** — aggregate summary from `AudioPerformanceSummary.build_summary()`. + `ManifestWriterStage` refreshes it next to the output manifest after each + successful batch write, with `teardown()` as a final backstop. This keeps the + summary available even when a backend runs the writer as an actor whose + teardown is not called on the driver. `ShardedManifestWriterStage` refreshes + it when a shard hits its `_shard_total` (`.done` written) and again in + `teardown()`. Includes writer’s own I/O timings under + `stages[manifest_writer]` or `stages[sharded_manifest_writer]`. + Per-task `StagePerfStats` are aggregated in memory only (no per-shard sidecar file). +- **Manifest rows** — both writers omit `waveform` by default and also skip + accidental array-like values (`shape` + `dtype`) so in-memory tensors are not + serialized to JSONL. + +`main.py` (tutorial entry points) may add `pipeline_duration_s` after +`pipeline.run()` returns. + +### Adding custom metrics (stage authors) + +Inside `process` or `process_batch`: + +```python +self._log_metrics({"bytes_loaded": float(n_bytes), "audio_duration_s": dur}) +``` + +Optional timing helper: + +```python +with self._time_metric("decode_wall_s"): + ... +``` + +Metrics roll up to `stages[].custom_metrics_sum` in +`perf_summary.json`. For a new cross-stage scalar in the +published summary schema, add a field to `AudioStageMetrics` in +`metrics/performance.py` and emit it from the producing stage. + +### CPU vs GPU in published JSON + +| Field | CPU stage | GPU stage | +|-------|-----------|-----------| +| `process_time`, idle, invocations | yes | yes | +| `custom_metrics_sum` | yes, if stage calls `_log_metrics` | yes | +| `actor_id`, `node_id` | best-effort | best-effort | +| `actor_count`, `per_actor` | present (keyed by `actor_id`) | present (keyed by `actor_id`) | +| `gpu_id` | empty | legacy node label (e.g. `node-0:0`), additive | +| `gpu_addresses`, `gpu_count` | absent | present | +| `physical_address`, `gpu_indices` | absent | in each GPU actor's `per_actor` block | +| `pod_ip`, `hostname` | in `per_actor` when resolved | in `per_actor` when resolved | +| `gpu_util_pct_p*`, `gpu_mem_used_pct_p*`, `gpu_uuids` | absent | in `per_actor` when CUDA/NVML up | + +**Throughput denominator (`writer_wall_time_s`)** + +The writer is a single CPU actor (`num_workers=1`). Its timer starts at the end +of its own `setup_on_node` and runs until summary serialization. Under **Xenna +streaming** or **Ray Data** (pipelined execution), that interval spans the +end-to-end processing window (the writer blocks on upstream GPU stages). Under +**Xenna batch** (sequential stage materialization), the timer covers only the +writer phase — use whole-run pipeline wall clock from the entry point for throughput there. + +**Validation (recommended for pipeline changes)** + +- **Perf** — compare `perf_summary.json` across runs on shared throughput fields; + work-done identical (`total_utterances`, shard counts). +- **Output** — compare `manifest_*.jsonl` rows keyed on `audio_filepath`; gate + on key alignment and prediction-field stability (vLLM nondeterminism expected + on a small fraction of rows). + +Hardware telemetry is sampled opportunistically with NVML on GPU workers. +Sampler failures are fail-open so inference is not interrupted; use +`gpu_sampler_active` / `gpu_sampler_error_count` in perf summaries to tell +whether missing GPU util or VRAM fields mean "not sampled" rather than +"hardware was idle." + ## What you must always declare Every stage (CPU or GPU) should declare: @@ -363,6 +765,26 @@ under the hood). with its own model copy. - **Autoscaling**: Xenna can adjust worker counts based on measured throughput (`autoscale_interval_s` in executor config). +- **Pin expensive GPU stages; autoscale only cheap stages**: autoscale + optimizes *steady-state throughput, not cold-start latency*. An + unpinned stage starts at **1 worker** and only scales out after it has + produced enough speed measurements to be judged the bottleneck. That + ramp is instant for cheap CPU stages but expensive for GPU stages — they + idle most GPUs during warm-up and then pay a model-load tax on every + late-spawned worker. **Pin the worker count of any expensive GPU stage** + with the stage's `num_workers()` method for a cluster-wide cap, or with + `xenna_stage_spec()["num_workers_per_node"]` for a per-node cap. Stage + fields such as `ASRStage.xenna_num_workers` and + `ASRStage.xenna_num_workers_per_node` feed those two contracts. Do not + put cluster-wide `num_workers` in `xenna_stage_spec`; Xenna rejects it so + worker sizing stays aligned with the rest of Curator. A manual pin is a + *hard* constraint — Xenna panics if the cluster cannot satisfy it, so keep it within capacity. The pin is + model-dependent: `workers_per_node = floor(gpus_per_node / + resources.gpus)`, and `resources.gpus` is the per-actor GPU footprint set + by the model/adapter you run. Swapping to a **smaller model needs fewer + GPUs per actor**, which lets *more* actors fit per node (raise the pin); a + larger / higher-tensor-parallel model needs more GPUs per actor (lower the + pin). Re-tune the pin whenever you change the model. - **Call chain**: `Xenna scheduler → XennaStageAdapter.process_data(tasks)` `→ BaseStageAdapter.process_batch(tasks)` (timing + metrics) @@ -978,8 +1400,10 @@ The two surviving windows: ### Stage 3: `ManifestWriterStage` -Appends the entry as a single JSON line to -`./alm_output/alm_output.jsonl` (351 KB for this entry). +Appends the entry as a JSON line to `./alm_output/alm_output.jsonl` (351 KB for +this entry). The writer omits waveform/array-like values from serialized JSONL +and can refresh `perf_summary.json` during batch writes when +`write_perf_stats=true`. ### ALM summary table diff --git a/nemo_curator/stages/audio/common.py b/nemo_curator/stages/audio/common.py index a27c1f54bd..bcc2f83d7f 100644 --- a/nemo_curator/stages/audio/common.py +++ b/nemo_curator/stages/audio/common.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ruff: noqa: ANN401 import json import os @@ -25,11 +26,23 @@ from loguru import logger from nemo_curator.backends.base import NodeInfo, WorkerMetadata +from nemo_curator.stages.audio.io.manifest_writer_utils import AudioManifestWriterMetrics, manifest_lines from nemo_curator.stages.base import CompositeStage, ProcessingStage from nemo_curator.stages.file_partitioning import FilePartitioningStage from nemo_curator.tasks import AudioTask, EmptyTask, FileGroupTask +def _config_get(config: Any, key: str, default: Any = None) -> Any: + if config is None: + return default + if isinstance(config, dict): + return config.get(key, default) + get = getattr(config, "get", None) + if callable(get): + return get(key, default) + return default + + def get_audio_duration(audio_filepath: str) -> float: """Get the duration of the audio file in seconds.""" try: @@ -141,6 +154,7 @@ class ManifestReaderStage(ProcessingStage[FileGroupTask, AudioTask]): """ name: str = "manifest_reader_stage" + storage_options: dict[str, Any] | None = None def process(self, task: FileGroupTask) -> list[AudioTask]: t0 = time.perf_counter() @@ -148,20 +162,21 @@ def process(self, task: FileGroupTask) -> list[AudioTask]: results: list[AudioTask] = [] count = 0 for manifest in paths: - fs, resolved = url_to_fs(manifest) + fs, resolved = url_to_fs(manifest, **(self.storage_options or {})) + manifest_count = 0 with fs.open(resolved, "r", encoding="utf-8") as f: for line in f: if line.strip(): - results.append( - AudioTask( - dataset_name=task.dataset_name, - data=json.loads(line.strip()), - _metadata=task._metadata, - _stage_perf=list(task._stage_perf), - ) + audio_task = AudioTask( + dataset_name=task.dataset_name, + data=json.loads(line.strip()), + _metadata=task._metadata, + _stage_perf=list(task._stage_perf), ) + results.append(audio_task) count += 1 - logger.info(f"ManifestReaderStage: loaded {count} entries from {manifest}") + manifest_count += 1 + logger.info(f"ManifestReaderStage: loaded {manifest_count} entries from {manifest}") self._log_metrics( { "process_time": time.perf_counter() - t0, @@ -216,7 +231,7 @@ def decompose(self) -> list[ProcessingStage]: file_extensions=self.file_extensions, storage_options=self.storage_options, ), - ManifestReaderStage(), + ManifestReaderStage(storage_options=self.storage_options), ] def get_description(self) -> str: @@ -227,34 +242,96 @@ def get_description(self) -> str: parts.append(f"with target blocksize {self.blocksize}") return ", ".join(parts) + def build_payload_materialize_stage( + self, + *, + payload_spec: Any, + payload_config: dict[str, Any], + pipeline_config: Any, + run_id: str, + ) -> ProcessingStage: + """Build the audio payload materializer for the generic lifecycle planner. + + ``nemo_curator.pipeline.payload_lifecycle`` owns graph insertion order. + The reader owns modality-specific materialization, so central pipeline + code does not need to import audio reader/materializer internals. + """ + + from nemo_curator.stages.payload_lifecycle import AudioPayloadMaterializeStage + + return AudioPayloadMaterializeStage( + name=payload_spec.materialize_stage_name, + target_sample_rate=int(payload_config.get("target_sample_rate", 16000)), + target_nchannels=int(payload_config.get("target_nchannels", 1)), + audio_filepath_key=payload_spec.source_key, + duration_key=payload_spec.duration_key, + segment_start_key=str(payload_config.get("segment_start_key", "segment_start_s")), + segment_duration_key=str(payload_config.get("segment_duration_key", "segment_duration_s")), + waveform_key=payload_spec.waveform_key, + waveform_ref_key=payload_spec.ref_key, + sample_rate_key=payload_spec.sample_rate_key, + num_samples_key=payload_spec.num_samples_key, + skip_on_read_error=bool( + payload_config.get( + "skip_on_read_error", + _config_get(pipeline_config, "audio_reader_skip_on_read_error", False), + ) + ), + node_memory_fraction=float(payload_config.get("node_memory_fraction", 0.80)), + max_node_payload_bytes=payload_config.get("max_node_payload_bytes"), + max_cluster_payload_bytes=payload_config.get("max_cluster_payload_bytes"), + lease_ttl_s=float(payload_config.get("lease_ttl_s", 3600)), + materialized_lease_ttl_s=float(payload_config.get("materialized_lease_ttl_s", 4 * 60 * 60)), + admission_actor_name=str(payload_config.get("admission_actor_name", "curator_payload_admission")), + admission_poll_interval_s=float(payload_config.get("admission_poll_interval_s", 0.25)), + admission_wait_timeout_s=float(payload_config.get("admission_wait_timeout_s", 4 * 60 * 60)), + run_id=run_id, + ) + @dataclass class ManifestWriterStage(ProcessingStage[AudioTask, AudioTask]): - """Append a single AudioTask to a JSONL manifest file. + """Append AudioTasks to a JSONL manifest file. The output file is truncated once in ``setup()`` (called on the driver) so repeated pipeline runs produce a clean output. ``setup_on_node()`` only creates the parent directory -- it never truncates, so multi-node deployments do not erase each other's data. - .. note:: - Because all nodes append to the same path, callers in multi-node - setups should either use a shared filesystem or provide a - node-unique ``output_path``. + The stage is pinned to one worker/actor for all supported backends so + append writes to ``output_path`` are serialized. In-memory waveform tensors + can be omitted through explicit serialization policy. Supports local and cloud paths via fsspec. Args: output_path: Destination JSONL path (local or cloud). + write_perf_stats: If True, aggregate attached stage perf and refresh + ``perf_summary.json`` next to the output manifest after each batch + write, with teardown as a final backstop. + drop_manifest_keys: Explicit task data keys to omit from JSONL output. + drop_array_like_values: If True, omit tensor/array-like task data. + perf_summary_path: Optional override for perf summary output path. """ output_path: str name: str = "manifest_writer" + write_perf_stats: bool = False + duration_key: str = "duration" + drop_manifest_keys: tuple[str, ...] = () + drop_array_like_values: bool = False + perf_summary_path: str | None = None + _writer_metrics: AudioManifestWriterMetrics = field(init=False, repr=False) def __post_init__(self) -> None: if not self.output_path: msg = "output_path is required for ManifestWriterStage" raise ValueError(msg) + self._writer_metrics = AudioManifestWriterMetrics( + stage_name=self.name, + duration_key=self.duration_key, + write_perf_stats=self.write_perf_stats, + ) def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: """Truncate the output file once on the driver before processing starts.""" @@ -264,6 +341,8 @@ def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: self._fs.makedirs(parent_dir, exist_ok=True) with self._fs.open(self._path, "w", encoding="utf-8"): pass + if self.write_perf_stats: + self._writer_metrics.reset_wall_timer() logger.info(f"ManifestWriterStage: writing to {self.output_path}") def setup_on_node( @@ -276,16 +355,104 @@ def setup_on_node( parent_dir = "/".join(self._path.split("/")[:-1]) if parent_dir: self._fs.makedirs(parent_dir, exist_ok=True) + if self.write_perf_stats: + self._writer_metrics.reset_wall_timer() def process(self, task: AudioTask) -> AudioTask: - with self._fs.open(self._path, "a", encoding="utf-8") as f: - f.write(json.dumps(task.data, ensure_ascii=False) + "\n") - return AudioTask( - dataset_name=task.dataset_name, - data=task.data, - _metadata=task._metadata, - _stage_perf=list(task._stage_perf), + return self.process_batch([task])[0] + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + if len(tasks) == 0: + return [] + for task in tasks: + if not self.validate_input(task): + msg = f"Task {task.task_id} missing required columns for {type(self).__name__}: {self.inputs()}" + raise ValueError(msg) + lines = manifest_lines( + tasks, + self.drop_manifest_keys, + drop_array_like_values=self.drop_array_like_values, ) + if self.write_perf_stats: + self._writer_metrics.record_invocation(len(tasks)) + write_t0 = time.perf_counter() + with self._fs.open(self._path, "a", encoding="utf-8") as f: + f.writelines(lines) + if self.write_perf_stats: + self._writer_metrics.add_manifest_write_time(time.perf_counter() - write_t0) + for task in tasks: + self._writer_metrics.record_task(task) + self._write_perf_summary() + copied_tasks = [] + for task in tasks: + copied_task = AudioTask( + dataset_name=task.dataset_name, + data=task.data, + _metadata=task._metadata, + _stage_perf=list(task._stage_perf), + ) + copied_tasks.append(copied_task) + return copied_tasks + + def _resolved_perf_summary_path(self) -> str: + if self.perf_summary_path: + return self.perf_summary_path + parent = self.output_path.rsplit("/", 1)[0] if "/" in self.output_path else "" + return f"{parent}/perf_summary.json" if parent else "perf_summary.json" + + def _write_perf_summary(self) -> None: + summary_path = self._resolved_perf_summary_path() + fs, resolved = url_to_fs(summary_path) + parent_dir = "/".join(resolved.split("/")[:-1]) + if parent_dir: + fs.makedirs(parent_dir, exist_ok=True) + summary = self._writer_metrics.build_perf_summary() + write_t0 = time.perf_counter() + with fs.open(resolved, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + self._writer_metrics.add_perf_write_time(time.perf_counter() - write_t0) + logger.info(f"Wrote perf_summary.json: {summary_path}") + + def record_external_stage_perf(self, perf_stats: Any) -> None: + """Merge an externally collected stage summary into the persisted perf JSON.""" + if not self.write_perf_stats: + return + stage_summary = self._writer_metrics.build_external_stage_summary(perf_stats) + if not stage_summary: + return + summary_path = self._resolved_perf_summary_path() + fs, resolved = url_to_fs(summary_path) + parent_dir = "/".join(resolved.split("/")[:-1]) + if parent_dir: + fs.makedirs(parent_dir, exist_ok=True) + summary: dict[str, Any] + if fs.exists(resolved): + try: + with fs.open(resolved, "r", encoding="utf-8") as f: + summary = json.load(f) + except Exception as exc: # noqa: BLE001 + logger.warning("Could not read existing perf_summary.json at {}: {}", summary_path, exc) + summary = {} + else: + summary = {} + stages = summary.setdefault("stages", {}) + if isinstance(stages, dict): + stages[perf_stats.stage_name] = stage_summary + else: + summary["stages"] = {perf_stats.stage_name: stage_summary} + write_t0 = time.perf_counter() + with fs.open(resolved, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + self._writer_metrics.add_perf_write_time(time.perf_counter() - write_t0) + logger.info("Merged external perf stage {} into {}", perf_stats.stage_name, summary_path) + + def teardown(self) -> None: + if self.write_perf_stats and ( + self._writer_metrics.items_processed > 0 or self._writer_metrics.total_utterances > 0 + ): + self._write_perf_summary() + elif self.write_perf_stats: + logger.info("Skipping perf_summary.json write because no tasks were processed") def num_workers(self) -> int | None: return 1 @@ -301,7 +468,7 @@ def load_audio_file(audio_path: str, mono: bool = True) -> tuple[torch.Tensor, i return waveform, sample_rate -def ensure_waveform_2d(waveform: Any) -> torch.Tensor: # noqa: ANN401 +def ensure_waveform_2d(waveform: Any) -> torch.Tensor: """Ensure waveform is a torch.Tensor in 2D (channels, samples) format.""" if not torch.is_tensor(waveform): waveform = torch.as_tensor(waveform, dtype=torch.float32) diff --git a/nemo_curator/stages/audio/inference/__init__.py b/nemo_curator/stages/audio/inference/__init__.py index e69de29bb2..aed03e1465 100644 --- a/nemo_curator/stages/audio/inference/__init__.py +++ b/nemo_curator/stages/audio/inference/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.audio.inference.asr import ASRStage + +__all__ = ["ASRStage"] diff --git a/nemo_curator/stages/audio/inference/asr/__init__.py b/nemo_curator/stages/audio/inference/asr/__init__.py index e69de29bb2..8817d54be1 100644 --- a/nemo_curator/stages/audio/inference/asr/__init__.py +++ b/nemo_curator/stages/audio/inference/asr/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audio speech-recognition Curator stages. + +The generic stage-adapter split (``ASRStage`` + pluggable ASR adapter) +lives in ``stage.py``; the pre-existing NeMo-specific ASR stage stays in +``asr_nemo.py``. +""" + +from nemo_curator.stages.audio.inference.asr.stage import ASRStage + +__all__ = ["ASRStage"] diff --git a/nemo_curator/stages/audio/inference/asr/stage.py b/nemo_curator/stages/audio/inference/asr/stage.py new file mode 100644 index 0000000000..19dad3e709 --- /dev/null +++ b/nemo_curator/stages/audio/inference/asr/stage.py @@ -0,0 +1,1092 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic audio ASR Curator stage with a pluggable adapter. + +Curator-side glue: validates I/O, resolves per-task language, segments long +audio into model-sized work units, applies duration-aware bucketing inside the +backend-provided ``process_batch`` call, stitches results per parent task, and +writes predictions/metrics. The concrete adapter is +resolved at runtime from ``adapter_target`` via ``hydra.utils.get_class``. +""" + +from __future__ import annotations + +import time +from collections import OrderedDict, defaultdict +from dataclasses import dataclass, field +from numbers import Real +from threading import Lock +from typing import TYPE_CHECKING, Any + +import hydra.utils +from loguru import logger + +from nemo_curator.models.asr.base import ASRAdapter, ASRResult +from nemo_curator.pipeline.payload_refs import PayloadRef, resolve_payload_refs_batched +from nemo_curator.pipeline.prefetch import BoundedOneAheadPrefetchIterator +from nemo_curator.stages.audio.inference.bucketed_stage import BucketedInferenceStage +from nemo_curator.stages.audio.model_input_segmentation import plan_audio_segments, resolve_max_model_input_duration +from nemo_curator.stages.payload_lifecycle import PayloadAwareStageMixin +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + +if TYPE_CHECKING: + from collections.abc import Callable + + from nemo_curator.backends.base import NodeInfo, WorkerMetadata + from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy + + +# ISO code -> human-readable name; the adapter receives the resolved name. +_LANG_CODE_TO_NAME: dict[str, str] = { + "ar": "Arabic", + "bg": "Bulgarian", + "bn": "Bengali", + "cs": "Czech", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "et": "Estonian", + "fa": "Persian", + "fi": "Finnish", + "fil": "Filipino", + "fr": "French", + "gu": "Gujarati", + "he": "Hebrew", + "hi": "Hindi", + "hr": "Croatian", + "hu": "Hungarian", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "ko": "Korean", + "lt": "Lithuanian", + "lv": "Latvian", + "mk": "Macedonian", + "ml": "Malayalam", + "mr": "Marathi", + "mt": "Maltese", + "nl": "Dutch", + "no": "Norwegian", + "pa": "Punjabi", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sk": "Slovak", + "sl": "Slovenian", + "sr": "Serbian", + "sv": "Swedish", + "ta": "Tamil", + "te": "Telugu", + "th": "Thai", + "tl": "Tagalog", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "zh": "Chinese", +} + + +@dataclass(frozen=True) +class _ChunkSpec: + parent_task: AudioTask + parent_idx: int + chunk_idx: int + chunk_count: int + waveform: object | None + sample_rate: object + language: str | None + language_code: str | None + reference_text: str | None + cost: float + payload_ref: PayloadRef | None = None + start_sample: int = 0 + stop_sample: int = 0 + + +@dataclass(frozen=True) +class _InferenceCall: + indices: list[int] + items: list[dict[str, Any]] + + +_PAYLOAD_REF_ITEM_KEY = "_curator_payload_ref" +_PAYLOAD_START_ITEM_KEY = "_curator_payload_start_sample" +_PAYLOAD_STOP_ITEM_KEY = "_curator_payload_stop_sample" +_WAVEFORM_BYTES_ITEM_KEY = "_curator_waveform_bytes" + + +def _payload_cache_key(payload_ref: PayloadRef) -> tuple[str | None, str, str]: + return payload_ref.actor_namespace, payload_ref.store_actor_name, payload_ref.payload_id + + +class _PayloadCallMaterializer: + """Resolve and retain only payloads needed by active adapter calls.""" + + def __init__( + self, + *, + resolve_max_batch_bytes: int | None, + cache_max_bytes: int, + consumer_node_id: str, + slice_waveform: Callable[[object, int, int], object], + ) -> None: + self._resolve_max_batch_bytes = resolve_max_batch_bytes + self._cache_max_bytes = int(cache_max_bytes) + self._consumer_node_id = consumer_node_id + self._slice_waveform = slice_waveform + self._cache: OrderedDict[tuple[str | None, str, str], tuple[PayloadRef, object]] = OrderedDict() + self._active: dict[tuple[str | None, str, str], int] = defaultdict(int) + self._cache_bytes = 0 + self._lock = Lock() + self.resolution_count = 0 + self.resolution_bytes = 0 + self.same_node_count = 0 + self.cross_node_count = 0 + self.resolution_time_s = 0.0 + + def materialize(self, call: _InferenceCall) -> list[dict[str, Any]]: + refs = self._unique_call_refs(call) + with self._lock: + missing = [ref for ref in refs if _payload_cache_key(ref) not in self._cache] + + if missing: + started = time.perf_counter() + payloads = resolve_payload_refs_batched( + missing, + max_batch_bytes=self._resolve_max_batch_bytes, + ) + elapsed = time.perf_counter() - started + with self._lock: + self.resolution_time_s += elapsed + for ref, payload in zip(missing, payloads, strict=True): + key = _payload_cache_key(ref) + if key not in self._cache: + self._cache[key] = (ref, payload) + self._cache_bytes += max(0, int(ref.amount_bytes)) + self.resolution_count += 1 + self.resolution_bytes += max(0, int(ref.amount_bytes)) + if ref.owner_node_id and ref.owner_node_id == self._consumer_node_id: + self.same_node_count += 1 + else: + self.cross_node_count += 1 + + with self._lock: + for ref in refs: + key = _payload_cache_key(ref) + self._active[key] += 1 + self._cache.move_to_end(key) + materialized = [self._materialize_item(item) for item in call.items] + self._evict_inactive() + return materialized + + def complete(self, call: _InferenceCall) -> None: + with self._lock: + for ref in self._unique_call_refs(call): + key = _payload_cache_key(ref) + self._active[key] = max(0, self._active.get(key, 0) - 1) + self._evict_inactive() + + def close(self) -> None: + with self._lock: + self._cache.clear() + self._active.clear() + self._cache_bytes = 0 + + def _materialize_item(self, item: dict[str, Any]) -> dict[str, Any]: + payload_ref = item.get(_PAYLOAD_REF_ITEM_KEY) + if not isinstance(payload_ref, PayloadRef): + return item + _ref, waveform = self._cache[_payload_cache_key(payload_ref)] + start = int(item.get(_PAYLOAD_START_ITEM_KEY, 0)) + stop = int(item.get(_PAYLOAD_STOP_ITEM_KEY, payload_ref.num_samples)) + materialized = dict(item) + materialized["waveform"] = self._slice_waveform(waveform, start, stop) + materialized.pop(_PAYLOAD_REF_ITEM_KEY, None) + materialized.pop(_PAYLOAD_START_ITEM_KEY, None) + materialized.pop(_PAYLOAD_STOP_ITEM_KEY, None) + return materialized + + def _evict_inactive(self) -> None: + while self._cache_bytes > self._cache_max_bytes: + if len(self._cache) == 1: + # One source payload may itself exceed the lookahead budget + # (for example a multi-hour local parent). Keep that one + # payload across its contiguous model calls, but never prefetch + # another call beside it because the iterator's combined-byte + # check will fail. + return + evictable = next((key for key in self._cache if self._active.get(key, 0) == 0), None) + if evictable is None: + return + ref, _payload = self._cache.pop(evictable) + self._cache_bytes -= max(0, int(ref.amount_bytes)) + self._active.pop(evictable, None) + + @staticmethod + def _unique_call_refs(call: _InferenceCall) -> list[PayloadRef]: + refs: dict[tuple[str | None, str, str], PayloadRef] = {} + for item in call.items: + payload_ref = item.get(_PAYLOAD_REF_ITEM_KEY) + if isinstance(payload_ref, PayloadRef): + refs.setdefault(_payload_cache_key(payload_ref), payload_ref) + return list(refs.values()) + + +@dataclass +class ASRStage(PayloadAwareStageMixin, BucketedInferenceStage[AudioTask, AudioTask, "dict[str, Any]", ASRResult]): + """Audio speech-recognition Curator stage with pluggable adapter. + + Resolves an ``ASRAdapter`` from ``adapter_target``, slices long audio into + model-safe chunks, and stitches chunk outputs back to one result per input + task. Duration-aware bucketing is controlled independently by + ``batch_policy`` and packs already-created chunks. + """ + + # Adapter selection. + adapter_target: str + model_id: str + name: str = "ASR_inference" + revision: str | None = None + + # Task I/O keys. + waveform_key: str = "waveform" + waveform_ref_key: str | None = "waveform_ref" + sample_rate_key: str = "sample_rate" + source_lang_key: str = "source_lang" + reference_text_key: str | None = None + default_language: str | None = None + supported_language_codes: list[str] | None = None + pred_text_key: str = "pred_text" + disfluency_text_key: str | None = None + skip_me_key: str = "_skip_me" + + # Model-input segmentation and output retention. Long-row model safety is + # derived from max_inference_duration_s. + max_inference_duration_s: float = 2400.0 + keep_waveform: bool = True + + prefetch_fail_on_error: bool = True + + # Optional payload-resolution optimization. Defaults preserve the eager + # behavior used by existing pipelines; benchmark configs opt into bounded + # one-call lookahead explicitly. + payload_resolve_max_batch_bytes: int | None = None + payload_prefetch_enabled: bool = False + payload_prefetch_max_bytes: int | None = None + + # Worker placement. + xenna_num_workers: int | None = None + xenna_num_workers_per_node: int | None = None + + batch_policy: BatchPolicy | None = None + + adapter_kwargs: dict[str, Any] = field(default_factory=dict) + + resources: Resources = field(default_factory=lambda: Resources(gpus=1.0)) + # Backend-visible candidate window. Ray Data/Xenna use this to decide how + # many rows reach one process_batch() call; the final adapter call size is + # controlled separately by adapter_batch_size / batch_policy. + batch_size: int = 32 + adapter_batch_size: int | None = None + + def __post_init__(self) -> None: + self.max_inference_duration_s = resolve_max_model_input_duration( + max_duration_s=self.max_inference_duration_s, + owner="ASRStage", + ) + if int(self.batch_size) <= 0: + msg = f"ASRStage.batch_size must be > 0, got {self.batch_size}" + raise ValueError(msg) + self.batch_size = int(self.batch_size) + if self.adapter_batch_size is not None: + if int(self.adapter_batch_size) <= 0: + msg = f"ASRStage.adapter_batch_size must be > 0, got {self.adapter_batch_size}" + raise ValueError(msg) + self.adapter_batch_size = int(self.adapter_batch_size) + self._validate_payload_resolution_options() + if self.xenna_num_workers is not None and self.xenna_num_workers_per_node is not None: + msg = ( + "ASRStage: set at most one of xenna_num_workers " + "(cluster-wide) or xenna_num_workers_per_node (per-node); " + "they are mutually exclusive." + ) + raise ValueError(msg) + self._adapter: ASRAdapter | None = None + self._acc_model_metrics: dict[str, float] = defaultdict(float) + self._inference_elapsed_s: float = 0.0 + self._adapter_inference_calls: int = 0 + + def _validate_payload_resolution_options(self) -> None: + if not isinstance(self.payload_prefetch_enabled, bool): + msg = "ASRStage.payload_prefetch_enabled must be a bool" + raise TypeError(msg) + if self.payload_resolve_max_batch_bytes is not None: + if ( + isinstance(self.payload_resolve_max_batch_bytes, bool) + or int(self.payload_resolve_max_batch_bytes) <= 0 + ): + msg = "ASRStage.payload_resolve_max_batch_bytes must be > 0 when set" + raise ValueError(msg) + self.payload_resolve_max_batch_bytes = int(self.payload_resolve_max_batch_bytes) + if self.payload_prefetch_max_bytes is not None: + if isinstance(self.payload_prefetch_max_bytes, bool) or int(self.payload_prefetch_max_bytes) <= 0: + msg = "ASRStage.payload_prefetch_max_bytes must be > 0 when set" + raise ValueError(msg) + self.payload_prefetch_max_bytes = int(self.payload_prefetch_max_bytes) + if self.payload_prefetch_enabled and self.payload_prefetch_max_bytes is None: + msg = "ASRStage.payload_prefetch_max_bytes is required when payload_prefetch_enabled=True" + raise ValueError(msg) + self._adapter_inference_items: int = 0 + self._warned_ray_per_node_pin = False + self._supported_language_codes = self._normalise_supported_language_codes(self.supported_language_codes) + + @staticmethod + def _normalise_supported_language_codes(value: object) -> set[str] | None: + """Normalize an optional adapter-specific supported-language allowlist.""" + if value is None: + return None + raw_codes = value.split(",") if isinstance(value, str) else list(value) # type: ignore[arg-type] + codes = {str(code).strip().lower() for code in raw_codes if str(code).strip()} + return codes or None + + def _adapter_class(self) -> type: + return hydra.utils.get_class(self.adapter_target) + + def setup_on_node( + self, + _node_info: NodeInfo | None = None, + _worker_metadata: WorkerMetadata | None = None, + ) -> None: + """Cache model weights once per node (no GPU allocation).""" + try: + prefetch_t0 = time.perf_counter() + self._adapter_class().prefetch_weights(self.model_id, self.revision) + logger.info( + "ASR weights cached on node for {} ({}) in {:.3f}s", + self.model_id, + self.adapter_target, + time.perf_counter() - prefetch_t0, + ) + except Exception as exc: + msg = f"ASRStage: prefetch_weights failed for {self.model_id}" + if self.prefetch_fail_on_error: + raise RuntimeError(msg) from exc + logger.warning("{}; setup() will retry: {}", msg, exc) + + def setup_on_node_resources(self) -> Resources: + return Resources(cpus=1.0, gpus=0.0) + + def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: + if self._adapter is None: + cls = self._adapter_class() + self._adapter = cls( + model_id=self.model_id, + revision=self.revision, + **self.adapter_kwargs, + ) + self._adapter.setup() + logger.info("ASR adapter ready on worker ({})", self.adapter_target) + + def teardown(self) -> None: + if self._adapter is not None: + self._adapter.teardown() + self._adapter = None + + def num_workers(self) -> int | None: + if ( + self.xenna_num_workers is None + and self.xenna_num_workers_per_node is not None + and not self._warned_ray_per_node_pin + ): + logger.warning( + "ASRStage: xenna_num_workers_per_node={} is set but xenna_num_workers " + "is None; Ray Data has no per-node pin and will AUTOSCALE this GPU " + "stage. Set xenna_num_workers for a cluster-wide Ray Data pin.", + self.xenna_num_workers_per_node, + ) + self._warned_ray_per_node_pin = True + return self.xenna_num_workers + + def xenna_stage_spec(self) -> dict[str, Any]: + spec: dict[str, Any] = {} + if self.xenna_num_workers_per_node is not None: + spec["num_workers_per_node"] = self.xenna_num_workers_per_node + return spec + + def inputs(self) -> tuple[list[str], list[str]]: + waveform_input = self.waveform_ref_key or self.waveform_key + optional_inputs = [waveform_input, self.sample_rate_key] + if self.reference_text_key: + optional_inputs.append(self.reference_text_key) + return [], optional_inputs + + def outputs(self) -> tuple[list[str], list[str]]: + keys = [self.pred_text_key, self.skip_me_key] + if self.disfluency_text_key: + keys.append(self.disfluency_text_key) + return [], keys + + def _validate_asr_task_input(self, task: AudioTask) -> bool: + has_waveform = self.waveform_key in task.data + has_ref = bool(self.waveform_ref_key and self.waveform_ref_key in task.data) + if not has_waveform and not has_ref: + logger.error( + "Task {} missing ASR waveform input: expected '{}' or '{}'", + task.task_id, + self.waveform_key, + self.waveform_ref_key, + ) + return False + if self.sample_rate_key not in task.data: + logger.error("Task {} missing ASR sample-rate input '{}'", task.task_id, self.sample_rate_key) + return False + return True + + def _resolve_payload_refs(self, tasks: list[AudioTask]) -> list[AudioTask]: + return self.resolve_payload_refs_for_batch(tasks) + + def _drop_resolved_payload_waveforms(self, tasks: list[AudioTask]) -> None: + self.drop_resolved_payloads(tasks) + + def _resolve_language(self, task: AudioTask) -> str | None: + code = self._resolve_language_code(task) + if code: + return _LANG_CODE_TO_NAME.get(code, code) + return None + + def _resolve_language_code(self, task: AudioTask) -> str | None: + code = task.data.get(self.source_lang_key) if self.source_lang_key else None + if code: + return str(code).strip().lower() + if self.default_language: + return str(self.default_language).strip().lower() + return None + + def _is_language_supported(self, item: dict[str, Any]) -> bool: + if self._supported_language_codes is None: + return True + code = str(item.get("language_code", "") or "").strip().lower() + return bool(code) and code in self._supported_language_codes + + def _resolve_reference_text(self, task: AudioTask) -> str | None: + if not self.reference_text_key: + return None + value = task.data.get(self.reference_text_key) + if value is None: + return None + text = str(value).strip() + return text or None + + @staticmethod + def _waveform_num_samples(waveform: object) -> int: + shape = getattr(waveform, "shape", None) + if shape: + return int(shape[-1]) + try: + return len(waveform) # type: ignore[arg-type] + except TypeError: + return 0 + + @classmethod + def _waveform_is_empty(cls, waveform: object) -> bool: + return waveform is None or cls._waveform_num_samples(waveform) <= 0 + + @staticmethod + def _slice_waveform(waveform: object, start: int, stop: int) -> object: + try: + return waveform[..., start:stop] # type: ignore[index] + except TypeError: + return waveform[start:stop] # type: ignore[index] + + @staticmethod + def _waveform_nbytes(waveform: object) -> float: + nbytes = getattr(waveform, "nbytes", None) + if isinstance(nbytes, int): + return float(nbytes) + element_size = getattr(waveform, "element_size", None) + nelement = getattr(waveform, "nelement", None) + if callable(element_size) and callable(nelement): + return float(element_size() * nelement()) + return 0.0 + + @classmethod + def _chunk_waveform( + cls, + waveform: object, + sample_rate: int, + max_seconds: float, + ) -> list[object]: + """Return contiguous ``<= max_seconds`` sub-chunks of ``waveform``. + + Last chunk may be shorter (no padding/overlap). Returns ``[waveform]`` + unchanged when it already fits (the common case for ``data_config_s3_8``). + """ + if cls._waveform_is_empty(waveform) or not sample_rate or sample_rate <= 0: + return [waveform] + n = cls._waveform_num_samples(waveform) + segments = plan_audio_segments( + num_samples=n, + sample_rate=sample_rate, + max_duration_s=max_seconds, + owner="ASRStage", + ) + return [cls._slice_waveform(waveform, segment.start_sample, segment.stop_sample) for segment in segments] + + def build_items( + self, + tasks: list[AudioTask], + ) -> tuple[list[dict[str, Any]], list[int]]: + """Expand tasks into the flat adapter item list + parent map. + + ``BucketedInferenceStage`` hook (runs first each call): validates inputs, + resets per-call metric accumulators, then expands long clips into + model-input chunks using ``max_inference_duration_s``. + + Returns: + ``(items, parent_of)`` where ``parent_of[i]`` is the originating task + index. Each item carries ``waveform`` (chunk or full), + ``sample_rate`` (unchanged), ``language`` (resolved name), ``task_id``, + ``audio_seconds``, and ``chunk_idx`` / ``chunk_count``. + """ + for task in tasks: + if not self._validate_asr_task_input(task): + msg = f"Task {task.task_id} missing required columns for {type(self).__name__}: {self.inputs()}" + raise ValueError(msg) + if self._adapter is None: + msg = "Adapter not initialized - setup() was not called" + raise RuntimeError(msg) + + self._acc_model_metrics = defaultdict(float) + self._inference_elapsed_s = 0.0 + self._adapter_inference_calls = 0 + self._adapter_inference_items = 0 + + chunk_specs = self._build_chunk_specs(tasks) + items = [self._chunk_spec_to_item(spec) for spec in chunk_specs] + parent_of = [spec.parent_idx for spec in chunk_specs] + return items, parent_of + + def _stitch( + self, + results: list[ASRResult], + parent_of: list[int], + num_parents: int, + ) -> list[ASRResult]: + """Join per-chunk text outputs per parent task with single spaces. + + Parent is marked skipped only if EVERY chunk was skipped; if any chunk + succeeded, its non-empty texts are joined and the parent is not skipped. + """ + per_parent_texts: list[list[str]] = [[] for _ in range(num_parents)] + per_parent_secondary: list[list[str]] = [[] for _ in range(num_parents)] + per_parent_skip_count: list[int] = [0] * num_parents + per_parent_chunk_count: list[int] = [0] * num_parents + per_parent_model_id: list[str] = [""] * num_parents + per_parent_skip_reason: list[str | None] = [None] * num_parents + + for r, parent in zip(results, parent_of, strict=True): + per_parent_chunk_count[parent] += 1 + if r.skipped: + per_parent_skip_count[parent] += 1 + if per_parent_skip_reason[parent] is None: + reason = r.extras.get("skip_reason") + if reason: + per_parent_skip_reason[parent] = str(reason) + text = (r.text or "").strip() + if text: + per_parent_texts[parent].append(text) + sec = (r.secondary_text or "").strip() + if sec: + per_parent_secondary[parent].append(sec) + if r.model_id and not per_parent_model_id[parent]: + per_parent_model_id[parent] = r.model_id + + stitched: list[ASRResult] = [] + for p in range(num_parents): + all_skipped = per_parent_chunk_count[p] > 0 and per_parent_skip_count[p] == per_parent_chunk_count[p] + stitched.append( + ASRResult( + text=" ".join(per_parent_texts[p]), + secondary_text=" ".join(per_parent_secondary[p]) if per_parent_secondary[p] else None, + skipped=all_skipped, + model_id=per_parent_model_id[p], + extras={"skip_reason": per_parent_skip_reason[p]} if per_parent_skip_reason[p] else {}, + ) + ) + return stitched + + def process(self, task: AudioTask) -> AudioTask: + msg = f"{type(self).__name__} only supports process_batch" + raise NotImplementedError(msg) + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + """Run one backend-provided ASR batch. + + Backend executors own how many parent rows reach this call. Model-input + segmentation and duration-aware bucketing stay inside the stage: parent + rows are expanded into bounded model-input items here, bucketed when + ``batch_policy`` is enabled, and stitched back into the original parent + order before returning. + """ + if len(tasks) == 0: + return [] + if self.payload_prefetch_enabled and self._has_unresolved_payload_refs(tasks): + return self._process_plain_batch(tasks) + inserted_waveforms: list[AudioTask] = [] + try: + inserted_waveforms = self._resolve_payload_refs(tasks) + return self._process_plain_batch(tasks) + finally: + self._drop_resolved_payload_waveforms(inserted_waveforms) + + def _has_unresolved_payload_refs(self, tasks: list[AudioTask]) -> bool: + if not self.waveform_ref_key: + return False + return any( + self.waveform_key not in task.data and isinstance(task.data.get(self.waveform_ref_key), PayloadRef) + for task in tasks + ) + + def _process_plain_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + """Dispatch one unbucketed backend batch. + + The backend's normal batch stays intact, but each long parent is sliced + into bounded model-input chunks before adapter inference and stitched + back afterward. If a parent already fits the model-input window, it + remains one adapter item. + """ + for task in tasks: + if not self._validate_asr_task_input(task): + msg = f"Task {task.task_id} missing required columns for {type(self).__name__}: {self.inputs()}" + raise ValueError(msg) + if self._adapter is None: + msg = "Adapter not initialized - setup() was not called" + raise RuntimeError(msg) + + self._acc_model_metrics = defaultdict(float) + self._inference_elapsed_s = 0.0 + self._adapter_inference_calls = 0 + self._adapter_inference_items = 0 + + chunk_specs = self._build_chunk_specs(tasks) + items = [self._chunk_spec_to_item(spec) for spec in chunk_specs] + parent_of = [spec.parent_idx for spec in chunk_specs] + + if any(isinstance(item.get(_PAYLOAD_REF_ITEM_KEY), PayloadRef) for item in items): + results = self._run_payload_inference_capped(items) + else: + results = self._run_inference_capped(items) + if len(results) != len(items): + msg = f"run_fn returned {len(results)} results for {len(items)} items (must match 1:1)" + raise RuntimeError(msg) + return self.assemble(tasks, items, parent_of, results) + + def _run_inference_capped(self, items: list[dict[str, Any]]) -> list[ASRResult]: + """Run adapter calls with bucket-aware per-call item caps. + + When segmentation fans one backend batch out into many model work units, + cap each direct adapter call to avoid turning one long parent row into + an oversized vLLM request list. When ``BatchPolicy`` supplies + ``bucketed_inference_batch_size``, same-bucket short chunks can use a + larger model-call batch than long chunks; otherwise + ``adapter_batch_size`` is the fallback. Results are realigned to the + original item order. + """ + aligned_results, calls = self._plan_inference_calls(items) + for call in calls: + self._store_call_results(call, self.run_inference(call.items), aligned_results) + return self._finalize_aligned_results(aligned_results) + + def _run_payload_inference_capped(self, items: list[dict[str, Any]]) -> list[ASRResult]: + """Resolve only the current adapter call and prefetch one successor.""" + if self.payload_prefetch_max_bytes is None: + msg = "payload_prefetch_max_bytes is required for payload-prefetched inference" + raise RuntimeError(msg) + aligned_results, calls = self._plan_inference_calls(items) + payload_refs = [ + payload_ref for item in items if isinstance((payload_ref := item.get(_PAYLOAD_REF_ITEM_KEY)), PayloadRef) + ] + materializer = _PayloadCallMaterializer( + resolve_max_batch_bytes=self.payload_resolve_max_batch_bytes, + cache_max_bytes=self.payload_prefetch_max_bytes, + consumer_node_id=self.payload_consumer_node_id(), + slice_waveform=self._slice_waveform, + ) + self._start_payload_lease_keeper(payload_refs) + try: + prefetched_calls = BoundedOneAheadPrefetchIterator( + calls, + loader=materializer.materialize, + size_bytes=self._inference_call_payload_bytes, + max_inflight_bytes=self.payload_prefetch_max_bytes, + thread_name_prefix="curator-asr-payload-prefetch", + ) + for call, materialized_items in prefetched_calls: + try: + self._store_call_results(call, self.run_inference(materialized_items), aligned_results) + finally: + materializer.complete(call) + finally: + self._stop_payload_lease_keeper() + self._log_payload_resolution_metrics(materializer) + materializer.close() + return self._finalize_aligned_results(aligned_results) + + def _plan_inference_calls( + self, + items: list[dict[str, Any]], + ) -> tuple[list[ASRResult | None], list[_InferenceCall]]: + """Build exact adapter-call boundaries without resolving payload bytes.""" + aligned_results: list[ASRResult | None] = [None] * len(items) + eligible_items: list[dict[str, Any]] = [] + eligible_indices: list[int] = [] + for idx, item in enumerate(items): + if self._is_language_supported(item): + eligible_items.append(item) + eligible_indices.append(idx) + continue + code = str(item.get("language_code", "") or "").strip().lower() + aligned_results[idx] = ASRResult( + text="", + skipped=True, + extras={"skip_reason": f"lang_not_supported:{code or 'unknown'}"}, + ) + + calls: list[_InferenceCall] = [] + policy = self.batch_policy + if policy is None or not policy.enabled: + cursor = 0 + for sub_items in self._split_items_for_inference_calls(eligible_items): + sub_indices = eligible_indices[cursor : cursor + len(sub_items)] + calls.append(_InferenceCall(indices=sub_indices, items=sub_items)) + cursor += len(sub_items) + return aligned_results, calls + + for bucket_indices, bucket_items, _total_cost in policy.bucketize_with_costs( + eligible_items, + cost_fn=self.item_cost, + ): + cursor = 0 + for sub_items in self._split_items_for_inference_calls(bucket_items): + local_indices = bucket_indices[cursor : cursor + len(sub_items)] + calls.append( + _InferenceCall( + indices=[eligible_indices[index] for index in local_indices], + items=sub_items, + ) + ) + cursor += len(sub_items) + return aligned_results, calls + + @staticmethod + def _store_call_results( + call: _InferenceCall, + results: list[ASRResult], + aligned_results: list[ASRResult | None], + ) -> None: + if len(results) != len(call.items): + msg = f"run_fn returned {len(results)} results for {len(call.items)} items (must match 1:1)" + raise RuntimeError(msg) + for index, result in zip(call.indices, results, strict=True): + aligned_results[index] = result + + @staticmethod + def _finalize_aligned_results(aligned_results: list[ASRResult | None]) -> list[ASRResult]: + if any(result is None for result in aligned_results): + msg = "ASR call planning did not produce an inference result for every item" + raise RuntimeError(msg) + return [result for result in aligned_results if result is not None] + + @staticmethod + def _inference_call_payload_bytes(call: _InferenceCall) -> int: + refs: dict[tuple[str | None, str, str], PayloadRef] = {} + for item in call.items: + payload_ref = item.get(_PAYLOAD_REF_ITEM_KEY) + if isinstance(payload_ref, PayloadRef): + refs.setdefault(_payload_cache_key(payload_ref), payload_ref) + return sum(max(0, int(ref.amount_bytes)) for ref in refs.values()) + + def _log_payload_resolution_metrics(self, materializer: _PayloadCallMaterializer) -> None: + if materializer.resolution_count <= 0: + return + self._log_metrics( + { + "payload_resolution_count": float(materializer.resolution_count), + "payload_resolution_same_node_count": float(materializer.same_node_count), + "payload_resolution_cross_node_count": float(materializer.cross_node_count), + "payload_resolution_bytes": float(materializer.resolution_bytes), + "payload_resolution_time_s": float(materializer.resolution_time_s), + } + ) + + def _split_items_for_inference_calls(self, items: list[dict[str, Any]]) -> list[list[dict[str, Any]]]: + """Split contiguous items by bucket-aware adapter-call batch size.""" + batches: list[list[dict[str, Any]]] = [] + current: list[dict[str, Any]] = [] + current_bucket: int | None = None + current_cap = 1 + + for item in items: + bucket = self._bucket_for_inference_item(item) + cap = self._inference_batch_size_for_item(item) + if current and (bucket != current_bucket or len(current) >= current_cap): + batches.append(current) + current = [] + current.append(item) + current_bucket = bucket + current_cap = cap + + if current: + batches.append(current) + return batches + + def _bucket_for_inference_item(self, item: dict[str, Any]) -> int | None: + policy = self.batch_policy + if policy is None or not policy.enabled or policy.bucketed_inference_batch_size is None: + return None + return policy.bucket_for(self.item_cost(item)) + + def _inference_batch_size_for_item(self, item: dict[str, Any]) -> int: + fallback_source = self.adapter_batch_size if self.adapter_batch_size is not None else self.batch_size + fallback = max(1, int(fallback_source or 1)) + policy = self.batch_policy + if policy is None: + return fallback + return policy.inference_batch_size_for_cost(self.item_cost(item), fallback) + + def _build_chunk_specs(self, tasks: list[AudioTask]) -> list[_ChunkSpec]: + """Build model-input descriptors from waveforms or payload metadata.""" + specs: list[_ChunkSpec] = [] + slice_ceiling = float(self.max_inference_duration_s) + for parent_idx, task in enumerate(tasks): + waveform = task.data.get(self.waveform_key) + sample_rate = task.data.get(self.sample_rate_key) + payload_ref = task.data.get(self.waveform_ref_key) if self.waveform_ref_key else None + language = self._resolve_language(task) + language_code = self._resolve_language_code(task) + reference_text = self._resolve_reference_text(task) + if self._waveform_is_empty(waveform) and isinstance(payload_ref, PayloadRef) and sample_rate: + sr = int(sample_rate) + segments = plan_audio_segments( + num_samples=int(payload_ref.num_samples), + sample_rate=sr, + max_duration_s=slice_ceiling, + owner="ASRStage", + ) + for segment in segments: + specs.append( + _ChunkSpec( + parent_task=task, + parent_idx=parent_idx, + chunk_idx=segment.index, + chunk_count=segment.count, + waveform=None, + sample_rate=sr, + language=language, + language_code=language_code, + reference_text=reference_text, + cost=segment.duration_s, + payload_ref=payload_ref, + start_sample=segment.start_sample, + stop_sample=segment.stop_sample, + ) + ) + continue + if self._waveform_is_empty(waveform) or not sample_rate: + specs.append( + _ChunkSpec( + parent_task=task, + parent_idx=parent_idx, + chunk_idx=0, + chunk_count=1, + waveform=waveform, + sample_rate=sample_rate, + language=language, + language_code=language_code, + reference_text=reference_text, + cost=0.0, + stop_sample=0, + ) + ) + continue + + sr = int(sample_rate) + chunks = self._chunk_waveform(waveform, sr, slice_ceiling) + chunk_count = len(chunks) + for chunk_idx, chunk in enumerate(chunks): + specs.append( + _ChunkSpec( + parent_task=task, + parent_idx=parent_idx, + chunk_idx=chunk_idx, + chunk_count=chunk_count, + waveform=chunk, + sample_rate=sr, + language=language, + language_code=language_code, + reference_text=reference_text, + cost=0.0 if sr <= 0 else float(self._waveform_num_samples(chunk)) / float(sr), + start_sample=0, + stop_sample=self._waveform_num_samples(chunk), + ) + ) + return specs + + def _chunk_spec_to_item(self, spec: _ChunkSpec) -> dict[str, Any]: + """Convert a virtual chunk descriptor into one adapter input item.""" + item = { + "waveform": spec.waveform, + "sample_rate": spec.sample_rate, + "language": spec.language, + "language_code": spec.language_code, + "reference_text": spec.reference_text, + "task_id": spec.parent_task.task_id, + "audio_seconds": spec.cost, + "chunk_idx": spec.chunk_idx, + "chunk_count": spec.chunk_count, + _WAVEFORM_BYTES_ITEM_KEY: self._chunk_spec_waveform_bytes(spec), + } + if spec.payload_ref is not None: + item[_PAYLOAD_REF_ITEM_KEY] = spec.payload_ref + item[_PAYLOAD_START_ITEM_KEY] = spec.start_sample + item[_PAYLOAD_STOP_ITEM_KEY] = spec.stop_sample + return item + + @classmethod + def _chunk_spec_waveform_bytes(cls, spec: _ChunkSpec) -> float: + if spec.payload_ref is None: + return cls._waveform_nbytes(spec.waveform) + total_samples = max(1, int(spec.payload_ref.num_samples)) + segment_samples = max(0, int(spec.stop_sample) - int(spec.start_sample)) + return float(int(spec.payload_ref.amount_bytes) * segment_samples) / float(total_samples) + + def item_cost(self, item: dict[str, Any]) -> float: + """Bucketing cost of one sub-chunk. + + Duration remains the default cost unit, but adapters may provide a + better estimator for scheduler pressure (for example encoder tokens or + approximate VRAM units) without changing executor autoscaling. + """ + estimator = getattr(self._adapter, "estimate_item_cost", None) + if callable(estimator): + try: + estimated = estimator(item) + if isinstance(estimated, Real): + return max(0.0, float(estimated)) + except Exception as exc: # noqa: BLE001 + logger.debug("ASR adapter cost estimator failed; falling back to duration cost: {}", exc) + for key in ("estimated_vram_units", "estimated_encoder_tokens"): + value = item.get(key) + if value is not None: + return max(0.0, float(value)) + return float(item.get("audio_seconds", 0.0)) + + def run_inference(self, items: list[dict[str, Any]]) -> list[ASRResult]: + """Transcribe ONE bucket-respecting sub-batch via the adapter. + + Also folds the adapter's ``last_metrics`` and wall-clock time into the + per-``process_batch`` accumulators that :meth:`assemble` reports. + """ + inference_t0 = time.perf_counter() + sub_results = self._adapter.transcribe_batch(items) + self._inference_elapsed_s += time.perf_counter() - inference_t0 + self._adapter_inference_calls += 1 + self._adapter_inference_items += len(items) + last_m = dict(getattr(self._adapter, "last_metrics", {}) or {}) + for k, v in last_m.items(): + if isinstance(v, (int, float)): + self._acc_model_metrics[k] += float(v) + return sub_results + + def assemble( + self, + tasks: list[AudioTask], + items: list[dict[str, Any]], + parent_of: list[int], + results: list[ASRResult], + ) -> list[AudioTask]: + """Stitch sub-chunk results per parent task, write outputs, emit metrics.""" + accumulated_model_metrics = self._acc_model_metrics + inference_elapsed = self._inference_elapsed_s + + # Defensive: turn any None slots into skipped placeholders, guarding + # against silent data loss if a future adapter forgets a slot. + chunk_results: list[ASRResult] = [r if r is not None else ASRResult(text="", skipped=True) for r in results] + + per_parent_results = self._stitch(chunk_results, parent_of, num_parents=len(tasks)) + + skipped_count = 0 + for task, parent_result in zip(tasks, per_parent_results, strict=True): + task.data[self.pred_text_key] = parent_result.text + if self.disfluency_text_key: + task.data[self.disfluency_text_key] = parent_result.secondary_text or "" + if parent_result.skipped: + task.data[self.skip_me_key] = str(parent_result.extras.get("skip_reason") or "empty_audio") + skipped_count += 1 + if not self.keep_waveform: + task.data.pop(self.waveform_key, None) + + # ``utterances_*`` count PARENT tasks (the 1-row-per-input semantic + # downstream consumers rely on); ``sub_chunks_generated`` surfaces the + # pre-slicer fan-out. + metrics: dict[str, float] = { + "utterances_input": float(len(tasks)), + "utterances_processed": float(max(0, len(tasks) - skipped_count)), + "utterances_skipped": float(skipped_count), + "sub_chunks_generated": float(len(items)), + "audio_duration_s": sum(float(item.get("audio_seconds", 0.0)) for item in items), + "waveform_bytes": sum( + float(item.get(_WAVEFORM_BYTES_ITEM_KEY, self._waveform_nbytes(item.get("waveform")))) + for item in items + ), + "output_chars": float( + sum(len(r.text) for r in per_parent_results) + + sum(len(r.secondary_text or "") for r in per_parent_results) + ), + "output_tokens": float(accumulated_model_metrics.get("output_tokens", 0.0)), + "turn1_output_tokens": float(accumulated_model_metrics.get("turn1_output_tokens", 0.0)), + "turn2_output_tokens": float(accumulated_model_metrics.get("turn2_output_tokens", 0.0)), + "inference_time_s": inference_elapsed, + "adapter_inference_calls": float(self._adapter_inference_calls), + "adapter_inference_items": float(self._adapter_inference_items), + } + # Pass through adapter scalar metrics under a "model_" alias, + # skipping any that would restate a key the stage already emits. + metrics.update( + { + f"model_{name}": value + for name, value in accumulated_model_metrics.items() + if isinstance(value, (int, float)) and name not in metrics + } + ) + self._log_metrics(metrics) + + if skipped_count: + logger.info( + f"ASRStage ({self.adapter_target}): marked {skipped_count}/{len(tasks)} " + f"tasks as empty_audio ({self.skip_me_key})", + ) + logger.debug( + f"ASRStage ({self.adapter_target}): generated {len(per_parent_results)} parent predictions " + f"from {len(items)} sub-chunk(s) " + f"(disfluency_text={'on' if self.disfluency_text_key else 'off'})", + ) + return tasks diff --git a/nemo_curator/stages/audio/inference/batch_policy.py b/nemo_curator/stages/audio/inference/batch_policy.py new file mode 100644 index 0000000000..4979859c7d --- /dev/null +++ b/nemo_curator/stages/audio/inference/batch_policy.py @@ -0,0 +1,463 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cost-aware batch policy for GPU inference stages. + +Hydra-instantiable policy for heterogeneous model-input items. Stages use +``bucketize`` / ``bucketize_with_costs`` inside ``process_batch`` to split a +backend-provided parent-row batch into duration/cost-coherent adapter calls. +Cost is supplied via ``cost_fn``; bucket edges and cost budgets are in the same +units (audio seconds for ASR, the default consumer). + +``BatchPolicy`` does not replace backend scheduling. Xenna and Ray Data still +decide how many parent rows reach a worker based on stage ``batch_size``, +resources, and worker count. Once inside the stage, the policy controls +same-bucket grouping and final adapter-call caps. ``prebatching_window_size`` is +an optional advisory window for callers that maintain their own finite +candidate list. Payload lifecycle may inherit it as the backend candidate-row +window, but payload lifecycle still splits that candidate list by node byte +budget before materializing payloads. +""" + +from __future__ import annotations + +import time +from collections import deque +from dataclasses import dataclass, field +from numbers import Real +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable + + +@dataclass(frozen=True) +class ReadyBatch: + """A scheduler-emitted batch that is ready to dispatch to a worker.""" + + indices: list[int] + items: list[object] + total_cost: float + bucket_index: int + flush_reason: str + + +@dataclass(frozen=True) +class _QueuedItem: + index: int + item: object + cost: float + enqueued_ms: float + + +class BucketQueueScheduler: + """Persistent per-bucket queue scheduler for cost-aware GPU dispatch. + + The scheduler accepts already-costed work units one at a time. Each unit + enters exactly one bucket queue. A queue flushes when adding work would + overflow its item/cost budget, when it reaches a budget exactly, when its + timer expires, or when the caller drains it. This is the shared primitive + used by finite ``BatchPolicy.bucketize_with_costs`` plans and by executor + paths that can hold a persistent scheduling window. + """ + + def __init__(self, policy: BatchPolicy, *, enable_timer: bool = True) -> None: + self.policy = policy + self._enable_timer = enable_timer + self._queues: list[deque[_QueuedItem]] = [deque() for _ in range(policy.num_buckets)] + self._costs: list[float] = [0.0 for _ in range(policy.num_buckets)] + self._first_enqueued_ms: list[float | None] = [None for _ in range(policy.num_buckets)] + + def enqueue(self, index: int, item: object, cost: float, now_ms: float | None = None) -> list[ReadyBatch]: + """Queue one item and return any batches that became ready.""" + now = self._now_ms(now_ms) if self._enable_timer else self._static_ms(now_ms) + ready = self.flush_due(now) if self._enable_timer else [] + + bucket_index = self.policy.bucket_for(float(cost)) + if self._would_overflow(bucket_index, float(cost)): + flushed = self._flush_bucket(bucket_index, "capacity") + if flushed is not None: + ready.append(flushed) + + queue = self._queues[bucket_index] + if not queue: + self._first_enqueued_ms[bucket_index] = now + queue.append(_QueuedItem(index=index, item=item, cost=float(cost), enqueued_ms=now)) + self._costs[bucket_index] += float(cost) + + ready_reason = self._ready_reason(bucket_index) + if ready_reason is not None: + flushed = self._flush_bucket(bucket_index, ready_reason) + if flushed is not None: + ready.append(flushed) + return ready + + def flush_due(self, now_ms: float | None = None) -> list[ReadyBatch]: + """Flush queues whose first item has exceeded ``flush_interval_ms``.""" + if not self._enable_timer: + return [] + interval_ms = float(getattr(self.policy, "flush_interval_ms", 0) or 0) + if interval_ms <= 0: + return [] + + now = self._now_ms(now_ms) + ready: list[ReadyBatch] = [] + for bucket_index, first_ms in enumerate(self._first_enqueued_ms): + if first_ms is not None and now - first_ms >= interval_ms: + flushed = self._flush_bucket(bucket_index, "timer") + if flushed is not None: + ready.append(flushed) + return ready + + def flush_all(self, reason: str = "drain") -> list[ReadyBatch]: + """Drain all non-empty bucket queues.""" + ready: list[ReadyBatch] = [] + for bucket_index in range(self.policy.num_buckets): + flushed = self._flush_bucket(bucket_index, reason) + if flushed is not None: + ready.append(flushed) + return ready + + def _would_overflow(self, bucket_index: int, cost: float) -> bool: + queue = self._queues[bucket_index] + if not queue: + return False + item_cap = int(self.policy.max_items_per_batch_by_bucket[bucket_index]) + if len(queue) >= item_cap: + return True + cost_cap = self.policy.max_audio_sec_per_batch + return cost_cap is not None and self._costs[bucket_index] + cost > float(cost_cap) + + def _ready_reason(self, bucket_index: int) -> str | None: + queue = self._queues[bucket_index] + item_cap = int(self.policy.max_items_per_batch_by_bucket[bucket_index]) + if len(queue) >= item_cap: + return "item_cap" + cost_cap = self.policy.max_audio_sec_per_batch + if cost_cap is not None and self._costs[bucket_index] >= float(cost_cap): + return "cost_cap" + return None + + def _flush_bucket(self, bucket_index: int, reason: str) -> ReadyBatch | None: + queue = self._queues[bucket_index] + if not queue: + return None + queued = list(queue) + queue.clear() + total_cost = self._costs[bucket_index] + self._costs[bucket_index] = 0.0 + self._first_enqueued_ms[bucket_index] = None + return ReadyBatch( + indices=[queued_item.index for queued_item in queued], + items=[queued_item.item for queued_item in queued], + total_cost=total_cost, + bucket_index=bucket_index, + flush_reason=reason, + ) + + @staticmethod + def _now_ms(now_ms: float | None) -> float: + if now_ms is not None: + return float(now_ms) + return time.monotonic() * 1000.0 + + @staticmethod + def _static_ms(now_ms: float | None) -> float: + if now_ms is not None: + return float(now_ms) + return 0.0 + + +@dataclass +class BatchPolicy: + """Cost-bucketed batching policy. + + Defaults match the Qwen-Omni tutorial layout (``buckets_sec=[0, 600, 1200, + 2400]`` when ``max_inference_duration_s=2400``). + + Args: + enabled: When ``False``, the policy is carried in config but backend + adapters and ``run_bucketed`` dispatch one normal batch, matching + ``policy=None``. + strategy: Only ``"duration_bucketed"`` is implemented; other values are + reserved for future use. + buckets_sec: Strictly-increasing left edges starting at ``0`` (cost + units). Bucket ``i`` covers ``[buckets_sec[i], buckets_sec[i+1])``; + the last covers ``[buckets_sec[-1], +inf)``. + max_items_per_batch_by_bucket: Per-bucket item cap; length must equal + ``len(buckets_sec)``. + bucketed_inference_batch_size: Optional per-bucket cap for the final + model adapter call. This is separate from scheduler batch shape: + scheduler caps decide which items reach the worker together; this + cap decides how many same-bucket items are sent into one + ``model.inference``/adapter call. ``None`` keeps the stage's normal + ``batch_size`` fallback. + max_audio_sec_per_batch: Optional per-sub-batch total-cost cap (``None`` + = only item caps apply). + prebatching_window_size: Optional advisory candidate-window size for + callers that maintain a finite scheduling window. ``None`` + preserves the derived default ``sum(max_items_per_batch_by_bucket)``. + flush_interval_ms: Cross-call queue flush timer (ms). Persistent + schedulers use it directly; finite ``bucketize`` calls drain at the + end of the supplied item window. + """ + + enabled: bool = True + strategy: str = "duration_bucketed" + buckets_sec: list[float] = field(default_factory=lambda: [0.0, 600.0, 1200.0, 2400.0]) + max_items_per_batch_by_bucket: list[int] = field(default_factory=lambda: [32, 16, 8, 4]) + bucketed_inference_batch_size: list[int] | None = None + max_audio_sec_per_batch: float | None = 2400.0 + prebatching_window_size: int | None = None + flush_interval_ms: int = 250 + + def __post_init__(self) -> None: + self._validate_flags() + if not self.enabled: + return + self._validate_strategy() + self._validate_bucket_edges() + self._validate_batch_caps() + self._validate_prebatching_window() + + def _validate_flags(self) -> None: + if not isinstance(self.enabled, bool): + msg = f"BatchPolicy: enabled must be a bool, got {type(self.enabled).__name__}" + raise TypeError(msg) + if self.prebatching_window_size is not None and ( + isinstance(self.prebatching_window_size, bool) or not isinstance(self.prebatching_window_size, int) + ): + msg = ( + f"BatchPolicy: prebatching_window_size must be an int or None, " + f"got {type(self.prebatching_window_size).__name__}" + ) + raise TypeError(msg) + if isinstance(self.flush_interval_ms, bool) or not isinstance(self.flush_interval_ms, int): + msg = f"BatchPolicy: flush_interval_ms must be an int, got {type(self.flush_interval_ms).__name__}" + raise TypeError(msg) + + def _validate_strategy(self) -> None: + if self.strategy != "duration_bucketed": + msg = ( + f"BatchPolicy: strategy={self.strategy!r} not yet implemented; only 'duration_bucketed' is supported." + ) + raise ValueError(msg) + + def _validate_bucket_edges(self) -> None: + if not self.buckets_sec: + msg = "BatchPolicy: buckets_sec must contain at least one edge" + raise ValueError(msg) + for edge in self.buckets_sec: + if isinstance(edge, bool) or not isinstance(edge, Real): + msg = f"BatchPolicy: every buckets_sec entry must be numeric, got {type(edge).__name__}" + raise TypeError(msg) + if self.buckets_sec[0] != 0.0: + msg = f"BatchPolicy: buckets_sec must start at 0.0, got {self.buckets_sec[0]}" + raise ValueError(msg) + for i in range(len(self.buckets_sec) - 1): + if self.buckets_sec[i + 1] <= self.buckets_sec[i]: + msg = ( + f"BatchPolicy: buckets_sec must be strictly increasing; " + f"got {self.buckets_sec[i]} -> {self.buckets_sec[i + 1]}" + ) + raise ValueError(msg) + + def _validate_batch_caps(self) -> None: # noqa: C901 + if len(self.max_items_per_batch_by_bucket) != len(self.buckets_sec): + msg = ( + f"BatchPolicy: max_items_per_batch_by_bucket has " + f"{len(self.max_items_per_batch_by_bucket)} entries but buckets_sec has " + f"{len(self.buckets_sec)}; lengths must match" + ) + raise ValueError(msg) + for cap in self.max_items_per_batch_by_bucket: + if isinstance(cap, bool) or not isinstance(cap, int): + msg = ( + f"BatchPolicy: every max_items_per_batch_by_bucket entry must be an int, got {type(cap).__name__}" + ) + raise TypeError(msg) + if cap <= 0: + msg = f"BatchPolicy: every max_items_per_batch_by_bucket entry must be > 0, got {cap}" + raise ValueError(msg) + if self.bucketed_inference_batch_size is not None: + if len(self.bucketed_inference_batch_size) != len(self.buckets_sec): + msg = ( + f"BatchPolicy: bucketed_inference_batch_size has " + f"{len(self.bucketed_inference_batch_size)} entries but buckets_sec has " + f"{len(self.buckets_sec)}; lengths must match" + ) + raise ValueError(msg) + for cap in self.bucketed_inference_batch_size: + if isinstance(cap, bool) or not isinstance(cap, int): + msg = ( + f"BatchPolicy: every bucketed_inference_batch_size entry must be an int, " + f"got {type(cap).__name__}" + ) + raise TypeError(msg) + if cap <= 0: + msg = f"BatchPolicy: every bucketed_inference_batch_size entry must be > 0, got {cap}" + raise ValueError(msg) + if self.max_audio_sec_per_batch is not None: + if isinstance(self.max_audio_sec_per_batch, bool) or not isinstance(self.max_audio_sec_per_batch, Real): + msg = ( + f"BatchPolicy: max_audio_sec_per_batch must be numeric or None, " + f"got {type(self.max_audio_sec_per_batch).__name__}" + ) + raise TypeError(msg) + if self.max_audio_sec_per_batch <= 0: + msg = f"BatchPolicy: max_audio_sec_per_batch must be > 0 (or None), got {self.max_audio_sec_per_batch}" + raise ValueError(msg) + + def _validate_prebatching_window(self) -> None: + if self.prebatching_window_size is not None and self.prebatching_window_size <= 0: + msg = f"BatchPolicy: prebatching_window_size must be > 0 (or None), got {self.prebatching_window_size}" + raise ValueError(msg) + if self.flush_interval_ms < 0: + msg = f"BatchPolicy: flush_interval_ms must be >= 0, got {self.flush_interval_ms}" + raise ValueError(msg) + + @property + def num_buckets(self) -> int: + return len(self.buckets_sec) + + def bucket_for(self, cost: float) -> int: + """Return the bucket index for an item with the given cost. + + Left-edge semantics: cost 600 with ``[0, 600, 1200, 2400]`` lands in + bucket 1 (``[600, 1200)``). Items at/above the top edge clamp into the + last bucket (the pre-slicer should prevent this, but the clamp keeps the + helper robust). + """ + for i in range(self.num_buckets - 1, -1, -1): + if cost >= self.buckets_sec[i]: + return i + return 0 + + def inference_batch_size_for_cost(self, cost: float, fallback: int) -> int: + """Return adapter-call batch size for one item cost. + + This is intentionally independent from ``max_items_per_batch_by_bucket``: + the scheduler can form a coherent worker batch, then the stage can tune + the final model-call granularity for short versus long items. + """ + if not self.enabled or self.bucketed_inference_batch_size is None: + return max(1, int(fallback)) + return int(self.bucketed_inference_batch_size[self.bucket_for(float(cost))]) + + def bucketize( + self, + items: list[Any], + cost_fn: Callable[[Any], float], + ) -> list[tuple[list[int], list[Any]]]: + """Re-partition ``items`` into bucket-respecting sub-batches. + + Args: + items: Flat list of tasks or model-input items to partition. + cost_fn: Returns the per-item cost (audio seconds by default). + + Returns: + ``(orig_indices, sub_items)`` tuples whose indices union to + ``range(len(items))``. The finite planner dispatches heavier + sub-batches first to reduce multi-worker tail time; results should + always be realigned by the caller. + + Per-sub-batch invariants: + * all items share one bucket; + * size <= ``max_items_per_batch_by_bucket[bucket]``; + * total cost <= ``max_audio_sec_per_batch`` if set, except a single + over-cost item is its own sub-batch so it always fires. + """ + return [ + (orig_indices, sub_items) + for orig_indices, sub_items, _total_cost in self.bucketize_with_costs(items, cost_fn) + ] + + def bucketize_with_costs( + self, + items: list[Any], + cost_fn: Callable[[Any], float], + ) -> list[tuple[list[int], list[Any], float]]: + """Re-partition ``items`` and return each sub-batch's total cost. + + This is the planning form of :meth:`bucketize`: it computes + ``cost_fn(item)`` once per item, then carries the accumulated sub-batch + cost forward so callers can sort or account without re-inspecting + expensive payloads. + """ + if not items: + return [] + if not self.enabled: + return [(list(range(len(items))), list(items), 0.0)] + + scheduler = BucketQueueScheduler(self, enable_timer=False) + ready_batches: list[ReadyBatch] = [] + for i, it in enumerate(items): + ready_batches.extend(scheduler.enqueue(i, it, float(cost_fn(it)))) + ready_batches.extend(scheduler.flush_all()) + + return [ + (batch.indices, batch.items, batch.total_cost) + for batch in sorted(ready_batches, key=lambda batch: batch.total_cost, reverse=True) + ] + + +def run_bucketed( + items: list[Any], + run_fn: Callable[[list[Any]], list[Any]], + *, + cost_fn: Callable[[Any], float], + policy: BatchPolicy | None = None, +) -> list[Any]: + """Dispatch ``run_fn`` over cost-bucketed sub-batches, preserving order. + + The importable direct-call helper for GPU inference stages, so stages + don't re-implement the bucketize -> dispatch -> reassemble loop. + ``policy=None`` / ``policy.enabled=False`` (or empty ``items``) runs a + single ``run_fn`` call; otherwise each sub-batch is dispatched and results + are realigned to ``items`` order so callers never see the internal bucket + ordering. Scheduler-backed backend execution can pre-bucket stage-specific + work units before calling the stage. + + Args: + items: Flat list of per-item payloads the stage assembled this call. + run_fn: Runs one sub-batch, returning one result per item (1:1, in order). + cost_fn: Returns the per-item cost (audio seconds by default). + policy: Optional bucketing policy; ``None`` or disabled runs a single + batch. + + Returns: + Results aligned 1:1 with ``items``. + + Raises: + RuntimeError: If ``run_fn`` returns a count that mismatches its sub-batch. + """ + if not items: + return [] + + if policy is not None and policy.enabled: + sub_batches = policy.bucketize(items, cost_fn=cost_fn) + else: + sub_batches = [(list(range(len(items))), list(items))] + + results: list[Any] = [None] * len(items) + for sub_indices, sub_items in sub_batches: + if not sub_items: + continue + sub_results = run_fn(sub_items) + if len(sub_results) != len(sub_items): + msg = f"run_fn returned {len(sub_results)} results for {len(sub_items)} items (must match 1:1)" + raise RuntimeError(msg) + for i, r in zip(sub_indices, sub_results, strict=True): + results[i] = r + return results diff --git a/nemo_curator/stages/audio/inference/bucketed_stage.py b/nemo_curator/stages/audio/inference/bucketed_stage.py new file mode 100644 index 0000000000..6e6e8ce072 --- /dev/null +++ b/nemo_curator/stages/audio/inference/bucketed_stage.py @@ -0,0 +1,106 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic cost-bucketed GPU inference stage. + +``BucketedInferenceStage`` factors the model-item dispatch -> reassemble loop +out of individual stages. Backends still own parent-row scheduling and call +``process_batch`` normally. Inside that call, the stage expands parent tasks into +model-input items, optionally applies ``BatchPolicy`` at item level, calls the +adapter, and stitches results back to parent rows. + +Any GPU inference processor implements four hooks: + +* :meth:`build_items` - expand input tasks into flat model-input items; +* :meth:`item_cost` - per-item bucketing cost (audio sec, tokens, ...); +* :meth:`run_inference` - run the model on ONE sub-batch (1:1 results); +* :meth:`assemble` - stitch per-item results back onto the tasks. + +The base :meth:`process_batch` wires these through ``run_bucketed``, which +honors ``batch_policy`` and realigns results to the original item order. +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Generic, TypeVar + +from nemo_curator.stages.audio.inference.batch_policy import run_bucketed +from nemo_curator.stages.base import ProcessingStage, X, Y + +if TYPE_CHECKING: + from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy + +# Model-input item and per-item result types; intentionally unbounded. +ItemT = TypeVar("ItemT") +ResultT = TypeVar("ResultT") + + +class BucketedInferenceStage(ProcessingStage[X, Y], Generic[X, Y, ItemT, ResultT]): + """Abstract cost-bucketed inference stage. + + Subclasses set a ``batch_policy`` (``None`` or ``enabled=False`` = one + sub-batch per call) and implement the four hooks. The base owns the 1:1 + ``process_batch`` contract: exactly one output per input task, in input + order. Optional scheduler hooks may exist for specialized callers, but the + default contract is ordinary backend ``process_batch`` dispatch. + """ + + _is_abstract_root = True # never registered / instantiated directly + batch_policy: BatchPolicy | None = None + + @abstractmethod + def build_items(self, tasks: list[X]) -> tuple[list[ItemT], list[int]]: + """Expand ``tasks`` into flat model-input items. + + Returns ``(items, parent_of)`` where ``parent_of[i]`` is the index of the + task that produced ``items[i]`` (a task may fan out to several items or + none). Reset any per-call accumulators here, as this hook runs first. + """ + + @abstractmethod + def item_cost(self, item: ItemT) -> float: + """Per-item bucketing cost (audio seconds, tokens, pixels, ...).""" + + @abstractmethod + def run_inference(self, items: list[ItemT]) -> list[ResultT]: + """Run the model on ONE sub-batch; return one result per item (1:1).""" + + @abstractmethod + def assemble( + self, + tasks: list[X], + items: list[ItemT], + parent_of: list[int], + results: list[ResultT], + ) -> list[Y]: + """Stitch per-item ``results`` back onto ``tasks`` and write outputs. + + ``items``/``parent_of``/``results`` are index-aligned. Must return + exactly one output task per input task, in input order. + """ + + def process_batch(self, tasks: list[X]) -> list[Y]: + # Ray Data passes columns as numpy ndarrays, so use len() not truthiness + # (``if not tasks`` would raise ValueError). + if len(tasks) == 0: + return [] + items, parent_of = self.build_items(tasks) + results = run_bucketed( + items, + self.run_inference, + cost_fn=self.item_cost, + policy=self.batch_policy, + ) + return self.assemble(tasks, items, parent_of, results) diff --git a/nemo_curator/stages/audio/io/audio_file_reader.py b/nemo_curator/stages/audio/io/audio_file_reader.py new file mode 100644 index 0000000000..488c399ff1 --- /dev/null +++ b/nemo_curator/stages/audio/io/audio_file_reader.py @@ -0,0 +1,324 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Local audio-file reader stage. + +``AudioFileReaderStage`` is the raw audio-byte I/O stage for pipelines that +want to hand an in-memory waveform to downstream processors. Remote object +staging is intentionally outside Curator; launchers such as NvLLMOps should +download Swift/S3 objects and rewrite manifests to node-local paths before +Curator starts. +""" + +import os +import shutil +import subprocess +import time +from dataclasses import dataclass +from typing import Any + +import torch +from loguru import logger + +from nemo_curator.backends.base import NodeInfo, WorkerMetadata +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import AudioTask + +from .waveform_utils import audio_item_id_from_path + + +@dataclass +class AudioFileReaderStage(ProcessingStage[AudioTask, AudioTask]): + """Read a local audio file or local audio segment and emit a waveform. + + Segment mode is driven by ``segment_start_s`` and ``segment_duration_s`` in + task data. The stage uses ffmpeg input seeking so a globally planned segment + can be decoded without loading the full parent audio into memory. + """ + + target_sample_rate: int = 16000 + target_nchannels: int = 1 + audio_filepath_key: str = "audio_filepath" + duration_key: str = "duration" + segment_start_key: str = "segment_start_s" + segment_duration_key: str = "segment_duration_s" + audio_item_id_key: str = "audio_item_id" + waveform_key: str = "waveform" + sample_rate_key: str = "sample_rate" + num_samples_key: str = "num_samples" + skip_me_key: str = "_skip_me" + read_error_key: str = "audio_read_error" + skip_on_read_error: bool = True + ray_num_workers: int | None = None + xenna_num_workers: int | None = None + xenna_num_workers_per_node: int | None = None + verbose: bool = False + name: str = "AudioFileReader" + + def __post_init__(self) -> None: + if self.target_sample_rate <= 0: + msg = f"target_sample_rate must be > 0, got {self.target_sample_rate}" + raise ValueError(msg) + if self.target_nchannels <= 0: + msg = f"target_nchannels must be > 0, got {self.target_nchannels}" + raise ValueError(msg) + self._validate_optional_positive_int("ray_num_workers", self.ray_num_workers) + self._validate_optional_positive_int("xenna_num_workers", self.xenna_num_workers) + self._validate_optional_positive_int("xenna_num_workers_per_node", self.xenna_num_workers_per_node) + if self.xenna_num_workers is not None and self.xenna_num_workers_per_node is not None: + msg = ( + "AudioFileReaderStage: set at most one of xenna_num_workers " + "(cluster-wide) or xenna_num_workers_per_node (per-node)." + ) + raise ValueError(msg) + if not isinstance(self.skip_on_read_error, bool): + msg = f"skip_on_read_error must be bool, got {type(self.skip_on_read_error).__name__}" + raise TypeError(msg) + + @staticmethod + def _validate_optional_positive_int(name: str, value: int | None) -> None: + if value is not None and value <= 0: + msg = f"{name} must be > 0 when set, got {value}" + raise ValueError(msg) + + def setup_on_node( + self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None + ) -> None: + if not shutil.which("ffmpeg"): + msg = "AudioFileReaderStage requires 'ffmpeg'. Install with: sudo apt-get install -y ffmpeg" + raise RuntimeError(msg) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [ + self.audio_filepath_key, + self.audio_item_id_key, + self.duration_key, + self.segment_start_key, + self.segment_duration_key, + self.waveform_key, + self.sample_rate_key, + "is_mono", + self.num_samples_key, + self.skip_me_key, + self.read_error_key, + ] + + def num_workers(self) -> int | None: + if self.xenna_num_workers_per_node is not None: + return self.xenna_num_workers + if self.xenna_num_workers is not None: + return self.xenna_num_workers + return self.ray_num_workers + + def ray_stage_spec(self) -> dict[str, Any]: + if self.num_workers() is None: + return {} + return {RayStageSpecKeys.IS_ACTOR_STAGE: True} + + def xenna_stage_spec(self) -> dict[str, Any]: + spec: dict[str, Any] = {} + if self.xenna_num_workers_per_node is not None: + spec["num_workers_per_node"] = self.xenna_num_workers_per_node + return spec + + def process(self, task: AudioTask) -> AudioTask: + t0 = time.perf_counter() + data_entry = task.data + if self.audio_filepath_key not in data_entry: + msg = "Absolute audio filepath is required" + raise ValueError(msg) + + audio_path = str(data_entry[self.audio_filepath_key]) + segment_start_s = self._optional_seconds(data_entry.get(self.segment_start_key), self.segment_start_key) + segment_duration_s = self._optional_seconds( + data_entry.get(self.segment_duration_key), + self.segment_duration_key, + strictly_positive=True, + ) + data_entry.setdefault(self.audio_item_id_key, audio_item_id_from_path(audio_path)) + if self._is_remote_path(audio_path): + msg = ( + "AudioFileReaderStage only accepts local audio paths. " + "Stage remote Swift/S3 audio with the launcher before Curator starts." + ) + raise ValueError(msg) + + try: + waveform, sample_rate = self._load_waveform( + audio_path, + segment_start_s=segment_start_s, + segment_duration_s=segment_duration_s, + ) + except Exception as exc: + if not self.skip_on_read_error: + raise + logger.warning("Skipping audio row after read failure for {}: {}", audio_path, exc) + return self._mark_read_error(task, audio_path, exc, time.perf_counter() - t0) + + num_samples = int(waveform.shape[-1]) + duration = num_samples / float(sample_rate) + + if segment_start_s is not None: + data_entry[self.segment_start_key] = segment_start_s + if segment_duration_s is not None: + data_entry[self.segment_duration_key] = duration + data_entry[self.waveform_key] = waveform + data_entry[self.sample_rate_key] = sample_rate + data_entry["is_mono"] = waveform.shape[0] == 1 + data_entry[self.num_samples_key] = num_samples + data_entry[self.duration_key] = duration + + metrics = { + "process_time": time.perf_counter() - t0, + "duration": duration, + "waveform_bytes": float(waveform.element_size() * waveform.nelement()), + "audio_file_read": 1.0, + } + if segment_start_s is not None: + metrics["segment_start_s"] = float(segment_start_s) + if segment_duration_s is not None: + metrics["segment_duration_s"] = float(duration) + self._log_metrics(metrics) + return task + + @staticmethod + def _optional_seconds(value: object, key: str, *, strictly_positive: bool = False) -> float | None: + if value is None: + return None + if isinstance(value, bool): + msg = f"{key} must be numeric, got bool" + raise TypeError(msg) + try: + seconds = float(value) + except (TypeError, ValueError) as exc: + msg = f"{key} must be numeric, got {value!r}" + raise TypeError(msg) from exc + if strictly_positive and seconds <= 0: + msg = f"{key} must be > 0 when present, got {seconds}" + raise ValueError(msg) + if not strictly_positive and seconds < 0: + msg = f"{key} must be >= 0 when present, got {seconds}" + raise ValueError(msg) + return seconds + + def _mark_read_error( + self, + task: AudioTask, + audio_path: str, + exc: BaseException, + elapsed_s: float, + ) -> AudioTask: + data_entry = task.data + waveform = torch.empty((self.target_nchannels, 0), dtype=torch.float32) + data_entry[self.waveform_key] = waveform + data_entry[self.sample_rate_key] = self.target_sample_rate + data_entry["is_mono"] = self.target_nchannels == 1 + data_entry[self.num_samples_key] = 0 + data_entry[self.duration_key] = 0.0 + data_entry[self.skip_me_key] = "audio_read_error" + data_entry[self.read_error_key] = f"{type(exc).__name__}: {exc}" + self._log_metrics( + { + "process_time": elapsed_s, + "duration": 0.0, + "waveform_bytes": 0.0, + "audio_file_read": 0.0, + "audio_file_read_errors": 1.0, + "audio_file_skipped": 1.0, + } + ) + logger.debug("Marked {} as skipped due to audio read error", audio_path) + return task + + @staticmethod + def _is_remote_path(path: str) -> bool: + return "://" in str(path) + + def _run_ffmpeg(self, cmd: list[str]) -> subprocess.CompletedProcess[bytes]: + completed = subprocess.run( # noqa: S603 + cmd, + stdin=subprocess.DEVNULL, + capture_output=True, + check=False, + ) + if completed.returncode: + raise subprocess.CalledProcessError( + completed.returncode, + cmd, + output=completed.stdout, + stderr=completed.stderr, + ) + return completed + + def _load_waveform( + self, + input_audio_path: str, + *, + segment_start_s: float | None = None, + segment_duration_s: float | None = None, + ) -> tuple[torch.Tensor, int]: + if self._is_remote_path(input_audio_path): + msg = ( + "AudioFileReaderStage only accepts local audio paths. " + "Stage remote Swift/S3 audio with the launcher before Curator starts." + ) + raise ValueError(msg) + if not os.path.exists(input_audio_path): + raise FileNotFoundError(input_audio_path) + + cmd = ["ffmpeg", "-v", "error"] + if segment_start_s is not None and segment_start_s > 0: + cmd.extend(["-ss", self._format_seconds(segment_start_s)]) + cmd.extend(["-i", input_audio_path]) + if segment_duration_s is not None: + cmd.extend(["-t", self._format_seconds(segment_duration_s)]) + cmd.extend( + [ + "-ar", + str(self.target_sample_rate), + "-ac", + str(self.target_nchannels), + "-f", + "f32le", + "-acodec", + "pcm_f32le", + "pipe:1", + ] + ) + try: + completed = self._run_ffmpeg(cmd) + except subprocess.CalledProcessError as e: + stderr = e.stderr.decode("utf-8", errors="replace") if isinstance(e.stderr, bytes) else e.stderr + msg = f"Error loading waveform from {input_audio_path}: {stderr or e}" + raise RuntimeError(msg) from e + + if not completed.stdout: + msg = f"ffmpeg produced an empty waveform for {input_audio_path}" + raise RuntimeError(msg) + + samples = torch.frombuffer(completed.stdout, dtype=torch.float32) + channels = int(self.target_nchannels) + usable = (samples.numel() // channels) * channels + if usable != samples.numel(): + samples = samples[:usable] + waveform = samples.reshape(-1, channels).transpose(0, 1).contiguous() + return waveform, self.target_sample_rate + + @staticmethod + def _format_seconds(seconds: float) -> str: + return f"{seconds:.6f}".rstrip("0").rstrip(".") diff --git a/nemo_curator/stages/audio/io/manifest_writer_utils.py b/nemo_curator/stages/audio/io/manifest_writer_utils.py new file mode 100644 index 0000000000..deb6e622aa --- /dev/null +++ b/nemo_curator/stages/audio/io/manifest_writer_utils.py @@ -0,0 +1,155 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for audio JSONL manifest writers.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from loguru import logger + +from nemo_curator.stages.audio.metrics.performance import AudioPerformanceSummary + +if TYPE_CHECKING: + from nemo_curator.tasks import AudioTask + from nemo_curator.utils.performance_utils import StagePerfStats + + +def manifest_data( + task: AudioTask, + drop_manifest_keys: tuple[str, ...] = (), + *, + drop_array_like_values: bool = False, +) -> dict[str, Any]: + """Return the manifest row after applying explicit serialization policy.""" + if not drop_manifest_keys and not drop_array_like_values: + return task.data + + data: dict[str, Any] = {} + drop_keys = set(drop_manifest_keys) + for key, value in task.data.items(): + if key in drop_keys: + continue + if drop_array_like_values and hasattr(value, "shape") and hasattr(value, "dtype"): + logger.debug("Dropping array-like manifest key {} from writer output", key) + continue + try: + json.dumps(value, ensure_ascii=False) + except TypeError as exc: + msg = f"Task {task.task_id} contains non-JSON-serializable manifest key {key!r}" + raise TypeError(msg) from exc + data[key] = value + return data + + +def manifest_lines( + tasks: list[AudioTask], + drop_manifest_keys: tuple[str, ...] = (), + *, + drop_array_like_values: bool = False, +) -> list[str]: + """Serialize ``tasks`` to JSONL lines using the shared audio writer rules.""" + return [ + json.dumps( + manifest_data(task, drop_manifest_keys, drop_array_like_values=drop_array_like_values), + ensure_ascii=False, + ) + + "\n" + for task in tasks + ] + + +@dataclass +class AudioManifestWriterMetrics: + """Writer-local metrics and terminal perf-summary accumulator.""" + + stage_name: str + duration_key: str = "duration" + write_perf_stats: bool = False + _perf_summary: AudioPerformanceSummary = field(init=False, repr=False) + _writer_manifest_write_time_s: float = field(default=0.0, repr=False) + _writer_done_write_time_s: float = field(default=0.0, repr=False) + _writer_perf_write_time_s: float = field(default=0.0, repr=False) + _writer_invocation_count: int = field(default=0, repr=False) + _writer_items_processed: int = field(default=0, repr=False) + + def __post_init__(self) -> None: + self._perf_summary = AudioPerformanceSummary(duration_key=self.duration_key) + + @property + def total_utterances(self) -> int: + return self._perf_summary.total_utterances + + @property + def shard_keys(self) -> list[str]: + return self._perf_summary.shard_keys + + @property + def items_processed(self) -> int: + return self._writer_items_processed + + def reset_wall_timer(self) -> None: + self._perf_summary.reset_wall_timer() + + def record_invocation(self, item_count: int) -> None: + self._writer_invocation_count += 1 + self._writer_items_processed += item_count + + def add_manifest_write_time(self, elapsed_s: float) -> None: + self._writer_manifest_write_time_s += elapsed_s + + def add_done_write_time(self, elapsed_s: float) -> None: + self._writer_done_write_time_s += elapsed_s + + def add_perf_write_time(self, elapsed_s: float) -> None: + self._writer_perf_write_time_s += elapsed_s + + def record_task(self, task: AudioTask, shard_key: str | None = None) -> None: + self._perf_summary.record_task(task, shard_key=shard_key, include_stage_perf=self.write_perf_stats) + + def shard_count(self, shard_key: str) -> int: + return self._perf_summary.shard_count(shard_key) + + def build_writer_summary(self) -> dict[str, Any]: + writer_total_time = ( + self._writer_manifest_write_time_s + self._writer_done_write_time_s + self._writer_perf_write_time_s + ) + return { + "total_process_time_s": writer_total_time, + "total_items_processed": float(self._writer_items_processed), + "invocation_count": float(self._writer_invocation_count), + "throughput_items_per_s": ( + float(self._writer_items_processed) / writer_total_time if writer_total_time > 0 else 0.0 + ), + "custom_metrics_sum": { + "manifest_write_time_s": self._writer_manifest_write_time_s, + "done_marker_write_time_s": self._writer_done_write_time_s, + "perf_write_time_s": self._writer_perf_write_time_s, + "writer_process_calls": float(self._writer_invocation_count), + "writer_invocation_count": float(self._writer_invocation_count), + "writer_items_processed": float(self._writer_items_processed), + }, + } + + def build_perf_summary(self) -> dict[str, Any]: + return self._perf_summary.build_summary(extra_stage_summaries={self.stage_name: self.build_writer_summary()}) + + def build_external_stage_summary(self, perf_stats: StagePerfStats) -> dict[str, Any] | None: + """Render one externally collected perf record in the normal stage-summary shape.""" + perf_summary = AudioPerformanceSummary(duration_key=self.duration_key) + perf_summary.record_stage_perf([perf_stats]) + return perf_summary.build_stage_summaries().get(perf_stats.stage_name) diff --git a/nemo_curator/stages/audio/io/nemo_tarred_reader.py b/nemo_curator/stages/audio/io/nemo_tarred_reader.py new file mode 100644 index 0000000000..e0ead8d75c --- /dev/null +++ b/nemo_curator/stages/audio/io/nemo_tarred_reader.py @@ -0,0 +1,669 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reader for NeMo-style tarred audio datasets (e.g. Granary YAML configs). + +Decomposes into a shard-discovery stage (parses the YAML) and a shard-reader +stage (streams each local tar, decodes audio in memory via lhotse/soundfile, +emits one ``AudioTask`` per utterance; nothing is written to disk). +""" + +from __future__ import annotations + +import json +import os +import re +import tarfile +import time +from dataclasses import dataclass +from io import BytesIO +from typing import Any, BinaryIO + +import soundfile as sf +import yaml +from loguru import logger + +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.base import CompositeStage, ProcessingStage +from nemo_curator.tasks import AudioTask, EmptyTask, FileGroupTask + + +def _expand_nemo_path(pattern: str) -> list[str]: + """Expand NeMo brace patterns like ``__OP_0..N_CL_``.""" + match = re.search(r"_OP_(\d+)\.\.(\d+)_CL_", pattern) + if not match: + return [pattern] + start, end = int(match.group(1)), int(match.group(2)) + if end < start: + msg = f"NeMo brace range must be ascending, got {start}..{end} in {pattern!r}" + raise ValueError(msg) + prefix = pattern[: match.start()] + suffix = pattern[match.end() :] + return [f"{prefix}{i}{suffix}" for i in range(start, end + 1)] + + +def _open_tar(tar_path: str) -> tarfile.TarFile: + """Open a local tar via lhotse's ``open_best`` in streaming mode (``r|*``).""" + from lhotse.serialization import open_best + + if not os.path.exists(tar_path): + msg = f"Tar file not found: {tar_path}" + raise FileNotFoundError(msg) + fileobj = open_best(tar_path, mode="rb") + return tarfile.open(fileobj=fileobj, mode="r|*") + + +def _open_text_stream(path: str) -> BinaryIO: + """Open a local text file as a binary stream.""" + from lhotse.serialization import open_best + + if not os.path.exists(path): + msg = f"Text file not found: {path}" + raise FileNotFoundError(msg) + return open_best(path, mode="rb") + + +def _normalize_audio_path(path: str) -> str: + """Strip leading ``./`` so manifest and tar-member paths compare consistently.""" + return path.lstrip("./") + + +def _path_suffix_overlap(a: list[str], b: list[str]) -> int: + """Count shared trailing path components between two split paths.""" + n = 0 + for x, y in zip(reversed(a), reversed(b), strict=False): + if x != y: + break + n += 1 + return n + + +class _ManifestIndex: + """Resolve a tar member name to its manifest entry, collision-safe. + + Entries are keyed by normalized full path. Lookup prefers an exact + full-path match, then falls back to basename. When several entries share a + basename, the candidate with the longest trailing path-component overlap + wins; genuine ties resolve to no match rather than an arbitrary one. + """ + + def __init__(self) -> None: + self._by_path: dict[str, dict] = {} + self._dup_paths: set[str] = set() + self._basename_to_paths: dict[str, list[str]] = {} + + def add(self, audio_path: str, entry: dict) -> None: + norm = _normalize_audio_path(audio_path) + if norm in self._by_path and self._by_path[norm] != entry: + self._dup_paths.add(norm) # same path, different content -> ambiguous + else: + self._by_path[norm] = entry + + def finalize(self) -> None: + for path in self._dup_paths: + self._by_path.pop(path, None) + self._basename_to_paths = {} + for norm in self._by_path: + self._basename_to_paths.setdefault(os.path.basename(norm), []).append(norm) + + def match(self, member_name: str) -> dict | None: + norm = _normalize_audio_path(member_name) + entry = self._by_path.get(norm) + if entry is not None: + return entry + candidates = self._basename_to_paths.get(os.path.basename(norm)) + if not candidates: + return None + if len(candidates) == 1: + return self._by_path.get(candidates[0]) + member_parts = norm.split("/") + best, best_overlap, tied = None, 0, False + for cand in candidates: + overlap = _path_suffix_overlap(member_parts, cand.split("/")) + if overlap > best_overlap: + best, best_overlap, tied = cand, overlap, False + elif overlap == best_overlap: + tied = True + if best is None or tied: + return None + return self._by_path.get(best) + + def __getitem__(self, member_name: str) -> dict: + entry = self.match(member_name) + if entry is None: + raise KeyError(member_name) + return entry + + def __contains__(self, member_name: str) -> bool: + return self.match(member_name) is not None + + +def _iter_discovery_groups(config: object, yaml_path: str) -> list[dict[str, Any]]: + """Validate the Granary discovery YAML root and return corpus-group dicts. + + Require a list of mappings at the top level (safe_load can return None / + scalar / string, which would crash or silently mis-parse on iteration); + skip non-mapping entries with a warning. + """ + if config is None: + msg = f"Granary YAML at {yaml_path} is empty (safe_load returned None)" + raise ValueError(msg) + if not isinstance(config, list): + msg = f"Granary YAML at {yaml_path} must be a list of corpus-group mappings, got {type(config).__name__}" + raise TypeError(msg) + + groups: list[dict[str, Any]] = [] + for idx, group in enumerate(config): + if not isinstance(group, dict): + logger.warning( + "Skipping non-mapping entry at index {} in {} (got {})", + idx, + yaml_path, + type(group).__name__, + ) + continue + groups.append(group) + return groups + + +def _iter_input_cfg_entries(group: dict[str, Any], yaml_path: str) -> list[dict[str, Any]]: + """Return validated ``input_cfg`` corpus entries from one top-level group.""" + raw = group.get("input_cfg", []) + if raw is None: + return [] + if not isinstance(raw, list): + logger.warning( + "Skipping corpus group in {} with non-list input_cfg (got {})", + yaml_path, + type(raw).__name__, + ) + return [] + + entries: list[dict[str, Any]] = [] + for idx, cfg in enumerate(raw): + if not isinstance(cfg, dict): + logger.warning( + "Skipping non-mapping input_cfg entry at index {} in {}", + idx, + yaml_path, + ) + continue + entries.append(cfg) + return entries + + +@dataclass +class NemoTarShardDiscoveryStage(ProcessingStage[EmptyTask, FileGroupTask]): + """Parse a Granary YAML config and emit one ``FileGroupTask`` per shard. + + Each emitted task has ``data = [manifest_path, tar_path]``. + + Args: + yaml_path: Path to the Granary YAML data config. + corpus_filter: Include only these corpora (``None`` = all). + """ + + yaml_path: str + name: str = "nemo_tar_shard_discovery" + corpus_filter: list[str] | None = None + output_dir: str | None = None + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], [] + + def xenna_stage_spec(self) -> dict[str, Any]: + return {"num_workers_per_node": 1} + + def ray_stage_spec(self) -> dict[str, Any]: + return {RayStageSpecKeys.IS_FANOUT_STAGE: True} + + def _scan_completed_shards(self) -> set[str]: + """Return shard keys with a ``.done`` marker under output_dir (resume skip set). + + Keys are relative paths like ``yodas/0_from_captions/en/sharded_manifests/manifest_42``. + """ + if not self.output_dir: + return set() + completed: set[str] = set() + if not os.path.isdir(self.output_dir): + return completed + for root, _dirs, files in os.walk(self.output_dir): + for fname in files: + if fname.endswith(".jsonl.done"): + rel = os.path.relpath(os.path.join(root, fname), self.output_dir) + shard_key = rel[: -len(".jsonl.done")] + completed.add(shard_key) + return completed + + @staticmethod + def _manifest_to_rel_path(manifest_path: str, corpus: str) -> str: + """Extract the shard-key path from a manifest path, starting at the corpus name. + + Example: ``/data/yodas/.../manifest_42.jsonl`` + corpus ``yodas`` -> + ``yodas/.../manifest_42`` (the ``.jsonl`` extension is stripped). + The corpus name must appear exactly once for unambiguous extraction. + """ + parts = manifest_path.replace("\\", "/").split("/") + parts_lower = [p.lower() for p in parts] + corpus_lower = corpus.lower() + matches = [i for i, p in enumerate(parts_lower) if p == corpus_lower] + if len(matches) == 0: + msg = ( + f"Corpus name '{corpus}' not found in manifest path: {manifest_path}. " + f"The YAML 'corpus' field must match a directory component in the manifest path (case-insensitive)." + ) + raise ValueError(msg) + if len(matches) > 1: + msg = ( + f"Corpus name '{corpus}' appears {len(matches)} times in manifest path: {manifest_path}. " + f"It must appear exactly once for unambiguous path extraction." + ) + raise ValueError(msg) + idx = matches[0] + rel = "/".join(parts[idx:]) + if rel.endswith(".jsonl"): + rel = rel[: -len(".jsonl")] + elif rel.endswith(".json"): + rel = rel[: -len(".json")] + return rel + + def process(self, _task: EmptyTask) -> list[FileGroupTask]: # noqa: C901, PLR0915 + t0 = time.perf_counter() + completed = self._scan_completed_shards() + if completed: + logger.info(f"Checkpoint: {len(completed)} shards already completed, will skip them") + logger.info(f"Completed shard keys (first 10): {sorted(completed)[:10]}") + + with open(self.yaml_path) as f: + config = yaml.safe_load(f) + + tasks: list[FileGroupTask] = [] + corpora_seen = 0 + shards_seen = 0 + skipped = 0 + invalid_corpora = 0 + invalid_shards = 0 + for group in _iter_discovery_groups(config, self.yaml_path): + for cfg in _iter_input_cfg_entries(group, self.yaml_path): + corpus = cfg.get("corpus", "unknown") + corpora_seen += 1 + if self.corpus_filter and corpus not in self.corpus_filter: + continue + if cfg.get("type", "nemo_tarred") != "nemo_tarred": + logger.warning(f"Skipping non-nemo_tarred corpus {corpus} (type={cfg.get('type')})") + continue + manifest_pattern = cfg.get("manifest_filepath") + tar_pattern = cfg.get("tarred_audio_filepaths") + if not isinstance(manifest_pattern, str) or not isinstance(tar_pattern, str): + invalid_corpora += 1 + logger.warning( + "Skipping corpus {} in {}: manifest_filepath and tarred_audio_filepaths are required strings", + corpus, + self.yaml_path, + ) + continue + manifest_paths = _expand_nemo_path(manifest_pattern) + tar_paths = _expand_nemo_path(tar_pattern) + if len(manifest_paths) != len(tar_paths): + msg = ( + f"Manifest/tar count mismatch for corpus={corpus}: " + f"{len(manifest_paths)} manifests vs {len(tar_paths)} tars" + ) + raise ValueError(msg) + for mp, tp in zip(manifest_paths, tar_paths, strict=False): + shards_seen += 1 + try: + shard_key = self._manifest_to_rel_path(mp, corpus) + except ValueError as exc: + invalid_shards += 1 + logger.warning("Skipping manifest {} for corpus {}: {}", mp, corpus, exc) + continue + if shard_key in completed: + skipped += 1 + continue + if self.output_dir: + partial = os.path.join(self.output_dir, f"{shard_key}.jsonl") + if os.path.exists(partial): + os.remove(partial) + logger.info(f"Removed partial output for {shard_key}") + shard_task = FileGroupTask( + dataset_name=corpus, + data=[mp, tp], + reader_config={"corpus": corpus, "shard_key": shard_key}, + ) + shard_task.task_id = shard_key + tasks.append(shard_task) + + logger.info( + f"NemoTarShardDiscoveryStage: found {len(tasks)} shards, skipped {skipped} completed (corpus_filter={self.corpus_filter})" + ) + # ``total_items_emitted``: framework ``num_items_processed`` counts 0 for + # stages that synthesise work from config, so the summary builder falls + # back to this when the framework count is 0. + self._log_metrics( + { + "input_tasks": 1.0, + "output_tasks": float(len(tasks)), + "total_items_emitted": float(len(tasks)), + "corpora_seen": float(corpora_seen), + "shards_seen": float(shards_seen), + "shards_emitted": float(len(tasks)), + "shards_skipped_completed": float(skipped), + "corpora_skipped_invalid": float(invalid_corpora), + "shards_skipped_invalid": float(invalid_shards), + "discovery_time_s": time.perf_counter() - t0, + } + ) + return tasks + + +@dataclass +class NemoTarShardReaderStage(ProcessingStage[FileGroupTask, AudioTask]): + """Read a single NeMo tar shard and emit one ``AudioTask`` per utterance. + + Expects ``task.data = [manifest_path, tar_path]`` from + ``NemoTarShardDiscoveryStage``. Audio is decoded in memory (lhotse + ``open_best`` for tar streaming, ``soundfile`` for decode); nothing is + written to disk. Each ``AudioTask`` carries a 1-D mono float32 waveform in + ``task.data["waveform"]`` and the native rate in ``sample_rate`` (plus a + ``sampling_rate`` alias for Granary v2 reference-script compatibility). + + Args: + filepath_key: Manifest key for the audio filename inside the tar. + duration_key: Manifest key for utterance duration (seconds). + max_duration_s: Optional upper bound on emitted utterance duration. + """ + + name: str = "nemo_tar_shard_reader" + filepath_key: str = "audio_filepath" + duration_key: str = "duration" + max_duration_s: float | None = None + max_utterances_per_shard: int | None = None + # Used ONLY to write a ``.done`` marker for zero-utterance shards, which + # never reach the writer and would otherwise be re-queued forever on resume. + # ``None`` -> no checkpoint dir (single-rank tutorial); resume not a concern. + output_dir: str | None = None + + def __post_init__(self) -> None: + if self.max_duration_s is not None and self.max_duration_s <= 0: + msg = "max_duration_s must be positive when set" + raise ValueError(msg) + if self.max_utterances_per_shard is not None and self.max_utterances_per_shard <= 0: + msg = "max_utterances_per_shard must be positive when set" + raise ValueError(msg) + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], ["waveform", "sample_rate", "sampling_rate", "corpus", "num_channels"] + + def num_workers(self) -> int | None: + # Pin to 1 to bound memory (a shard's waveforms are held in memory). + return 1 + + def ray_stage_spec(self) -> dict[str, Any]: + # IS_ACTOR_STAGE makes Ray Data honor num_workers()=1 (cluster-wide + # analog of the per-node Xenna pin); IS_FANOUT_STAGE keeps 1-row/block. + return {RayStageSpecKeys.IS_FANOUT_STAGE: True, RayStageSpecKeys.IS_ACTOR_STAGE: True} + + def xenna_stage_spec(self) -> dict[str, Any]: + # One reader actor per node: node-local decode, ~1 shard of memory/node. + return {"num_workers_per_node": 1} + + def _read_manifest(self, path: str) -> tuple[_ManifestIndex, int]: + index = _ManifestIndex() + entry_count = 0 + skipped_lines = 0 + with _open_text_stream(path) as f: + for raw_line in f: + line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line + stripped = line.strip() + if not stripped: + continue + entry_count += 1 + try: + entry = json.loads(stripped) + except json.JSONDecodeError: + skipped_lines += 1 + logger.warning("Skipping invalid JSON line in manifest {}", path) + continue + audio_path = entry.get(self.filepath_key) + if audio_path is None: + skipped_lines += 1 + logger.warning( + "Skipping manifest line missing {!r} in {}", + self.filepath_key, + path, + ) + continue + index.add(str(audio_path), entry) + if skipped_lines: + logger.warning( + "Manifest {}: skipped {} line(s) (invalid JSON or missing {})", + path, + skipped_lines, + self.filepath_key, + ) + index.finalize() + return index, entry_count + + def _mark_empty_shard_done(self, shard_key: str) -> None: + """Write a ``.jsonl.done`` (count 0) for a zero-utterance shard. + + Mirrors the writer's marker so discovery skips the shard on resume; no + ``.jsonl`` is written. No-op when ``output_dir`` is unset. + """ + if not self.output_dir: + return + done_path = os.path.join(self.output_dir, f"{shard_key}.jsonl.done") + try: + os.makedirs(os.path.dirname(done_path), exist_ok=True) + with open(done_path, "w") as f: + f.write("0\n") + logger.info(f"Shard {shard_key}: 0 utterances, wrote empty-shard marker {done_path}") + except OSError as exc: + logger.warning("Failed to write empty-shard marker for {}: {}", shard_key, exc) + + def process(self, task: FileGroupTask) -> list[AudioTask]: # noqa: C901, PLR0912, PLR0915 + t0 = time.perf_counter() + manifest_path, tar_path = task.data[0], task.data[1] + corpus = task.reader_config.get("corpus", "unknown") + shard_key = task.reader_config.get("shard_key", task.task_id) + + manifest_t0 = time.perf_counter() + manifest, manifest_entry_count = self._read_manifest(manifest_path) + manifest_elapsed = time.perf_counter() - manifest_t0 + + logger.info(f"Reading shard {shard_key}: {tar_path} ({manifest_entry_count} manifest entries)") + + open_t0 = time.perf_counter() + tar = _open_tar(tar_path) + tar_open_elapsed = time.perf_counter() - open_t0 + results: list[AudioTask] = [] + tar_members_seen = 0 + audio_members_matched = 0 + corrupt_audio_count = 0 + duration_filtered_count = 0 + decoded_audio_seconds = 0.0 + decoded_waveform_bytes = 0.0 + decode_elapsed = 0.0 + + try: + for tar_info in tar: + tar_members_seen += 1 + if not tar_info.isfile(): + continue + manifest_entry = manifest.match(tar_info.name) + if manifest_entry is None: + continue + + entry = dict(manifest_entry) + # Cheap pre-decode skip when the manifest has a usable duration. + # Missing / non-numeric durations are re-checked post-decode below + # so they cannot bypass the cap. + if self.max_duration_s is not None and self.duration_key in entry: + try: + duration_s = float(entry[self.duration_key]) + except (TypeError, ValueError): + duration_s = None + if duration_s is not None and duration_s > self.max_duration_s: + duration_filtered_count += 1 + continue + + fobj = tar.extractfile(tar_info) + if fobj is None: + corrupt_audio_count += 1 + logger.warning(f"Skipping non-regular tar member {tar_info.name} in {tar_path}") + continue + raw_audio = fobj.read() + try: + decode_t0 = time.perf_counter() + audio, sample_rate = sf.read(BytesIO(raw_audio), dtype="float32") + decode_elapsed += time.perf_counter() - decode_t0 + # Only genuine decode/format failures count as "corrupt audio"; + # resource/dependency errors are excluded so they propagate + # instead of being mislabeled and skipped. + except (sf.LibsndfileError, ValueError, EOFError) as exc: + corrupt_audio_count += 1 + logger.warning(f"Skipping corrupt audio {tar_info.name} in {tar_path}: {exc}") + continue + + num_channels = audio.shape[1] if audio.ndim > 1 else 1 + if audio.ndim > 1: + audio = audio.mean(axis=1) + utt_seconds = float(audio.shape[0]) / float(sample_rate) if sample_rate else 0.0 + # Post-decode duration enforcement using the authoritative decoded + # length (covers rows the pre-decode skip could not bound). + if self.max_duration_s is not None and utt_seconds > self.max_duration_s: + duration_filtered_count += 1 + continue + + audio_members_matched += 1 + decoded_waveform_bytes += float(getattr(audio, "nbytes", 0)) + decoded_audio_seconds += utt_seconds + + entry["waveform"] = audio + entry["sample_rate"] = sample_rate + entry["sampling_rate"] = sample_rate + entry["num_channels"] = num_channels + entry["corpus"] = corpus + + audio_task = AudioTask( + dataset_name=corpus, + data=entry, + _metadata={**task._metadata, "_shard_key": shard_key}, + _stage_perf=list(task._stage_perf), + ) + audio_task.task_id = f"{shard_key}_{tar_info.name}" + results.append(audio_task) + if self.max_utterances_per_shard and len(results) >= self.max_utterances_per_shard: + logger.info( + "Shard {}: reached max_utterances_per_shard={} and stopped early", + shard_key, + self.max_utterances_per_shard, + ) + break + finally: + tar.close() + + shard_total = len(results) + for result_task in results: + result_task._metadata["_shard_total"] = shard_total + + # Empty / fully-filtered shard never reaches the writer, so mark it done + # here so resume skips it instead of re-queueing it indefinitely. + if shard_total == 0: + self._mark_empty_shard_done(shard_key) + + logger.info(f"Shard {shard_key}: emitted {shard_total} AudioTasks") + self._log_metrics( + { + "input_shards": 1.0, + "output_tasks": float(shard_total), + "output_utterances": float(shard_total), + "manifest_entries": float(manifest_entry_count), + "tar_members_seen": float(tar_members_seen), + "audio_members_decoded": float(audio_members_matched), + "corrupt_audio_count": float(corrupt_audio_count), + "duration_filtered_count": float(duration_filtered_count), + "utterances_emitted": float(shard_total), + "utterance_limit_hit": float( + bool(self.max_utterances_per_shard and shard_total >= self.max_utterances_per_shard) + ), + "audio_duration_s": decoded_audio_seconds, + "waveform_bytes": decoded_waveform_bytes, + "manifest_read_time_s": manifest_elapsed, + "tar_open_time_s": tar_open_elapsed, + "audio_decode_time_s": decode_elapsed, + "reader_total_time_s": time.perf_counter() - t0, + } + ) + return results + + +@dataclass +class NemoTarredAudioReader(CompositeStage[EmptyTask, AudioTask]): + """Read NeMo-style tarred audio datasets from a Granary YAML config. + + Decomposes into ``NemoTarShardDiscoveryStage`` (parse YAML -> one + ``FileGroupTask`` per shard) then ``NemoTarShardReaderStage`` (stream each + tar, decode in memory -> ``AudioTask`` with waveform arrays). + + Args: + yaml_path: Path to the Granary YAML data config. + corpus_filter: Process only these corpora (``None`` = all). + filepath_key: Manifest key for audio filenames inside tar archives. + duration_key: Manifest key for utterance duration (seconds). + max_duration_s: Optional upper bound on emitted utterance duration. + """ + + yaml_path: str + name: str = "nemo_tarred_audio_reader" + corpus_filter: list[str] | None = None + filepath_key: str = "audio_filepath" + duration_key: str = "duration" + max_duration_s: float | None = None + output_dir: str | None = None + max_utterances_per_shard: int | None = None + + def __post_init__(self) -> None: + super().__init__() + + self._stages: list[ProcessingStage] = [ + NemoTarShardDiscoveryStage( + yaml_path=self.yaml_path, + corpus_filter=self.corpus_filter, + output_dir=self.output_dir, + ), + NemoTarShardReaderStage( + filepath_key=self.filepath_key, + duration_key=self.duration_key, + max_duration_s=self.max_duration_s, + max_utterances_per_shard=self.max_utterances_per_shard, + output_dir=self.output_dir, + ), + ] + + def inputs(self) -> tuple[list[str], list[str]]: + return self._stages[0].inputs() + + def outputs(self) -> tuple[list[str], list[str]]: + return self._stages[-1].outputs() + + def decompose(self) -> list[ProcessingStage]: + return self._stages diff --git a/nemo_curator/stages/audio/io/sharded_manifest_writer.py b/nemo_curator/stages/audio/io/sharded_manifest_writer.py new file mode 100644 index 0000000000..64c5313ba7 --- /dev/null +++ b/nemo_curator/stages/audio/io/sharded_manifest_writer.py @@ -0,0 +1,292 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sharded Manifest Writer -- writes per-shard JSONL files mirroring input paths with .done markers.""" + +import json +import os +import time +from dataclasses import dataclass, field +from typing import Any + +from loguru import logger + +from nemo_curator.backends.base import NodeInfo, WorkerMetadata +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.audio.io.manifest_writer_utils import ( + AudioManifestWriterMetrics, + manifest_lines, +) +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import AudioTask, FileGroupTask + + +@dataclass +class ShardedManifestWriterStage(ProcessingStage[AudioTask, FileGroupTask]): + """Write AudioTasks to per-shard JSONL files mirroring the input manifest paths. + + Output mirrors the input manifest paths, e.g. + ``output_dir/yodas/.../manifest_42.jsonl`` plus a ``.jsonl.done`` marker, with + an aggregate ``perf_summary.json`` at the root. The shard key comes from + ``task._metadata["_shard_key"]`` (set by ``NemoTarShardReaderStage`` as a + relative path). + + Args: + output_dir: Root directory for output manifests. + final_manifest_path: Optional aggregate JSONL rebuilt from completed + shard outputs at teardown; sharded files stay primary. + write_perf_stats: If True, record per-task stage perf into the aggregate + and refresh ``perf_summary.json`` on each shard completion. + """ + + output_dir: str + name: str = "sharded_manifest_writer" + final_manifest_path: str | None = None + write_perf_stats: bool = True + duration_key: str = "duration" + drop_manifest_keys: tuple[str, ...] = ("waveform",) + _writer_metrics: AudioManifestWriterMetrics = field(init=False, repr=False) + _final_shards_materialized: set[str] = field(default_factory=set, init=False, repr=False) + + def __post_init__(self) -> None: + self._writer_metrics = AudioManifestWriterMetrics( + stage_name=self.name, + duration_key=self.duration_key, + write_perf_stats=self.write_perf_stats, + ) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def _has_completed_shards(self) -> bool: + if not os.path.isdir(self.output_dir): + return False + for _root, _dirs, files in os.walk(self.output_dir): + if any(name.endswith(".jsonl.done") for name in files): + return True + return False + + def setup_on_node( + self, + _node_info: NodeInfo | None = None, + _worker_metadata: WorkerMetadata | None = None, + ) -> None: + os.makedirs(self.output_dir, exist_ok=True) + self._final_shards_materialized = set() + if self.final_manifest_path: + final_parent = os.path.dirname(self.final_manifest_path) + if final_parent: + os.makedirs(final_parent, exist_ok=True) + if os.path.exists(self.final_manifest_path): + if self._has_completed_shards(): + self._final_shards_materialized.update(self._completed_shard_keys()) + logger.info( + "Preserving final manifest until teardown rebuild: {}", + self.final_manifest_path, + ) + else: + os.remove(self.final_manifest_path) + self._writer_metrics.reset_wall_timer() + logger.info(f"ShardedManifestWriterStage: output_dir={self.output_dir}") + + @staticmethod + def _shard_key_of(task: AudioTask) -> str: + return task._metadata.get("_shard_key", "unknown/shard_0") + + def _write_shard_group(self, shard_key: str, group: list[AudioTask]) -> str: + """Persist all utterances of one shard with one open/close per file. + + Rows are serialized in memory and written with a single ``writelines`` + (one open per shard manifest, not one per utterance). + """ + out_path = os.path.join(self.output_dir, f"{shard_key}.jsonl") + os.makedirs(os.path.dirname(out_path), exist_ok=True) + + lines = manifest_lines(group, self.drop_manifest_keys) + write_t0 = time.perf_counter() + with open(out_path, "a", encoding="utf-8") as f: + f.writelines(lines) + self._writer_metrics.add_manifest_write_time(time.perf_counter() - write_t0) + + for task in group: + self._writer_metrics.record_task(task, shard_key=shard_key) + + # Completion: the reader stamps every utterance with the shard's total. + shard_total = group[-1]._metadata.get("_shard_total", 0) + if shard_total > 0 and self._writer_metrics.shard_count(shard_key) >= shard_total: + done_path = os.path.join(self.output_dir, f"{shard_key}.jsonl.done") + done_t0 = time.perf_counter() + with open(done_path, "w") as f: + f.write(f"{self._writer_metrics.shard_count(shard_key)}\n") + self._writer_metrics.add_done_write_time(time.perf_counter() - done_t0) + logger.info( + f"Shard {shard_key} complete: " + f"{self._writer_metrics.shard_count(shard_key)} utterances, wrote {done_path}" + ) + self._append_completed_shard_to_final(shard_key, out_path) + if self.write_perf_stats: + self._write_perf_summary() + return out_path + + def process(self, task: AudioTask) -> FileGroupTask: + return self.process_batch([task])[0] + + def process_batch(self, tasks: list[AudioTask]) -> list[FileGroupTask]: + if len(tasks) == 0: + return [] + for task in tasks: + if not self.validate_input(task): + msg = f"Task {task.task_id} missing required columns for {type(self).__name__}: {self.inputs()}" + raise ValueError(msg) + self._writer_metrics.record_invocation(len(tasks)) + + # Group by shard (dict preserves first-seen order) so each shard writes + # in a single open/append rather than once per utterance. + groups: dict[str, list[AudioTask]] = {} + for task in tasks: + groups.setdefault(self._shard_key_of(task), []).append(task) + out_path_by_shard = { + shard_key: self._write_shard_group(shard_key, group) for shard_key, group in groups.items() + } + + output_tasks = [] + for task in tasks: + output_task = FileGroupTask( + dataset_name=task.dataset_name, + data=[out_path_by_shard[self._shard_key_of(task)]], + _metadata=task._metadata, + _stage_perf=task._stage_perf, + ) + output_task.task_id = task.task_id + output_tasks.append(output_task) + return output_tasks + + def _write_perf_summary(self) -> None: + """Write aggregate perf_summary.json at the output root.""" + summary = self._writer_metrics.build_perf_summary() + summary_path = os.path.join(self.output_dir, "perf_summary.json") + write_t0 = time.perf_counter() + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + self._writer_metrics.add_perf_write_time(time.perf_counter() - write_t0) + logger.info(f"Wrote perf_summary.json: {summary_path}") + + def _completed_shard_manifest_paths(self) -> list[str]: + """Return completed shard JSONL paths, excluding the aggregate final manifest.""" + if not os.path.isdir(self.output_dir): + return [] + final_abs = os.path.abspath(self.final_manifest_path) if self.final_manifest_path else "" + paths: list[str] = [] + for root, _dirs, files in os.walk(self.output_dir): + for fname in files: + if not fname.endswith(".jsonl.done"): + continue + manifest_path = os.path.join(root, fname[: -len(".done")]) + if not os.path.isfile(manifest_path): + continue + if final_abs and os.path.abspath(manifest_path) == final_abs: + continue + paths.append(manifest_path) + return sorted(paths) + + def _completed_shard_keys(self) -> set[str]: + """Return shard keys whose done markers are present.""" + output_abs = os.path.abspath(self.output_dir) + keys: set[str] = set() + for manifest_path in self._completed_shard_manifest_paths(): + rel_path = os.path.relpath(os.path.abspath(manifest_path), output_abs) + keys.add(rel_path.removesuffix(".jsonl")) + return keys + + def _append_completed_shard_to_final(self, shard_key: str, shard_path: str) -> None: + """Append one completed shard into the aggregate manifest for eager consumers.""" + if not self.final_manifest_path or shard_key in self._final_shards_materialized: + return + final_parent = os.path.dirname(self.final_manifest_path) + if final_parent: + os.makedirs(final_parent, exist_ok=True) + + write_t0 = time.perf_counter() + with ( + open(self.final_manifest_path, "a", encoding="utf-8") as out_f, + open( + shard_path, + encoding="utf-8", + ) as in_f, + ): + out_f.writelines(in_f) + self._writer_metrics.add_manifest_write_time(time.perf_counter() - write_t0) + self._final_shards_materialized.add(shard_key) + logger.info( + "Appended completed shard {} into final manifest {}", + shard_key, + self.final_manifest_path, + ) + + def _write_final_manifest_from_shards(self) -> None: + """Rebuild the aggregate final manifest from completed shard outputs.""" + if not self.final_manifest_path: + return + final_parent = os.path.dirname(self.final_manifest_path) + if final_parent: + os.makedirs(final_parent, exist_ok=True) + + shard_paths = self._completed_shard_manifest_paths() + tmp_path = f"{self.final_manifest_path}.tmp" + write_t0 = time.perf_counter() + with open(tmp_path, "w", encoding="utf-8") as out_f: + for shard_path in shard_paths: + with open(shard_path, encoding="utf-8") as in_f: + out_f.writelines(in_f) + os.replace(tmp_path, self.final_manifest_path) + self._writer_metrics.add_manifest_write_time(time.perf_counter() - write_t0) + self._final_shards_materialized = self._completed_shard_keys() + logger.info( + "Rebuilt final manifest {} from {} completed shard file(s)", + self.final_manifest_path, + len(shard_paths), + ) + + def teardown(self) -> None: + self._write_final_manifest_from_shards() + + total = self._writer_metrics.total_utterances + done = sum( + 1 + for k in self._writer_metrics.shard_keys + if os.path.exists(os.path.join(self.output_dir, f"{k}.jsonl.done")) + ) + logger.info( + f"ShardedManifestWriter: {total} utterances across " + f"{len(self._writer_metrics.shard_keys)} shards, {done} completed with .done" + ) + + if self.write_perf_stats and ( + self._writer_metrics.items_processed > 0 or self._writer_metrics.total_utterances > 0 + ): + self._write_perf_summary() + elif self.write_perf_stats: + logger.info("Skipping perf_summary.json write because no tasks were processed") + + def num_workers(self) -> int | None: + return 1 + + def ray_stage_spec(self) -> dict[str, Any]: + return {RayStageSpecKeys.IS_ACTOR_STAGE: True} + + def xenna_stage_spec(self) -> dict[str, Any]: + return {} diff --git a/nemo_curator/stages/audio/io/waveform_utils.py b/nemo_curator/stages/audio/io/waveform_utils.py new file mode 100644 index 0000000000..9c0ac15e88 --- /dev/null +++ b/nemo_curator/stages/audio/io/waveform_utils.py @@ -0,0 +1,91 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared waveform helpers for audio I/O stages.""" + +import hashlib +import os +from urllib.parse import urlparse + +import torch + +_WAVEFORM_2D_NDIM = 2 + + +def audio_item_id_from_path(audio_path: str) -> str: + parsed = urlparse(str(audio_path)) + basename = os.path.basename(parsed.path if parsed.scheme else str(audio_path)) + stem = os.path.splitext(basename)[0] or "audio" + path_hash = hashlib.sha256(str(audio_path).encode()).hexdigest()[:8] + return f"{stem}_{path_hash}" + + +def as_waveform_tensor(waveform: object) -> torch.Tensor: + if waveform is None: + msg = "waveform is required" + raise ValueError(msg) + if torch.is_tensor(waveform): + tensor = waveform.detach().to(dtype=torch.float32) + else: + tensor = torch.as_tensor(waveform, dtype=torch.float32) + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + if tensor.ndim != _WAVEFORM_2D_NDIM: + msg = f"waveform must be 1-D or 2-D, got shape {tuple(tensor.shape)}" + raise ValueError(msg) + return tensor.contiguous() + + +def convert_channels(waveform: torch.Tensor, target_nchannels: int) -> torch.Tensor: + if target_nchannels <= 0: + msg = f"target_nchannels must be > 0, got {target_nchannels}" + raise ValueError(msg) + if waveform.shape[0] == target_nchannels: + return waveform + if target_nchannels == 1: + return waveform.mean(dim=0, keepdim=True) + if waveform.shape[0] == 1: + return waveform.repeat(target_nchannels, 1) + msg = f"Cannot convert {waveform.shape[0]} channels to {target_nchannels} without a mixing policy" + raise ValueError(msg) + + +def resample_waveform(waveform: torch.Tensor, sample_rate: int, target_sample_rate: int) -> torch.Tensor: + if sample_rate <= 0: + msg = f"sample_rate must be > 0, got {sample_rate}" + raise ValueError(msg) + if target_sample_rate <= 0: + msg = f"target_sample_rate must be > 0, got {target_sample_rate}" + raise ValueError(msg) + if sample_rate == target_sample_rate: + return waveform + try: + from torchaudio.functional import resample + except ImportError as exc: + msg = "Resampling an in-memory waveform requires torchaudio" + raise RuntimeError(msg) from exc + return resample(waveform, orig_freq=sample_rate, new_freq=target_sample_rate) + + +def prepare_waveform( + waveform: object, + sample_rate: int, + *, + target_sample_rate: int, + target_nchannels: int, +) -> torch.Tensor: + tensor = as_waveform_tensor(waveform) + tensor = convert_channels(tensor, target_nchannels) + tensor = resample_waveform(tensor, int(sample_rate), int(target_sample_rate)) + return tensor.contiguous() diff --git a/nemo_curator/stages/audio/metrics/performance.py b/nemo_curator/stages/audio/metrics/performance.py new file mode 100644 index 0000000000..e38c0d1de0 --- /dev/null +++ b/nemo_curator/stages/audio/metrics/performance.py @@ -0,0 +1,972 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: C901, PLR0912, PLR0915 + +"""Reusable audio pipeline performance summary helpers. + +Audio stages emit counters/timings via ``_log_metrics()``; backends attach +them to ``Task._stage_perf`` as ``StagePerfStats``. Terminal stages feed those +into ``AudioPerformanceSummary`` to build the published ``perf_summary.json``. + +Key types: ``AudioStageMetrics`` (superset of every scalar custom metric), +``AudioStageSamples`` (per-invocation samples for percentiles), +``AudioStageCallerContext`` (GPU/actor fields not derivable from perf stats), +and ``AudioPerformanceSummary`` (the dedup'ing accumulator). All +post-processing lives in ``performance_utils.py``; this file only collects. +""" + +from __future__ import annotations + +import contextlib +import time +from collections import defaultdict +from dataclasses import dataclass, field, fields +from typing import TYPE_CHECKING, Any + +from nemo_curator.stages.audio.metrics.performance_utils import ( + add_ratio, + audio_hours_per_gpu_hour, + bytes_to_mb, + estimate_wallclock_s, + seconds_to_hours, + summarize_samples, +) +from nemo_curator.utils.gpu_sampler import norm_uuid + +if TYPE_CHECKING: + from nemo_curator.tasks import Task + from nemo_curator.utils.performance_utils import StagePerfStats + +# GPU-util metrics ride custom_metrics as ``::``, sampled per GPU +# and summarized as percentiles -- excluded from scalar totals so they are +# never summed into a meaningless aggregate. +_GPU_SAMPLE_KEYS = frozenset({"gpu_util_pct", "gpu_mem_used_pct"}) +_MAX_CUSTOM_METRIC_KEYS = frozenset( + {"expected_stage_gpu_count", "expected_stage_worker_count", "expected_worker_gpu_count"} +) + + +def _gpu_sample_base(key: str) -> str: + """Base metric name of a (possibly UUID-namespaced) GPU sample key.""" + return key.split("::", 1)[0] + + +@dataclass +class AudioStageMetrics: + """Superset of every scalar custom metric the audio pipeline emits. + + Stages populate only relevant fields via ``_log_metrics``; the accumulator + sums them and rebuilds an ``AudioStageMetrics`` per stage. Default 0.0 means + "not emitted"; ``to_dict()`` strips zeros so JSON only carries populated + keys. Adding a metric is one field here plus the producer's ``_log_metrics``. + """ + + # ----- universal counters ----- + input_tasks: float = 0.0 + output_tasks: float = 0.0 + # Actor-pattern fix for stages the framework's num_items_processed cannot + # count (e.g. discovery synthesises work from config, so input is seen as 0). + total_items_emitted: float = 0.0 + + # ----- audio volume scalars ----- + audio_duration_s: float = 0.0 + # Legacy aliases for older stages; new stages emit ``audio_duration_s``. + audio_duration: float = 0.0 + duration: float = 0.0 + input_duration: float = 0.0 + filtered_dur: float = 0.0 + waveform_bytes: float = 0.0 + # "Truthful bytes loaded" for stages that load data themselves (tar reader) + # where framework's input_data_size_mb is unavailable. Producer-opt-in. + bytes_loaded: float = 0.0 + + # ----- text/transcript output ----- + output_chars: float = 0.0 + output_tokens: float = 0.0 + turn1_output_tokens: float = 0.0 + turn2_output_tokens: float = 0.0 + + # ----- inference timing ----- + inference_time_s: float = 0.0 + inference_time: float = 0.0 # legacy alias + adapter_inference_calls: float = 0.0 + adapter_inference_items: float = 0.0 + + # ----- model-side internal timers / counters ----- + model_turn1_prep_time_s: float = 0.0 + model_turn1_generation_time_s: float = 0.0 + model_turn2_prep_time_s: float = 0.0 + model_turn2_generation_time_s: float = 0.0 + model_turn1_valid_inputs: float = 0.0 + model_turn2_valid_inputs: float = 0.0 + model_utterances_skipped_preprocess: float = 0.0 + + # ----- shard reader (NemoTarShardReaderStage) ----- + input_shards: float = 0.0 + output_utterances: float = 0.0 + utterances_emitted: float = 0.0 + manifest_entries: float = 0.0 + manifest_read_time_s: float = 0.0 + tar_members_seen: float = 0.0 + audio_members_decoded: float = 0.0 + tar_open_time_s: float = 0.0 + audio_decode_time_s: float = 0.0 + reader_total_time_s: float = 0.0 + corrupt_audio_count: float = 0.0 + duration_filtered_count: float = 0.0 + utterance_limit_hit: float = 0.0 + + # ----- shard discovery ----- + corpora_seen: float = 0.0 + shards_seen: float = 0.0 + shards_emitted: float = 0.0 + shards_skipped_completed: float = 0.0 + discovery_time_s: float = 0.0 + + # ----- inference / filter / tagging utterance accounting ----- + utterances_input: float = 0.0 + utterances_processed: float = 0.0 + utterances_skipped: float = 0.0 + utterances_selected: float = 0.0 + utterances_eligible: float = 0.0 + utterances_restored: float = 0.0 + utterances_kept_as_is: float = 0.0 + utterances_filtered: float = 0.0 + utterances_newly_flagged: float = 0.0 + utterances_recovered: float = 0.0 + + # ----- text-filter rejection reasons ----- + pnc_rejected: float = 0.0 + empty_after_regex: float = 0.0 + wrong_language: float = 0.0 + low_probability: float = 0.0 + + # ----- preserve-by-value / generic batch filter ----- + input_count: float = 0.0 + output_count: float = 0.0 + filtered_count: float = 0.0 + + # ----- manifest reader ----- + manifests_read: float = 0.0 + entries_read: float = 0.0 + + # ----- ALM data overlap / builder ----- + filter_time: float = 0.0 + input_windows: float = 0.0 + output_windows: float = 0.0 + segments_processed: float = 0.0 + windows_created: float = 0.0 + + # ----- audio split / merge / resample / NeMo ASR align ----- + splits_produced: float = 0.0 + splits_joined: float = 0.0 + words_aligned: float = 0.0 + segments_merged: float = 0.0 + skipped_conversion: float = 0.0 + entries_processed: float = 0.0 + files_transcribed: float = 0.0 + process_time: float = 0.0 # legacy custom timer some stages emit + + # ----- speaker diarization ----- + segments_detected: float = 0.0 + overlap_segments_detected: float = 0.0 + speakers_detected: float = 0.0 + + # ----- VAD ----- + vad_segments_detected: float = 0.0 + skipped_short: float = 0.0 + + # ----- sharded manifest writer (ShardedManifestWriterStage) ----- + writer_process_calls: float = 0.0 + writer_invocation_count: float = 0.0 + writer_items_processed: float = 0.0 + manifest_write_time_s: float = 0.0 + done_marker_write_time_s: float = 0.0 + perf_write_time_s: float = 0.0 + + # forward-compat: any emitted scalar this dataclass doesn't know + extras: dict[str, float] = field(default_factory=dict) + + @classmethod + def known_field_names(cls) -> set[str]: + return {f.name for f in fields(cls) if f.name != "extras"} + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> AudioStageMetrics: + known = cls.known_field_names() + kwargs: dict[str, float] = {} + extras: dict[str, float] = {} + for k, v in (d or {}).items(): + if not isinstance(v, (int, float, bool)): + continue + fv = float(v) + if k in known: + kwargs[k] = fv + else: + extras[k] = fv + return cls(extras=extras, **kwargs) + + def to_dict(self) -> dict[str, float]: + """Serialise only populated fields (zeros are omitted).""" + out: dict[str, float] = {} + for f in fields(self): + if f.name == "extras": + continue + v = getattr(self, f.name) + if v != 0.0: + out[f.name] = v + out.update({k: v for k, v in self.extras.items() if v != 0.0}) + return out + + +@dataclass +class AudioStageSamples: + """Per-invocation sample lists used for percentile derivation. + + Populated once per dedup'd invocation; only the accumulator writes these. + """ + + invocation_process_times_s: list[float] = field(default_factory=list) + actor_idle_times_s: list[float] = field(default_factory=list) + items_processed_per_invocation: list[float] = field(default_factory=list) + batch_sizes: list[float] = field(default_factory=list) + audio_duration_s_per_invocation: list[float] = field(default_factory=list) + + def add(self, perf: StagePerfStats) -> None: + """Record one dedup'd invocation's per-call samples. + + GPU util is sampled per device and accumulated separately, so it is + intentionally absent here -- these are actor/stage scalars only. + """ + self.invocation_process_times_s.append(float(perf.process_time)) + self.actor_idle_times_s.append(float(perf.actor_idle_time)) + self.items_processed_per_invocation.append(float(perf.num_items_processed)) + + custom = perf.custom_metrics or {} + # Batch size proxy; collapses to 1 for single-task-per-invocation stages. + batch_size = ( + custom.get("utterances_input") + or custom.get("input_count") + or custom.get("input_tasks") + or perf.num_items_processed + ) + with contextlib.suppress(TypeError, ValueError): + self.batch_sizes.append(float(batch_size)) + + audio_s = custom.get("audio_duration_s") or custom.get("audio_duration") or 0.0 + try: + audio_s_f = float(audio_s) + except (TypeError, ValueError): + audio_s_f = 0.0 + if audio_s_f > 0: + self.audio_duration_s_per_invocation.append(audio_s_f) + + def summarize(self, percentiles: tuple[int, ...] = (50, 95)) -> dict[str, float]: + """Render the percentile-derived view (only populated keys).""" + out: dict[str, float] = {} + out.update(summarize_samples(self.invocation_process_times_s, "invocation_process_time_s", percentiles)) + out.update(summarize_samples(self.actor_idle_times_s, "queue_wait_s", percentiles)) + out.update(summarize_samples(self.batch_sizes, "batch_size", percentiles)) + out.update(summarize_samples(self.audio_duration_s_per_invocation, "audio_duration_s", percentiles)) + return out + + +@dataclass +class AudioStageCallerContext: + """Optional caller-provided fields the accumulator cannot derive itself. + + A writer with NVML/DCGM/autoscaler snapshots passes these to populate the + GPU/actor fields; defaults cause those fields to be omitted. + """ + + actor_count_samples: list[float] = field(default_factory=list) + gpu_util_pct_samples: list[float] = field(default_factory=list) + gpu_hours: float = 0.0 + setup_time_s_total: float = 0.0 + wallclock_s: float | None = None # overrides estimate if provided + + +def serialize_stage_perf(stage_perf_list: list[StagePerfStats]) -> list[dict[str, Any]]: + """Serialise a task's stage performance chain to JSON-friendly dicts.""" + result: list[dict[str, Any]] = [] + for perf in stage_perf_list: + entry: dict[str, Any] = { + "invocation_id": getattr(perf, "invocation_id", ""), + "stage_name": perf.stage_name, + "process_time": perf.process_time, + "actor_idle_time": perf.actor_idle_time, + "num_items_processed": perf.num_items_processed, + } + # Identity labels (best-effort; empty when unresolved). + for identity_field in ("actor_id", "node_id", "gpu_id", "physical_address", "pod_ip", "hostname"): + identity_value = getattr(perf, identity_field, "") + if identity_value: + entry[identity_field] = identity_value + gpu_indices = getattr(perf, "gpu_indices", None) or [] + gpu_uuids = getattr(perf, "gpu_uuids", None) or [] + if gpu_indices: + entry["gpu_indices"] = [int(idx) for idx in gpu_indices] + if gpu_uuids: + entry["gpu_uuids"] = list(gpu_uuids) + if perf.custom_metrics: + entry["custom_metrics"] = dict(perf.custom_metrics) + result.append(entry) + return result + + +def _task_audio_seconds(task: Task, duration_key: str) -> float: + data = getattr(task, "data", {}) + if not isinstance(data, dict): + return 0.0 + try: + seconds = float(data.get(duration_key, 0.0)) + except (TypeError, ValueError): + return 0.0 + return seconds if seconds > 0 else 0.0 + + +def _build_stage_summary( # noqa: PLR0913 + stage_totals: dict[str, float], + custom_totals: dict[str, float], + samples: AudioStageSamples | None = None, + caller_context: AudioStageCallerContext | None = None, + stage_identity: dict[str, Any] | None = None, + actor_breakdown: dict[str, dict[str, Any]] | None = None, +) -> dict[str, Any]: + """Render one stage's summary in the proposed pipeline-perf shape. + + Combines framework scalar totals, the dedup'd custom-metric superset, + per-invocation sample percentiles, and caller-provided GPU/actor context. + """ + entry: dict[str, Any] = { + "total_process_time_s": stage_totals.get("process_time", 0.0), + "total_actor_idle_time_s": stage_totals.get("actor_idle_time", 0.0), + "total_items_processed": stage_totals.get("num_items_processed", 0.0), + "invocation_count": stage_totals.get("invocation_count", 0.0), + } + + invocation_count = stage_totals.get("invocation_count", 0.0) + total_time = stage_totals.get("process_time", 0.0) + total_items = stage_totals.get("num_items_processed", 0.0) + + metrics = AudioStageMetrics.from_dict(custom_totals) + custom_sums = metrics.to_dict() + + # Actor-pattern stages lack framework num_items_processed; fall back to + # total_items_emitted to keep throughput ratios meaningful. + if total_items == 0.0 and metrics.total_items_emitted > 0: + total_items = metrics.total_items_emitted + entry["total_items_processed"] = total_items + if metrics.total_items_emitted > 0: + entry["total_items_emitted"] = metrics.total_items_emitted + + add_ratio(entry, "avg_invocation_time_s", total_time, invocation_count) + add_ratio(entry, "throughput_items_per_s", total_items, total_time) + + # caller context: wallclock + GPU + actor + ctx = caller_context or AudioStageCallerContext() + actor_count_p50 = None + if ctx.actor_count_samples: + actor_count_p50 = summarize_samples(ctx.actor_count_samples, "actor_count").get("actor_count_p50") + + wallclock_s = ( + ctx.wallclock_s + if ctx.wallclock_s is not None + else estimate_wallclock_s( + total_process_time_s=total_time, + actor_count=actor_count_p50, + ) + ) + if wallclock_s is not None and wallclock_s > 0: + entry["wallclock_s"] = wallclock_s + + if ctx.gpu_hours > 0: + entry["gpu_hours"] = ctx.gpu_hours + if ctx.setup_time_s_total > 0: + entry["setup_time_s_total"] = ctx.setup_time_s_total + entry.update(summarize_samples(ctx.actor_count_samples, "actor_count")) + entry.update(summarize_samples(ctx.gpu_util_pct_samples, "gpu_util_pct")) + + # Identity-driven topology + per-actor scheduling breakdown (keyed by + # actor_id for GPU and CPU stages). Hardware gpu_hours/device_name deferred + # to the NVML/DCGM proposal. + if stage_identity: + entry.update(stage_identity) + if actor_breakdown: + entry["per_actor"] = actor_breakdown + + expected_gpu_count = custom_sums.get("expected_stage_gpu_count", 0.0) + if expected_gpu_count > 0: + active_gpu_count = float(entry.get("gpu_count", 0.0) or 0.0) + missing_gpu_count = max(0.0, expected_gpu_count - active_gpu_count) + entry["expected_gpu_count"] = expected_gpu_count + entry["active_sampled_gpu_count"] = active_gpu_count + entry["missing_or_unattributed_gpu_count"] = missing_gpu_count + entry["active_sampled_gpu_fraction"] = active_gpu_count / expected_gpu_count + expected_worker_count = custom_sums.get("expected_stage_worker_count", 0.0) + if expected_worker_count > 0: + entry["expected_worker_count"] = expected_worker_count + expected_worker_gpu_count = custom_sums.get("expected_worker_gpu_count", 0.0) + if expected_worker_gpu_count > 0: + entry["expected_worker_gpu_count"] = expected_worker_gpu_count + + if not custom_sums and not samples: + return entry + + if custom_sums: + entry["custom_metrics_sum"] = custom_sums + + if samples is not None: + entry.update(samples.summarize()) + + # ----- audio-domain throughput composites ----- + audio_seconds = metrics.audio_duration_s or metrics.audio_duration or metrics.duration + inference_time = metrics.inference_time_s or metrics.inference_time + output_tokens = metrics.output_tokens + output_chars = metrics.output_chars + waveform_mb = bytes_to_mb(metrics.waveform_bytes) + bytes_loaded_mb = bytes_to_mb(metrics.bytes_loaded) + + # Both default to the audio duration the stage saw; filter stages may + # override audio_hours_out via custom_metrics. + if audio_seconds > 0: + entry["audio_hours_in"] = seconds_to_hours(audio_seconds) + entry["audio_hours_out"] = seconds_to_hours(audio_seconds) + + if wallclock_s and actor_count_p50: + gpu_seconds = wallclock_s * actor_count_p50 + ah_per_gpu_h = audio_hours_per_gpu_hour(audio_seconds, gpu_seconds) + if ah_per_gpu_h is not None: + entry["audio_hours_per_gpu_hour"] = ah_per_gpu_h + + # Two efficiency views: overall (audio per total process-time, incl. overhead) + # and inference-only. inference_compute_fraction is the model-vs-overhead share. + add_ratio(entry, "throughput_audio_s_per_process_s", audio_seconds, total_time) + add_ratio(entry, "throughput_audio_s_per_inference_s", audio_seconds, inference_time) + add_ratio(entry, "inference_compute_fraction", inference_time, total_time) + add_ratio(entry, "avg_audio_s_per_item", audio_seconds, total_items) + add_ratio(entry, "throughput_output_tokens_per_process_s", output_tokens, total_time) + add_ratio(entry, "throughput_output_tokens_per_inference_s", output_tokens, inference_time) + add_ratio(entry, "throughput_output_chars_per_process_s", output_chars, total_time) + add_ratio(entry, "throughput_output_chars_per_inference_s", output_chars, inference_time) + add_ratio(entry, "throughput_waveform_mb_per_process_s", waveform_mb, total_time) + add_ratio(entry, "throughput_bytes_loaded_mb_per_process_s", bytes_loaded_mb, total_time) + if metrics.adapter_inference_calls > 0: + entry["adapter_inference_call_count"] = metrics.adapter_inference_calls + entry["adapter_inference_items"] = metrics.adapter_inference_items + add_ratio( + entry, + "avg_adapter_inference_batch_size", + metrics.adapter_inference_items, + metrics.adapter_inference_calls, + ) + add_ratio( + entry, + "avg_audio_s_per_adapter_inference_call", + audio_seconds, + metrics.adapter_inference_calls, + ) + add_ratio( + entry, + "adapter_inference_calls_per_stage_invocation", + metrics.adapter_inference_calls, + invocation_count, + ) + + # ----- pipeline-structure ratios ----- + add_ratio(entry, "output_tasks_per_input_task", metrics.output_tasks, metrics.input_tasks) + utterances_emitted = metrics.utterances_emitted or metrics.output_utterances + add_ratio(entry, "utterances_emitted_per_input_shard", utterances_emitted, metrics.input_shards) + + # Generic item-fate aliases: populate from whichever stage-specific + # counter is non-zero. + items_skipped = metrics.utterances_skipped or metrics.model_utterances_skipped_preprocess or metrics.skipped_short + items_filtered = ( + metrics.utterances_filtered + or metrics.filtered_count + or metrics.duration_filtered_count + or metrics.shards_skipped_completed + ) + items_recovered = metrics.utterances_recovered + if items_skipped > 0: + entry["items_skipped"] = items_skipped + if items_filtered > 0: + entry["items_filtered"] = items_filtered + if items_recovered > 0: + entry["items_recovered"] = items_recovered + if output_tokens > 0: + entry["output_tokens"] = output_tokens + + # filter/tagging stages: per-input-utterance ratios + utterances_input = metrics.utterances_input or metrics.input_tasks + if utterances_input > 0: + for metric_name in ( + "utterances_selected", + "utterances_skipped", + "utterances_processed", + "utterances_eligible", + "utterances_restored", + "utterances_kept_as_is", + "utterances_filtered", + "utterances_newly_flagged", + "utterances_recovered", + "pnc_rejected", + "empty_after_regex", + "wrong_language", + "low_probability", + ): + value = getattr(metrics, metric_name, 0.0) + add_ratio(entry, f"{metric_name}_per_input_utterance", value, utterances_input) + + return entry + + +@dataclass +class AudioPerformanceSummary: + """Accumulate and summarise audio task performance metrics. + + Writer-independent: a terminal stage calls ``record_task`` per output task, + then writes ``build_summary()`` wherever its output contract requires. + """ + + duration_key: str = "duration" + _stage_totals: dict[str, dict[str, float]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict(float)), + repr=False, + ) + _stage_custom_totals: dict[str, dict[str, float]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict(float)), + repr=False, + ) + _stage_samples: dict[str, AudioStageSamples] = field( + default_factory=lambda: defaultdict(AudioStageSamples), + repr=False, + ) + _seen_perf_invocations: set[str] = field(default_factory=set, repr=False) + # Per-(stage, actor) scheduling breakdown for any record with a resolved + # actor_id (GPU and CPU stages). GPU actors also carry physical address + + # NVML util/mem percentiles. + _stage_actor_samples: dict[str, dict[str, AudioStageSamples]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict(AudioStageSamples)), + repr=False, + ) + _stage_actor_items: dict[str, dict[str, float]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict(float)), + repr=False, + ) + _stage_actor_audio_s: dict[str, dict[str, float]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict(float)), + repr=False, + ) + _stage_actor_location: dict[str, dict[str, dict[str, Any]]] = field( + default_factory=lambda: defaultdict(dict), + repr=False, + ) + # Per-GPU NVML samples nested stage -> actor -> address (":"), + # rolled up under each actor's ``gpus`` block. ``_gpu_unit_meta`` holds + # per-address metadata (gpu_index, gpu_uuid). + _stage_actor_gpu_util: dict[str, dict[str, dict[str, list[float]]]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))), + repr=False, + ) + _stage_actor_gpu_mem: dict[str, dict[str, dict[str, list[float]]]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))), + repr=False, + ) + _gpu_unit_meta: dict[str, dict[str, Any]] = field(default_factory=dict, repr=False) + # _stage_gpus: per-actor addresses (":"); _stage_gpu_units: + # individual devices (":") so gpu_count is true under tensor-parallel. + _stage_gpus: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set), repr=False) + _stage_gpu_units: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set), repr=False) + _stage_actors: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set), repr=False) + _actor_node: dict[str, str] = field(default_factory=dict, repr=False) + _shard_counts: dict[str, int] = field(default_factory=lambda: defaultdict(int), repr=False) + _shard_audio_seconds: dict[str, float] = field(default_factory=lambda: defaultdict(float), repr=False) + _total_utterances: int = field(default=0, repr=False) + _total_audio_seconds: float = field(default=0.0, repr=False) + _wall_start_s: float = field(default_factory=time.perf_counter, repr=False) + + @property + def total_utterances(self) -> int: + return self._total_utterances + + @property + def shard_keys(self) -> list[str]: + return sorted(self._shard_counts) + + def shard_count(self, shard_key: str) -> int: + return self._shard_counts.get(shard_key, 0) + + def reset_wall_timer(self) -> None: + self._wall_start_s = time.perf_counter() + + # ----------------------------------------------------------------------- + # Recording + # ----------------------------------------------------------------------- + + def record_task(self, task: Task, shard_key: str | None = None, *, include_stage_perf: bool = True) -> None: + """Record one audio task and optionally its attached stage perf chain.""" + audio_seconds = _task_audio_seconds(task, self.duration_key) + self._total_utterances += 1 + self._total_audio_seconds += audio_seconds + + if shard_key is not None: + self._shard_counts[shard_key] += 1 + self._shard_audio_seconds[shard_key] += audio_seconds + + if include_stage_perf: + self.record_stage_perf(getattr(task, "_stage_perf", []) or []) + + @staticmethod + def _fingerprint_perf(perf: StagePerfStats) -> str: + """Deterministic fingerprint of a ``StagePerfStats`` value tuple. + + Fallback dedup key when ``invocation_id`` is unset: the same record is + seen once per emitted downstream task, so an N-task invocation would be + counted N times. Collisions (distinct invocations with byte-equal + timings and custom metrics) are not a practical concern. + """ + custom = sorted((perf.custom_metrics or {}).items()) + return repr( + ( + perf.stage_name, + getattr(perf, "actor_id", ""), + getattr(perf, "node_id", ""), + getattr(perf, "gpu_id", ""), + getattr(perf, "physical_address", ""), + round(perf.process_time, 9), + round(perf.actor_idle_time, 9), + perf.num_items_processed, + tuple((k, round(float(v), 9)) for k, v in custom), + ) + ) + + def record_stage_perf(self, stage_perf_list: list[StagePerfStats]) -> None: + """Accumulate ``StagePerfStats``, deduplicating repeat sightings. + + Dedup key is ``invocation_id`` when wired, else a synthetic value-tuple + fingerprint. After dedup, each record feeds stage scalar totals, + custom-metric sums, and per-invocation samples (for p50/p95). + """ + for perf in stage_perf_list: + if not all( + hasattr(perf, attr) + for attr in ("stage_name", "process_time", "actor_idle_time", "num_items_processed") + ): + # Legacy/custom stages sometimes attach dictionaries. Perf is + # optional observability, so malformed records must not break + # terminal manifest writing. + continue + invocation_id = getattr(perf, "invocation_id", "") or self._fingerprint_perf(perf) + if invocation_id in self._seen_perf_invocations: + continue + self._seen_perf_invocations.add(invocation_id) + + totals = self._stage_totals[perf.stage_name] + totals["process_time"] += perf.process_time + totals["actor_idle_time"] += perf.actor_idle_time + totals["num_items_processed"] += perf.num_items_processed + totals["invocation_count"] += 1 + + for key, value in (perf.custom_metrics or {}).items(): + if _gpu_sample_base(key) in _GPU_SAMPLE_KEYS: + continue + if isinstance(value, (int, float, bool)): + if key in _MAX_CUSTOM_METRIC_KEYS: + self._stage_custom_totals[perf.stage_name][key] = max( + self._stage_custom_totals[perf.stage_name].get(key, 0.0), + float(value), + ) + else: + self._stage_custom_totals[perf.stage_name][key] += float(value) + + self._stage_samples[perf.stage_name].add(perf) + self._record_actor_breakdown(perf) + + def _record_actor_breakdown(self, perf: StagePerfStats) -> None: + """Accumulate the per-(stage, actor) scheduling breakdown. + + Keyed by ``actor_id`` so every actor-backed stage (GPU or CPU) reports + per-actor metrics; GPU actors also contribute their physical address and + device units. No-op for records without a resolved ``actor_id``. + """ + stage_name = perf.stage_name + actor_id = (getattr(perf, "actor_id", "") or "").strip() + if not actor_id: + return + node_id = (getattr(perf, "node_id", "") or "").strip() + self._stage_actors[stage_name].add(actor_id) + if node_id: + self._actor_node.setdefault(actor_id, node_id) + self._stage_actor_samples[stage_name][actor_id].add(perf) + self._stage_actor_items[stage_name][actor_id] += float(perf.num_items_processed) + custom = perf.custom_metrics or {} + audio_s = custom.get("audio_duration_s") or custom.get("audio_duration") or 0.0 + with contextlib.suppress(TypeError, ValueError): + self._stage_actor_audio_s[stage_name][actor_id] += float(audio_s) + # GPU topology: physical address + device units (gpu_count true under TP). + physical_address = (getattr(perf, "physical_address", "") or "").strip() + host = physical_address.rsplit(":", 1)[0] if physical_address else (node_id or "node") + if physical_address: + self._stage_gpus[stage_name].add(physical_address) + for idx in getattr(perf, "gpu_indices", None) or (): + self._stage_gpu_units[stage_name].add(f"{host}:{idx}") + self._record_gpu_samples(stage_name, actor_id, host, perf) + location = self._actor_location_fields(perf) + if location: + self._stage_actor_location[stage_name][actor_id] = location + + def _record_gpu_samples(self, stage_name: str, actor_id: str, host: str, perf: StagePerfStats) -> None: + """Fold per-GPU NVML samples (``::``) onto a physical address. + + Maps each sample's normalized UUID back to the actor's physical GPU index + (via parallel ``gpu_indices``/``gpu_uuids``) so it lands on the canonical + ``:`` address; unmappable UUIDs fall back to ``:``. + """ + custom = perf.custom_metrics or {} + if not any(_gpu_sample_base(k) in _GPU_SAMPLE_KEYS for k in custom): + return + gpu_indices = list(getattr(perf, "gpu_indices", None) or []) + gpu_uuids = list(getattr(perf, "gpu_uuids", None) or []) + uuid_to_index = {norm_uuid(u): idx for u, idx in zip(gpu_uuids, gpu_indices, strict=False)} + uuid_to_raw = {norm_uuid(u): u for u in gpu_uuids} + for key, value in custom.items(): + base = _gpu_sample_base(key) + if base not in _GPU_SAMPLE_KEYS or "::" not in key: + continue + try: + sample = float(value) + except (TypeError, ValueError): + continue + uuid_key = key.split("::", 1)[1] + index = uuid_to_index.get(uuid_key) + address = f"{host}:{index}" if index is not None else f"{host}:{uuid_key}" + self._stage_gpu_units[stage_name].add(address) + target = self._stage_actor_gpu_util if base == "gpu_util_pct" else self._stage_actor_gpu_mem + target[stage_name][actor_id][address].append(sample) + meta = self._gpu_unit_meta.setdefault(address, {}) + if index is not None and "gpu_index" not in meta: + meta["gpu_index"] = int(index) + if uuid_key in uuid_to_raw and "gpu_uuid" not in meta: + meta["gpu_uuid"] = uuid_to_raw[uuid_key] + + @staticmethod + def _actor_location_fields(perf: StagePerfStats) -> dict[str, Any]: + """Additive per-actor metadata (GPU actors carry physical address). + + ``node_id`` is folded in by the builder, not here. + """ + block: dict[str, Any] = {} + physical_address = getattr(perf, "physical_address", "") or "" + pod_ip = getattr(perf, "pod_ip", "") or "" + hostname = getattr(perf, "hostname", "") or "" + gpu_indices = getattr(perf, "gpu_indices", None) or [] + gpu_uuids = getattr(perf, "gpu_uuids", None) or [] + if physical_address: + block["physical_address"] = physical_address + if pod_ip: + block["pod_ip"] = pod_ip + if hostname: + block["hostname"] = hostname + if gpu_indices: + block["gpu_indices"] = [int(idx) for idx in gpu_indices] + if gpu_uuids: + block["gpu_uuids"] = list(gpu_uuids) + return block + + # ----------------------------------------------------------------------- + # Building the published summary + # ----------------------------------------------------------------------- + + def _stage_identity_meta(self, stage_name: str) -> dict[str, Any]: + """Topology labels for a stage: gpu_addresses, gpu_count, actor_count. + + ``gpu_count`` counts distinct physical devices (a TP actor on 2 GPUs + counts as 2). Keys are omitted for stages without resolved identity. + """ + meta: dict[str, Any] = {} + addresses = sorted(self._stage_gpus.get(stage_name, set())) + if addresses: + meta["gpu_addresses"] = addresses + meta["gpu_count"] = float(len(self._stage_gpu_units.get(stage_name, addresses))) + actors = self._stage_actors.get(stage_name, set()) + if actors: + meta["actor_count"] = float(len(actors)) + return meta + + def _build_per_actor(self, stage_name: str) -> dict[str, dict[str, Any]]: + """Per-actor scheduling breakdown for a stage (GPU and CPU alike). + + Keyed by ``actor_id``; empty when no actor identity was resolved. Each + entry carries node_id, items_processed, audio_hours_in, and + batch_size/queue_wait percentiles. GPU actors also carry physical_address, + gpu_indices/gpu_uuids, and a nested ``gpus`` map of per-device NVML + percentiles (only when the worker ran a GPU sampler). + """ + actor_samples = self._stage_actor_samples.get(stage_name, {}) + if not actor_samples: + return {} + per_actor: dict[str, dict[str, Any]] = {} + for actor_id in sorted(actor_samples): + block: dict[str, Any] = {} + node_id = self._actor_node.get(actor_id) + if node_id: + block["node_id"] = node_id + items = self._stage_actor_items.get(stage_name, {}).get(actor_id, 0.0) + if items: + block["items_processed"] = items + audio_s = self._stage_actor_audio_s.get(stage_name, {}).get(actor_id, 0.0) + if audio_s > 0: + block["audio_hours_in"] = seconds_to_hours(audio_s) + summary = actor_samples[actor_id].summarize() + for key in ("batch_size_p50", "batch_size_p95", "queue_wait_s_p50", "queue_wait_s_p95"): + if key in summary: + block[key] = summary[key] + location = self._stage_actor_location.get(stage_name, {}).get(actor_id) + if location: + block.update(location) + gpus = self._build_actor_gpus(stage_name, actor_id) + if gpus: + block["gpus"] = gpus + per_actor[actor_id] = block + return per_actor + + def _build_actor_gpus(self, stage_name: str, actor_id: str) -> dict[str, dict[str, Any]]: + """Per-physical-GPU NVML breakdown for one actor, keyed by ``:``. + + Each device carries gpu_index/gpu_uuid metadata and util/mem percentiles + from its own samples. Empty when the actor ran no GPU sampler. + """ + util_by_addr = self._stage_actor_gpu_util.get(stage_name, {}).get(actor_id, {}) + mem_by_addr = self._stage_actor_gpu_mem.get(stage_name, {}).get(actor_id, {}) + addresses = sorted(set(util_by_addr) | set(mem_by_addr)) + gpus: dict[str, dict[str, Any]] = {} + for address in addresses: + block: dict[str, Any] = dict(self._gpu_unit_meta.get(address, {})) + block.update(summarize_samples(util_by_addr.get(address, []), "gpu_util_pct")) + block.update(summarize_samples(mem_by_addr.get(address, []), "gpu_mem_used_pct")) + gpus[address] = block + return gpus + + def build_stage_summaries( + self, + stage_caller_context: dict[str, AudioStageCallerContext] | None = None, + ) -> dict[str, dict[str, Any]]: + """Build per-stage aggregate summaries from accumulated metrics.""" + ctx_by_stage = stage_caller_context or {} + return { + stage_name: _build_stage_summary( + dict(totals), + dict(self._stage_custom_totals.get(stage_name, {})), + samples=self._stage_samples.get(stage_name), + caller_context=ctx_by_stage.get(stage_name), + stage_identity=self._stage_identity_meta(stage_name), + actor_breakdown=self._build_per_actor(stage_name), + ) + for stage_name, totals in self._stage_totals.items() + } + + def build_summary( + self, + *, + extra_stage_summaries: dict[str, dict[str, Any]] | None = None, + wall_time_s: float | None = None, + run_id: str | None = None, + executor: str | None = None, + stage_caller_context: dict[str, AudioStageCallerContext] | None = None, + ) -> dict[str, Any]: + """Build the full audio pipeline performance summary. + + Top-level fields match the proposed pipeline-perf shape (run_id, + executor, input_hours, output_hours, rows_in, rows_out, stages). + Backward-compat keys (total_utterances, total_audio_seconds, shards, + etc.) are preserved verbatim for the protocol-doc baseline tables. + """ + resolved_wall_time_s = ( + max(time.perf_counter() - self._wall_start_s, 0.0) if wall_time_s is None else max(wall_time_s, 0.0) + ) + stages_summary = self.build_stage_summaries(stage_caller_context) + if extra_stage_summaries: + stages_summary.update(extra_stage_summaries) + + # Derive top-level input_hours from the first stage that has audio volume. + # Derive rows_in by priority so discovery's synthetic input_tasks=1 does + # not mask reader-level row counts. + input_hours = 0.0 + rows_in_by_key = { + "manifest_entries": 0.0, + "output_utterances": 0.0, + "input_shards": 0.0, + "input_tasks": 0.0, + } + for stage_dict in stages_summary.values(): + if input_hours == 0.0 and "audio_hours_in" in stage_dict: + input_hours = stage_dict["audio_hours_in"] + cm = stage_dict.get("custom_metrics_sum", {}) + for key, value in rows_in_by_key.items(): + if value == 0.0: + rows_in_by_key[key] = float(cm.get(key, 0.0) or 0.0) + + rows_in = next((value for value in rows_in_by_key.values() if value > 0.0), 0.0) + + output_hours = seconds_to_hours(self._total_audio_seconds) + rows_out = float(self._total_utterances) + + summary: dict[str, Any] = { + # proposed-structure top-level + "run_id": run_id or "", + "executor": executor or "", + "input_hours": input_hours, + "output_hours": output_hours, + "rows_in": rows_in, + "rows_out": rows_out, + # backward-compat top-level (protocol-doc baselines) + "total_utterances": self._total_utterances, + "total_audio_seconds": self._total_audio_seconds, + "total_audio_hours": output_hours, + "writer_wall_time_s": resolved_wall_time_s, + "pipeline_audio_s_per_wall_s": ( + self._total_audio_seconds / resolved_wall_time_s if resolved_wall_time_s > 0 else 0.0 + ), + "pipeline_utterances_per_wall_s": ( + self._total_utterances / resolved_wall_time_s if resolved_wall_time_s > 0 else 0.0 + ), + "perf_invocations_counted": len(self._seen_perf_invocations), + "shards": { + shard: { + "utterances": count, + "audio_seconds": self._shard_audio_seconds.get(shard, 0.0), + "audio_hours": self._shard_audio_seconds.get(shard, 0.0) / 3600.0, + } + for shard, count in sorted(self._shard_counts.items()) + }, + "stages": stages_summary, + } + + # Cluster-level rollup (scheduling only). Hardware rollups are deferred + # to the NVML/DCGM proposal; only identity-derivable fields emitted here. + pipeline_throughput: dict[str, Any] = {} + if resolved_wall_time_s > 0 and self._total_audio_seconds > 0: + pipeline_throughput["audio_hours_per_wallclock_hour"] = seconds_to_hours( + self._total_audio_seconds + ) / seconds_to_hours(resolved_wall_time_s) + all_addresses = sorted({addr for addrs in self._stage_gpus.values() for addr in addrs}) + if all_addresses: + all_units = {unit for units in self._stage_gpu_units.values() for unit in units} + pipeline_throughput["gpu_addresses"] = all_addresses + pipeline_throughput["gpu_count"] = float(len(all_units or all_addresses)) + if pipeline_throughput: + summary["pipeline_throughput"] = pipeline_throughput + + return summary diff --git a/nemo_curator/stages/audio/metrics/performance_utils.py b/nemo_curator/stages/audio/metrics/performance_utils.py new file mode 100644 index 0000000000..ba5b4b6069 --- /dev/null +++ b/nemo_curator/stages/audio/metrics/performance_utils.py @@ -0,0 +1,184 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure post-processing helpers for the audio pipeline performance summary. + +Owns everything that runs on top of the raw counters/samples collected by the +``performance.py`` accumulator: percentile computation, unit conversions, safe +ratio helpers, and audio-domain composites (``audio_hours_per_gpu_hour`` etc.). +Pure functions so they can be unit-tested and reused (CI checks, dashboards) +without the full accumulator. + +NOTE: shadows ``nemo_curator.utils.performance_utils`` by filename only; the +import paths differ so there is no conflict. That module owns ``StagePerfStats``; +this one owns audio-specific post-processing. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Iterable + +_MAX_PERCENTILE = 100.0 + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +SECONDS_PER_HOUR = 3600.0 +BYTES_PER_MB = 1024.0 * 1024.0 +BYTES_PER_GB = 1024.0 * 1024.0 * 1024.0 +DEFAULT_PERCENTILES: tuple[int, ...] = (50, 95) + + +# --------------------------------------------------------------------------- +# Unit conversions +# --------------------------------------------------------------------------- + + +def seconds_to_hours(seconds: float) -> float: + """Convert seconds to hours.""" + return float(seconds) / SECONDS_PER_HOUR + + +def hours_to_seconds(hours: float) -> float: + return float(hours) * SECONDS_PER_HOUR + + +def bytes_to_mb(b: float) -> float: + return float(b) / BYTES_PER_MB + + +def bytes_to_gb(b: float) -> float: + return float(b) / BYTES_PER_GB + + +# --------------------------------------------------------------------------- +# Safe ratio +# --------------------------------------------------------------------------- + + +def safe_ratio(numerator: float, denominator: float) -> float | None: + """Return numerator/denominator, or None if either input is non-positive. + + None (not 0/NaN) lets the consumer omit the field rather than carry a + misleading zero downstream. + """ + if numerator is None or denominator is None: + return None + if numerator <= 0 or denominator <= 0: + return None + return float(numerator) / float(denominator) + + +def add_ratio(entry: dict[str, Any], name: str, numerator: float, denominator: float) -> None: + """Add ``entry[name] = numerator/denominator`` only when both > 0.""" + value = safe_ratio(numerator, denominator) + if value is not None: + entry[name] = value + + +# --------------------------------------------------------------------------- +# Percentiles +# --------------------------------------------------------------------------- + + +def _percentile_sorted(sorted_values: list[float], p: float) -> float: + """``p``-th percentile of an already-sorted, non-empty list (linear interp).""" + if p < 0 or p > _MAX_PERCENTILE: + msg = f"percentile p must be in [0, 100], got {p}" + raise ValueError(msg) + if len(sorted_values) == 1: + return sorted_values[0] + rank = (len(sorted_values) - 1) * (p / 100.0) + lo = int(rank) + hi = min(lo + 1, len(sorted_values) - 1) + if lo == hi: + return sorted_values[lo] + frac = rank - lo + return sorted_values[lo] + frac * (sorted_values[hi] - sorted_values[lo]) + + +def percentile(values: Iterable[float], p: float) -> float | None: + """Compute the p-th percentile of ``values`` with linear interpolation. + + Mirrors numpy's default but avoids importing numpy so it works in writer + pods without it. Returns None when empty; ``p`` must be in ``[0, 100]``. + """ + materialized = [float(v) for v in values] + if not materialized: + return None + materialized.sort() + return _percentile_sorted(materialized, p) + + +def summarize_samples( + values: Iterable[float], + name: str, + percentiles: Iterable[int] = DEFAULT_PERCENTILES, +) -> dict[str, float]: + """Return ``{f"{name}_p{P}": value}`` for each requested percentile. + + Empty samples -> empty dict. Sorts once, then indexes each percentile off it. + """ + out: dict[str, float] = {} + materialized = [float(v) for v in values] + if not materialized: + return out + materialized.sort() + for p in percentiles: + out[f"{name}_p{p}"] = _percentile_sorted(materialized, p) + return out + + +# --------------------------------------------------------------------------- +# Audio-domain composites +# --------------------------------------------------------------------------- + + +def audio_hours_per_gpu_hour(audio_seconds: float, gpu_seconds: float) -> float | None: + """Hours of audio processed per GPU-hour spent. + + ``gpu_seconds`` is ``gpu_count * wallclock_s`` (caller-computed). Returns + None when either input is non-positive. + """ + return safe_ratio(seconds_to_hours(audio_seconds), seconds_to_hours(gpu_seconds)) + + +def items_per_hour(items: float, wall_seconds: float) -> float | None: + """Generic throughput in items / wallclock-hour.""" + if wall_seconds <= 0: + return None + return safe_ratio(items, seconds_to_hours(wall_seconds)) + + +def estimate_wallclock_s( + total_process_time_s: float, + actor_count: float | None = None, +) -> float | None: + """Best-effort stage wallclock estimate. + + ``process_time`` sums CPU time across an actor's invocations; true per-stage + wall would need first/last timestamps the framework doesn't expose. So: + divide by ``actor_count`` when positive (spread across concurrent actors), + else use ``total_process_time_s``. No ``max(invocation_times)`` fallback -- + it reads optimistically under parallel actors, so we stay conservative. + """ + if actor_count and actor_count > 0: + return float(total_process_time_s) / float(actor_count) + if total_process_time_s > 0: + return float(total_process_time_s) + return None diff --git a/nemo_curator/stages/audio/metrics/squim.py b/nemo_curator/stages/audio/metrics/squim.py index d1a3ec49de..f0ff10ef58 100644 --- a/nemo_curator/stages/audio/metrics/squim.py +++ b/nemo_curator/stages/audio/metrics/squim.py @@ -203,8 +203,8 @@ def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: in batches on GPU, then scatters results back to the originating task's segment. """ - if not tasks: - return tasks + if len(tasks) == 0: + return [] # Collect all valid waveforms with their origin (task_idx, segment_idx) all_waveform_metadata: list[tuple[int, int, torch.Tensor]] = [] diff --git a/nemo_curator/stages/audio/model_input_segmentation.py b/nemo_curator/stages/audio/model_input_segmentation.py new file mode 100644 index 0000000000..864eba873d --- /dev/null +++ b/nemo_curator/stages/audio/model_input_segmentation.py @@ -0,0 +1,104 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared audio model-input segmentation helpers. + +Segmentation creates model-safe work units. Duration-aware bucketing is a +separate packing step that consumes these bounded units. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass + + +@dataclass(frozen=True) +class AudioSegment: + """One contiguous audio model-input segment in sample coordinates.""" + + index: int + count: int + start_sample: int + stop_sample: int + duration_s: float + + +def resolve_max_model_input_duration( + *, + max_duration_s: float, + owner: str, +) -> float: + """Validate and normalize the model-input duration ceiling.""" + + maximum = float(max_duration_s) + if maximum <= 0: + msg = f"{owner}.max_inference_duration_s must be > 0 s, got {max_duration_s}" + raise ValueError(msg) + return maximum + + +def plan_audio_segments( + *, + num_samples: int, + sample_rate: int, + max_duration_s: float, + owner: str, +) -> tuple[AudioSegment, ...]: + """Create bounded contiguous segment specs for one audio input.""" + + maximum = resolve_max_model_input_duration( + max_duration_s=max_duration_s, + owner=owner, + ) + if sample_rate <= 0: + msg = f"{owner}.sample_rate must be > 0, got {sample_rate}" + raise ValueError(msg) + if num_samples <= 0: + return ( + AudioSegment( + index=0, + count=1, + start_sample=0, + stop_sample=0, + duration_s=0.0, + ), + ) + + max_samples = max(1, int(maximum * float(sample_rate))) + starts = list(range(0, int(num_samples), max_samples)) + count = max(1, len(starts)) + segments: list[AudioSegment] = [] + for index, start in enumerate(starts): + stop = min(start + max_samples, int(num_samples)) + duration_s = float(stop - start) / float(sample_rate) + segments.append( + AudioSegment( + index=index, + count=count, + start_sample=start, + stop_sample=stop, + duration_s=duration_s, + ) + ) + return tuple(segments) + + +def duration_to_num_samples(duration_s: float, sample_rate: int) -> int: + """Return ceil(duration_s * sample_rate) with non-negative duration.""" + + if sample_rate <= 0: + msg = f"sample_rate must be > 0, got {sample_rate}" + raise ValueError(msg) + return math.ceil(max(float(duration_s), 0.0) * float(sample_rate)) diff --git a/nemo_curator/stages/audio/preprocessing/mono_conversion.py b/nemo_curator/stages/audio/preprocessing/mono_conversion.py index ed60661235..0d89868335 100755 --- a/nemo_curator/stages/audio/preprocessing/mono_conversion.py +++ b/nemo_curator/stages/audio/preprocessing/mono_conversion.py @@ -60,9 +60,6 @@ class MonoConversionStage(ProcessingStage[AudioTask, AudioTask]): batch_size: int = 1 resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) - def __post_init__(self): - super().__init__() - def inputs(self) -> tuple[list[str], list[str]]: return [], [] @@ -103,11 +100,12 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]: else: mono_waveform = waveform + num_samples = int(mono_waveform.shape[-1]) task.data["waveform"] = mono_waveform task.data["sample_rate"] = sample_rate task.data["is_mono"] = True - task.data["duration"] = mono_waveform.shape[1] / sample_rate - task.data["num_samples"] = mono_waveform.shape[1] + task.data["duration"] = num_samples / sample_rate + task.data["num_samples"] = num_samples except (OSError, RuntimeError) as e: logger.error(f"Error processing {audio_filepath}: {e}") diff --git a/nemo_curator/stages/base.py b/nemo_curator/stages/base.py index 1688d679ff..db55235e20 100644 --- a/nemo_curator/stages/base.py +++ b/nemo_curator/stages/base.py @@ -119,6 +119,9 @@ class ProcessingStage(ABC, Generic[X, Y], metaclass=StageMeta): # resumability layer to mark the counter-decrement boundary. is_source_stage: bool = False is_sink_stage: bool = False + # Opt-in diagnostics used by benchmark pipelines. Existing stages retain + # main's performance record shape and avoid background GPU sampling. + extended_performance_metrics: bool = False @property @final @@ -253,6 +256,19 @@ def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: Work worker_metadata (WorkerMetadata, optional): Information about the worker (provided by some backends) """ + def setup_on_node_resources(self) -> Resources: + """Resources needed by the per-node setup task. + + Most stages need the same placement resources for setup as for steady + state processing. Stages that only prefetch/download per-node assets can + override this to avoid reserving GPUs during setup. + """ + if isinstance(self.resources, Resources): + return self.resources + if isinstance(self.resources, dict): + return Resources(**self.resources) + return Resources() + def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: """Setup method called once before processing begins. Override this method to perform any initialization that should diff --git a/nemo_curator/stages/payload_lifecycle.py b/nemo_curator/stages/payload_lifecycle.py new file mode 100644 index 0000000000..5734af7b66 --- /dev/null +++ b/nemo_curator/stages/payload_lifecycle.py @@ -0,0 +1,1042 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: ANN401, BLE001, C901, EM101, EM102, PLR0912, S110, S112, TRY300, TRY301 + +from __future__ import annotations + +import os +import re +import threading +import time +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch +from loguru import logger + +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.pipeline.payload_refs import ( + PayloadRef, + _get_named_actor, + heartbeat_payload_refs_batched, + release_payload_ref, + resolve_payload_refs_batched, + strip_payload_refs, + task_payload_refs, +) +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import AudioTask, Task + +if TYPE_CHECKING: + from nemo_curator.backends.base import NodeInfo + + +_DEFAULT_NODE_MEMORY_FRACTION = 0.70 +_DEFAULT_LEASE_TTL_S = 3600.0 +_DEFAULT_POLL_INTERVAL_S = 0.25 +_DEFAULT_ADMISSION_WAIT_TIMEOUT_S = 4 * 60 * 60 +_DEFAULT_MATERIALIZED_LEASE_TTL_S = 4 * 60 * 60 +_DEFAULT_SAMPLE_WIDTH_BYTES = 4 + + +def _ray_get(obj: Any) -> Any: + import ray + + return ray.get(obj) + + +def _resolve_node_id() -> str: + try: + import ray + + ctx = ray.get_runtime_context() + node_id = getattr(ctx, "get_node_id", lambda: None)() + if node_id: + return str(node_id) + except Exception: + pass + return os.uname().nodename + + +def _safe_actor_suffix(value: str) -> str: + suffix = re.sub(r"[^A-Za-z0-9_.-]+", "_", value) + return suffix or "unknown" + + +def _current_ray_namespace() -> str | None: + try: + import ray + + ctx = ray.get_runtime_context() + namespace = getattr(ctx, "namespace", None) + if callable(namespace): + namespace = namespace() + if not namespace: + get_namespace = getattr(ctx, "get_namespace", None) + if callable(get_namespace): + namespace = get_namespace() + if namespace: + return str(namespace) + except Exception: + pass + return None + + +def _parse_byte_limit(value: str | None, *, field_name: str = "byte limit") -> int | None: + if not value: + return None + text = value.strip().lower() + try: + if text.endswith("k"): + parsed = int(float(text[:-1]) * 1024) + elif text.endswith("m"): + parsed = int(float(text[:-1]) * 1024**2) + elif text.endswith("g"): + parsed = int(float(text[:-1]) * 1024**3) + else: + parsed = int(text) + except ValueError as exc: + msg = f"{field_name} must be an integer byte count or a k/m/g byte string, got {value!r}" + raise ValueError(msg) from exc + if parsed <= 0: + msg = f"{field_name} must be positive, got {value!r}" + raise ValueError(msg) + return parsed + + +def _detect_memory_limit_bytes() -> int | None: + cgroup_paths = ( + "/sys/fs/cgroup/memory.max", + "/sys/fs/cgroup/memory/memory.limit_in_bytes", + ) + for path in cgroup_paths: + try: + with open(path, encoding="utf-8") as f: + raw = f.read().strip() + if raw and raw != "max": + value = int(raw) + if value > 0 and value < 1 << 60: + return value + except Exception: + continue + + try: + pages = os.sysconf("SC_PHYS_PAGES") + page_size = os.sysconf("SC_PAGE_SIZE") + if pages > 0 and page_size > 0: + return int(pages * page_size) + except Exception: + return None + return None + + +def _detect_memory_usage_bytes() -> int: + cgroup_paths = ( + "/sys/fs/cgroup/memory.current", + "/sys/fs/cgroup/memory/memory.usage_in_bytes", + ) + for path in cgroup_paths: + try: + with open(path, encoding="utf-8") as f: + raw = f.read().strip() + if raw: + return max(0, int(raw)) + except Exception: + continue + try: + import resource + + # ru_maxrss is KiB on Linux. + return int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1024) + except Exception: + return 0 + + +def _resolve_node_payload_budget( + explicit_bytes: int | None, + memory_fraction: float, +) -> int: + if explicit_bytes is not None: + return max(1, int(explicit_bytes)) + memory_limit = _detect_memory_limit_bytes() + if memory_limit is None: + # Conservative fallback for development machines where cgroups are not visible. + return int(32 * 1024**3 * memory_fraction) + return max(1, int(memory_limit * memory_fraction)) + + +def _payload_object_bytes(payload: Any) -> int: + if isinstance(payload, torch.Tensor): + return int(payload.element_size() * payload.nelement()) + nbytes = getattr(payload, "nbytes", None) + if nbytes is not None: + return int(nbytes) + if isinstance(payload, (bytes, bytearray, memoryview)): + return len(payload) + return 0 + + +def _duration_to_payload_bytes( + duration_s: float, + sample_rate: int, + channels: int, + sample_width_bytes: int, +) -> int: + if duration_s <= 0: + raise ValueError("Audio payload byte admission requires positive duration before decode") + return max(1, int(duration_s * sample_rate * channels * sample_width_bytes)) + + +def _lease_expires_at(lease_ttl_s: float) -> float | None: + ttl = float(lease_ttl_s) + if ttl <= 0: + return None + return time.monotonic() + ttl + + +def _task_payload_estimate_bytes( + task: Task, + *, + duration_key: str, + sample_rate: int, + channels: int, + sample_width_bytes: int, +) -> int: + raw_duration = task.data.get(duration_key) + if raw_duration is None: + raise ValueError(f"Audio payload byte admission requires '{duration_key}' in each row before audio decode") + try: + duration_s = float(raw_duration) + except (TypeError, ValueError) as exc: + raise ValueError(f"Audio payload duration must be numeric, got {raw_duration!r}") from exc + return _duration_to_payload_bytes(duration_s, sample_rate, channels, sample_width_bytes) + + +class _PayloadAdmissionState: + def __init__( + self, + *, + default_node_budget_bytes: int, + default_cluster_budget_bytes: int | None = None, + default_lease_ttl_s: float = _DEFAULT_LEASE_TTL_S, + ) -> None: + self.default_node_budget_bytes = max(1, int(default_node_budget_bytes)) + self.default_cluster_budget_bytes = ( + max(1, int(default_cluster_budget_bytes)) if default_cluster_budget_bytes is not None else None + ) + self.default_lease_ttl_s = float(default_lease_ttl_s) + self._node_budget: dict[str, int] = {} + self._node_used: dict[str, int] = {} + self._cluster_used = 0 + self._leases: dict[tuple[str, str], tuple[int, float | None]] = {} + + def register_node(self, node_id: str, budget_bytes: int | None = None) -> None: + budget = self.default_node_budget_bytes if budget_bytes is None else int(budget_bytes) + self._node_budget[node_id] = max(1, budget) + self._node_used.setdefault(node_id, 0) + self._reap_expired() + + def try_acquire(self, node_id: str, owner_id: str, amount_bytes: int, lease_ttl_s: float | None = None) -> bool: + amount = int(amount_bytes) + if amount <= 0: + return True + self._reap_expired() + self.register_node(node_id) + used = self._node_used[node_id] + budget = self._node_budget[node_id] + cluster_budget = self._cluster_budget_bytes() + if amount > budget: + return False + if used + amount > budget: + return False + if amount > cluster_budget: + return False + if self._cluster_used + amount > cluster_budget: + return False + ttl = self.default_lease_ttl_s if lease_ttl_s is None else float(lease_ttl_s) + self._node_used[node_id] = used + amount + self._cluster_used += amount + self._leases[(node_id, owner_id)] = (amount, _lease_expires_at(ttl)) + return True + + def heartbeat(self, node_id: str, owner_id: str, lease_ttl_s: float | None = None) -> bool: + self._reap_expired() + return self._heartbeat(node_id, owner_id, lease_ttl_s) + + def heartbeat_many(self, requests: list[tuple[str, str, float | None]]) -> list[bool]: + """Refresh several admission leases in one actor RPC and one reap pass.""" + self._reap_expired() + return [self._heartbeat(node_id, owner_id, lease_ttl_s) for node_id, owner_id, lease_ttl_s in requests] + + def _heartbeat(self, node_id: str, owner_id: str, lease_ttl_s: float | None) -> bool: + key = (node_id, owner_id) + if key not in self._leases: + return False + amount, expires_at = self._leases[key] + if expires_at is not None: + ttl = self.default_lease_ttl_s if lease_ttl_s is None else float(lease_ttl_s) + expires_at = _lease_expires_at(ttl) + self._leases[key] = (amount, expires_at) + return True + + def release(self, node_id: str, owner_id: str, amount_bytes: int | None = None) -> None: + key = (node_id, owner_id) + lease = self._leases.pop(key, None) + if lease is None: + return + reserved, _ = lease + amount = reserved if amount_bytes is None else min(reserved, int(amount_bytes)) + self._node_used[node_id] = max(0, self._node_used.get(node_id, 0) - amount) + self._cluster_used = max(0, self._cluster_used - amount) + + def resize(self, node_id: str, owner_id: str, new_amount_bytes: int, lease_ttl_s: float | None = None) -> bool: + self._reap_expired() + key = (node_id, owner_id) + lease = self._leases.get(key) + if lease is None: + return self.try_acquire(node_id, owner_id, new_amount_bytes, lease_ttl_s) + + old_amount, _ = lease + new_amount = int(new_amount_bytes) + if new_amount <= old_amount: + delta = old_amount - new_amount + self._node_used[node_id] = max(0, self._node_used.get(node_id, 0) - delta) + self._cluster_used = max(0, self._cluster_used - delta) + ttl = self.default_lease_ttl_s if lease_ttl_s is None else float(lease_ttl_s) + expires_at = lease[1] + if expires_at is not None: + expires_at = _lease_expires_at(ttl) + self._leases[key] = (new_amount, expires_at) + return True + + delta = new_amount - old_amount + budget = self._node_budget.get(node_id, self.default_node_budget_bytes) + used = self._node_used.get(node_id, 0) + cluster_budget = self._cluster_budget_bytes() + if used + delta > budget: + return False + if self._cluster_used + delta > cluster_budget: + return False + self._node_used[node_id] = used + delta + self._cluster_used += delta + ttl = self.default_lease_ttl_s if lease_ttl_s is None else float(lease_ttl_s) + expires_at = lease[1] + if expires_at is not None: + expires_at = _lease_expires_at(ttl) + self._leases[key] = (new_amount, expires_at) + return True + + def snapshot(self) -> dict[str, Any]: + self._reap_expired() + return { + "node_budget": dict(self._node_budget), + "node_used": dict(self._node_used), + "cluster_budget": self._cluster_budget_bytes(), + "cluster_used": self._cluster_used, + "lease_count": len(self._leases), + } + + def _cluster_budget_bytes(self) -> int: + if self.default_cluster_budget_bytes is not None: + return self.default_cluster_budget_bytes + return max(1, sum(self._node_budget.values()) or self.default_node_budget_bytes) + + def _reap_expired(self) -> None: + now = time.monotonic() + expired = [key for key, (_, expires_at) in self._leases.items() if expires_at is not None and expires_at < now] + for node_id, owner_id in expired: + self.release(node_id, owner_id) + + +@dataclass +class _StoredPayload: + payload: Any + amount_bytes: int + expires_at: float | None + + +class _PayloadStoreState: + def __init__(self, *, default_lease_ttl_s: float = _DEFAULT_LEASE_TTL_S) -> None: + self.default_lease_ttl_s = float(default_lease_ttl_s) + self._payloads: dict[str, _StoredPayload] = {} + + def put(self, payload_id: str, payload: Any, amount_bytes: int, lease_ttl_s: float | None = None) -> None: + self._reap_expired() + ttl = self.default_lease_ttl_s if lease_ttl_s is None else float(lease_ttl_s) + self._payloads[payload_id] = _StoredPayload(payload, int(amount_bytes), _lease_expires_at(ttl)) + + def get(self, payload_id: str, lease_ttl_s: float | None = None) -> Any: + self._reap_expired() + return self._get(payload_id, lease_ttl_s) + + def get_many(self, requests: list[tuple[str, float | None]]) -> list[Any]: + """Resolve several payloads in request order in one actor RPC and one reap pass.""" + self._reap_expired() + return [self._get(payload_id, lease_ttl_s) for payload_id, lease_ttl_s in requests] + + def _get(self, payload_id: str, lease_ttl_s: float | None) -> Any: + stored = self._payloads[payload_id] + if stored.expires_at is not None: + ttl = self.default_lease_ttl_s if lease_ttl_s is None else float(lease_ttl_s) + stored.expires_at = _lease_expires_at(ttl) + return stored.payload + + def pin(self, payload_id: str, lease_ttl_s: float | None = None) -> bool: + self._reap_expired() + return self._pin(payload_id, lease_ttl_s) + + def pin_many(self, requests: list[tuple[str, float | None]]) -> list[bool]: + """Refresh several store leases in one actor RPC and one reap pass.""" + self._reap_expired() + return [self._pin(payload_id, lease_ttl_s) for payload_id, lease_ttl_s in requests] + + def _pin(self, payload_id: str, lease_ttl_s: float | None) -> bool: + stored = self._payloads.get(payload_id) + if stored is None: + return False + if stored.expires_at is not None: + ttl = self.default_lease_ttl_s if lease_ttl_s is None else float(lease_ttl_s) + stored.expires_at = _lease_expires_at(ttl) + return True + + def release(self, payload_id: str) -> int: + stored = self._payloads.pop(payload_id, None) + if stored is None: + return 0 + return stored.amount_bytes + + def snapshot(self) -> dict[str, Any]: + self._reap_expired() + return { + "payload_count": len(self._payloads), + "payload_bytes": sum(payload.amount_bytes for payload in self._payloads.values()), + } + + def _reap_expired(self) -> None: + now = time.monotonic() + expired = [ + payload_id + for payload_id, payload in self._payloads.items() + if payload.expires_at is not None and payload.expires_at < now + ] + for payload_id in expired: + self._payloads.pop(payload_id, None) + + +def _kill_named_actor(name: str, namespace: str | None = None) -> bool: + try: + import ray + + actor = _get_named_actor(name, namespace) + ray.kill(actor, no_restart=True) + return True + except ValueError: + return False + except Exception as exc: + logger.warning(f"Failed to kill payload actor {name!r}: {exc}") + return False + + +def _active_ray_node_ids() -> list[str]: + try: + import ray + + return [str(node["NodeID"]) for node in ray.nodes() if node.get("NodeID") and node.get("Alive", True)] + except Exception: + return [] + + +def _get_named_actor_or_create( + actor_cls: type, + name: str, + *, + node_id: str | None = None, + namespace: str | None = None, + **kwargs: Any, +) -> Any: + import ray + + try: + return _get_named_actor(name, namespace) + except ValueError: + options: dict[str, Any] = {"name": name, "get_if_exists": True, "lifetime": "detached"} + if namespace: + options["namespace"] = namespace + if node_id: + from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + options["scheduling_strategy"] = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) + return ray.remote(actor_cls).options(**options).remote(**kwargs) + + +def _get_admission_actor( + actor_name: str, + *, + default_node_budget_bytes: int, + default_cluster_budget_bytes: int | None, + default_lease_ttl_s: float, + namespace: str | None, +) -> Any: + return _get_named_actor_or_create( + _PayloadAdmissionState, + actor_name, + namespace=namespace, + default_node_budget_bytes=default_node_budget_bytes, + default_cluster_budget_bytes=default_cluster_budget_bytes, + default_lease_ttl_s=default_lease_ttl_s, + ) + + +def _get_store_actor(actor_name: str, *, node_id: str, default_lease_ttl_s: float, namespace: str | None) -> Any: + return _get_named_actor_or_create( + _PayloadStoreState, + actor_name, + node_id=node_id, + namespace=namespace, + default_lease_ttl_s=default_lease_ttl_s, + ) + + +class _PayloadLeaseKeeper: + def __init__(self, payload_refs: list[PayloadRef], *, interval_s: float | None = None) -> None: + deduped: dict[tuple[str | None, str, str], PayloadRef] = {} + for payload_ref in payload_refs: + key = (payload_ref.actor_namespace, payload_ref.store_actor_name, payload_ref.payload_id) + deduped[key] = payload_ref + self._payload_refs = list(deduped.values()) + if interval_s is None: + ttl_s = min((payload_ref.lease_ttl_s for payload_ref in self._payload_refs), default=_DEFAULT_LEASE_TTL_S) + interval_s = min(30.0, max(1.0, ttl_s / 3.0)) + self._interval_s = float(interval_s) + self._stop = threading.Event() + self._thread: threading.Thread | None = None + self._warned = False + + def start(self) -> None: + if not self._payload_refs or self._thread is not None: + return + self._thread = threading.Thread( + target=self._run, + name="curator-payload-lease-keeper", + daemon=True, + ) + self._thread.start() + + def stop(self) -> None: + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=max(1.0, min(5.0, self._interval_s))) + self._thread = None + + def _run(self) -> None: + while not self._stop.wait(self._interval_s): + try: + heartbeat_payload_refs_batched(self._payload_refs) + except Exception as exc: + if not self._warned: + logger.warning( + "Payload lease heartbeat failed; one or more payloads may expire during long stage work: {}", + exc, + ) + self._warned = True + + +class PayloadAwareStageMixin: + """Mixin for stages that need waveform payload handles at ``process_batch`` time.""" + + waveform_ref_key: str | None + waveform_key: str + sample_rate_key: str + num_samples_key: str + + def payload_bindings(self) -> list[dict[str, str]]: + """Return payload-ref bindings consumed by this stage. + + Stages with one waveform can rely on the legacy ``waveform_*`` fields. + Multi-input stages can override this method and return one mapping per + payload, each with ``ref_key`` and ``waveform_key`` plus optional + ``sample_rate_key`` and ``num_samples_key``. + """ + + payload_ref_key = getattr(self, "waveform_ref_key", None) + if not payload_ref_key: + return [] + return [ + { + "ref_key": str(payload_ref_key), + "waveform_key": str(getattr(self, "waveform_key", "waveform")), + "sample_rate_key": str(getattr(self, "sample_rate_key", "sample_rate")), + "num_samples_key": str(getattr(self, "num_samples_key", "num_samples")), + } + ] + + def resolve_payload_refs_for_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + bindings = self.payload_bindings() + if not bindings: + return [] + self._stop_payload_lease_keeper() + inserted: list[AudioTask] = [] + payload_refs: list[PayloadRef] = [] + pending: list[tuple[AudioTask, dict[str, str], PayloadRef]] = [] + consumer_node_id = _resolve_node_id() + resolution_start = time.perf_counter() + same_node_count = 0 + cross_node_count = 0 + resolved_bytes = 0 + try: + for task in tasks: + task_inserted = False + for binding in bindings: + payload_ref_key = binding["ref_key"] + payload_key = binding["waveform_key"] + if payload_key in task.data: + continue + payload_ref = task.data.get(payload_ref_key) + if payload_ref is None: + continue + if not isinstance(payload_ref, PayloadRef): + msg = ( + f"Task {task.task_id} has non-PayloadRef '{payload_ref_key}' " + f"value: {type(payload_ref).__name__}" + ) + raise TypeError(msg) + if payload_ref.owner_node_id and payload_ref.owner_node_id == consumer_node_id: + same_node_count += 1 + else: + cross_node_count += 1 + resolved_bytes += int(payload_ref.amount_bytes) + pending.append((task, binding, payload_ref)) + payload_refs.append(payload_ref) + task_inserted = True + if task_inserted: + inserted.append(task) + + payloads = resolve_payload_refs_batched( + payload_refs, + max_batch_bytes=getattr(self, "payload_resolve_max_batch_bytes", None), + ) + for (task, binding, payload_ref), payload in zip(pending, payloads, strict=True): + task.data[binding["waveform_key"]] = payload + task.data[binding.get("sample_rate_key", "sample_rate")] = payload_ref.sample_rate + task.data.setdefault(binding.get("num_samples_key", "num_samples"), payload_ref.num_samples) + except Exception: + for task in inserted: + for binding in bindings: + task.data.pop(binding["waveform_key"], None) + self._stop_payload_lease_keeper() + raise + self._start_payload_lease_keeper(payload_refs) + if payload_refs: + log_metrics = getattr(self, "_log_metrics", None) + if callable(log_metrics): + log_metrics( + { + "payload_resolution_count": float(len(payload_refs)), + "payload_resolution_same_node_count": float(same_node_count), + "payload_resolution_cross_node_count": float(cross_node_count), + "payload_resolution_bytes": float(resolved_bytes), + "payload_resolution_time_s": time.perf_counter() - resolution_start, + } + ) + return inserted + + def drop_resolved_payloads(self, tasks: list[AudioTask]) -> None: + self._stop_payload_lease_keeper() + payload_keys = [binding["waveform_key"] for binding in self.payload_bindings()] + for task in tasks: + for payload_key in payload_keys: + task.data.pop(payload_key, None) + + def terminal_tombstone_drop_data_keys(self) -> tuple[str, ...]: + return tuple({binding["waveform_key"] for binding in self.payload_bindings()}) + + @staticmethod + def payload_consumer_node_id() -> str: + """Return the node currently resolving payloads for locality metrics.""" + return _resolve_node_id() + + def _start_payload_lease_keeper(self, payload_refs: list[PayloadRef]) -> None: + if not payload_refs: + return + keeper = _PayloadLeaseKeeper(payload_refs) + keeper.start() + self._payload_lease_keeper = keeper + + def _stop_payload_lease_keeper(self) -> None: + keeper = getattr(self, "_payload_lease_keeper", None) + if keeper is None: + return + keeper.stop() + self._payload_lease_keeper = None + + +@dataclass +class AudioPayloadMaterializeStage(ProcessingStage[AudioTask, AudioTask]): + """Read audio once, decode to memory, and replace the waveform with a payload handle.""" + + _curator_pipeline_helper_stage = True + + name: str = "AudioPayloadMaterializeStage" + target_sample_rate: int = 16000 + target_nchannels: int = 1 + audio_filepath_key: str = "audio_filepath" + duration_key: str = "duration" + segment_start_key: str = "segment_start_s" + segment_duration_key: str = "segment_duration_s" + waveform_key: str = "waveform" + waveform_ref_key: str = "waveform_ref" + sample_rate_key: str = "sample_rate" + num_samples_key: str = "num_samples" + skip_me_key: str = "_skip_me" + read_error_key: str = "audio_read_error" + skip_on_read_error: bool = False + node_memory_fraction: float = _DEFAULT_NODE_MEMORY_FRACTION + max_node_payload_bytes: int | str | None = None + max_cluster_payload_bytes: int | str | None = None + lease_ttl_s: float = _DEFAULT_LEASE_TTL_S + materialized_lease_ttl_s: float = _DEFAULT_MATERIALIZED_LEASE_TTL_S + admission_poll_interval_s: float = _DEFAULT_POLL_INTERVAL_S + admission_wait_timeout_s: float = _DEFAULT_ADMISSION_WAIT_TIMEOUT_S + admission_actor_name: str = "curator_payload_admission" + store_actor_prefix: str = "curator_payload_store" + run_id: str | None = None + sample_width_bytes: int = _DEFAULT_SAMPLE_WIDTH_BYTES + verbose: bool = False + + _reader: Any = field(init=False, default=None, repr=False) + _node_id: str = field(init=False, default="", repr=False) + _node_budget_bytes: int = field(init=False, default=0, repr=False) + _cluster_budget_bytes: int | None = field(init=False, default=None, repr=False) + _actor_run_suffix: str = field(init=False, default="", repr=False) + _admission_actor_name: str = field(init=False, default="", repr=False) + _store_actor_name: str = field(init=False, default="", repr=False) + _actor_namespace: str | None = field(init=False, default=None, repr=False) + _admission: Any = field(init=False, default=None, repr=False) + _store: Any = field(init=False, default=None, repr=False) + + def __post_init__(self) -> None: + if self.lease_ttl_s <= 0: + raise ValueError("lease_ttl_s must be positive while a payload is being materialized") + if self.materialized_lease_ttl_s <= 0: + raise ValueError("materialized_lease_ttl_s must be positive") + if self.admission_poll_interval_s <= 0: + raise ValueError("admission_poll_interval_s must be positive") + if self.admission_wait_timeout_s <= 0: + raise ValueError("admission_wait_timeout_s must be positive") + self.batch_size = 1 + if self.resources is None: + self.resources = {"cpus": 1.0} + self.run_id = str(self.run_id or uuid.uuid4().hex) + self._actor_run_suffix = _safe_actor_suffix(self.run_id) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key, self.duration_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.waveform_ref_key, self.sample_rate_key, self.num_samples_key] + + def setup_on_node(self, node_info: NodeInfo, worker_metadata: dict[str, Any] | None = None) -> None: + self._node_id = node_info.node_id or _resolve_node_id() + self._ensure_ready() + self._reader.setup_on_node(node_info, worker_metadata) + + def setup(self, worker_metadata: dict[str, Any] | None = None) -> None: + self._ensure_ready() + self._reader.setup(worker_metadata) + + def teardown(self) -> None: + if self._reader is not None: + self._reader.teardown() + + def process(self, task: AudioTask) -> AudioTask: + self._ensure_ready() + payload_id = uuid.uuid4().hex + estimated_bytes = _task_payload_estimate_bytes( + task, + duration_key=self._estimate_duration_key(task), + sample_rate=self.target_sample_rate, + channels=self.target_nchannels, + sample_width_bytes=self.sample_width_bytes, + ) + admission_wait_s, admission_poll_count = self._acquire(payload_id, estimated_bytes) + reserved_bytes = estimated_bytes + stored = False + self._log_metrics( + { + "payload_admission_wait_s": admission_wait_s, + "payload_admission_poll_count": admission_poll_count, + "payload_estimated_bytes": float(estimated_bytes), + "payload_reserved_bytes": float(reserved_bytes), + "payload_node_budget_bytes": float(self._node_budget_bytes), + "payload_cluster_budget_bytes": float(self._cluster_budget_bytes or 0), + } + ) + try: + decoded = self._reader.process(task) + waveform = decoded.data.pop(self.waveform_key, None) + if waveform is None: + self._release(payload_id, reserved_bytes) + return decoded + if self._is_reader_skip_result(decoded): + self._release(payload_id, reserved_bytes) + decoded.data.pop(self.waveform_key, None) + return decoded + + actual_bytes = _payload_object_bytes(waveform) + if actual_bytes <= 0: + self._release(payload_id, reserved_bytes) + raise RuntimeError("Decoded audio waveform has unknown or zero byte size") + + if actual_bytes != reserved_bytes: + if not _ray_get( + self._admission.resize.remote( + self._node_id, + payload_id, + actual_bytes, + self.lease_ttl_s, + ) + ): + self._release(payload_id, reserved_bytes) + raise RuntimeError( + "Insufficient payload memory budget after audio decode " + f"(estimated={reserved_bytes}, actual={actual_bytes})" + ) + reserved_bytes = actual_bytes + + _ray_get(self._store.put.remote(payload_id, waveform, actual_bytes, self.materialized_lease_ttl_s)) + stored = True + self._log_metrics( + { + "payload_stored_bytes": float(actual_bytes), + "payload_materialized_count": 1.0, + } + ) + decoded.data[self.waveform_ref_key] = PayloadRef( + payload_id=payload_id, + owner_node_id=self._node_id, + store_actor_name=self._store_actor_name, + admission_actor_name=self._admission_actor_name, + amount_bytes=actual_bytes, + sample_rate=int(decoded.data[self.sample_rate_key]), + num_samples=int(decoded.data[self.num_samples_key]), + lease_ttl_s=self.lease_ttl_s, + actor_namespace=self._actor_namespace, + ) + decoded.data["_curator_payload_estimated_bytes"] = estimated_bytes + decoded.data["_curator_payload_bytes"] = actual_bytes + if not _ray_get( + self._admission.heartbeat.remote( + self._node_id, + payload_id, + self.materialized_lease_ttl_s, + ) + ): + raise RuntimeError(f"Payload reservation expired before materialization completed: {payload_id}") + return decoded + except Exception: + if stored: + try: + _ray_get(self._store.release.remote(payload_id)) + except Exception: + logger.debug("Failed to release stored payload {} after materialization error", payload_id) + self._release(payload_id, reserved_bytes) + task.data.pop(self.waveform_key, None) + raise + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + for task in tasks: + if not self.validate_input(task): + msg = f"Task {task!s} failed validation for stage {self}" + raise ValueError(msg) + return [self.process(task) for task in tasks] + + def _ensure_ready(self) -> None: + if self._reader is None: + from nemo_curator.stages.audio.io.audio_file_reader import AudioFileReaderStage + + self._reader = AudioFileReaderStage( + target_sample_rate=self.target_sample_rate, + target_nchannels=self.target_nchannels, + audio_filepath_key=self.audio_filepath_key, + duration_key=self.duration_key, + segment_start_key=self.segment_start_key, + segment_duration_key=self.segment_duration_key, + waveform_key=self.waveform_key, + sample_rate_key=self.sample_rate_key, + num_samples_key=self.num_samples_key, + skip_me_key=self.skip_me_key, + read_error_key=self.read_error_key, + skip_on_read_error=self.skip_on_read_error, + verbose=self.verbose, + ) + + if self._admission is None or self._store is None: + self._node_id = self._node_id or _resolve_node_id() + self._actor_namespace = _current_ray_namespace() + explicit_budget = self.max_node_payload_bytes + if isinstance(explicit_budget, str): + explicit_budget = _parse_byte_limit(explicit_budget, field_name="max_node_payload_bytes") + explicit_cluster_budget = self.max_cluster_payload_bytes + if isinstance(explicit_cluster_budget, str): + explicit_cluster_budget = _parse_byte_limit( + explicit_cluster_budget, + field_name="max_cluster_payload_bytes", + ) + self._cluster_budget_bytes = explicit_cluster_budget + self._node_budget_bytes = _resolve_node_payload_budget( + explicit_budget, + self.node_memory_fraction, + ) + self._admission_actor_name = f"{self.admission_actor_name}_{self._actor_run_suffix}" + self._store_actor_name = ( + f"{self.store_actor_prefix}_{self._actor_run_suffix}_{_safe_actor_suffix(self._node_id)}" + ) + self._admission = _get_admission_actor( + self._admission_actor_name, + default_node_budget_bytes=self._node_budget_bytes, + default_cluster_budget_bytes=explicit_cluster_budget, + default_lease_ttl_s=self.lease_ttl_s, + namespace=self._actor_namespace, + ) + self._store = _get_store_actor( + self._store_actor_name, + node_id=self._node_id, + default_lease_ttl_s=self.lease_ttl_s, + namespace=self._actor_namespace, + ) + _ray_get(self._admission.register_node.remote(self._node_id, self._node_budget_bytes)) + + def _estimate_duration_key(self, task: AudioTask) -> str: + if self.segment_duration_key in task.data and task.data.get(self.segment_duration_key) is not None: + return self.segment_duration_key + return self.duration_key + + def _is_reader_skip_result(self, task: AudioTask) -> bool: + if not self.skip_on_read_error: + return False + return self.skip_me_key in task.data or self.read_error_key in task.data + + def _acquire(self, payload_id: str, amount_bytes: int) -> tuple[float, int]: + if amount_bytes > self._node_budget_bytes: + raise RuntimeError( + f"Single audio payload estimate {amount_bytes} bytes exceeds node payload budget " + f"{self._node_budget_bytes} bytes" + ) + if self._cluster_budget_bytes is not None and amount_bytes > self._cluster_budget_bytes: + raise RuntimeError( + f"Single audio payload estimate {amount_bytes} bytes exceeds cluster payload budget " + f"{self._cluster_budget_bytes} bytes" + ) + start = time.perf_counter() + polls = 0 + while True: + polls += 1 + acquired = _ray_get( + self._admission.try_acquire.remote( + self._node_id, + payload_id, + amount_bytes, + self.lease_ttl_s, + ) + ) + if acquired: + return time.perf_counter() - start, polls + elapsed_s = time.perf_counter() - start + if elapsed_s >= self.admission_wait_timeout_s: + snapshot = _ray_get(self._admission.snapshot.remote()) + raise RuntimeError( + "Timed out waiting for payload admission " + f"after {elapsed_s:.3f}s for {amount_bytes} bytes; admission={snapshot}" + ) + time.sleep(self.admission_poll_interval_s) + + def _release(self, payload_id: str, amount_bytes: int) -> None: + try: + _ray_get(self._admission.release.remote(self._node_id, payload_id, amount_bytes)) + except Exception: + logger.debug("Failed to release payload admission tokens for {}", payload_id) + + def cleanup_run_resources(self) -> None: + suffix = _safe_actor_suffix(str(self.run_id)) + namespace = self._actor_namespace or _current_ray_namespace() + _kill_named_actor(f"{self.admission_actor_name}_{suffix}", namespace) + + store_prefix = f"{self.store_actor_prefix}_{suffix}_" + for node_id in _active_ray_node_ids(): + _kill_named_actor(f"{store_prefix}{_safe_actor_suffix(node_id)}", namespace) + + def ray_stage_spec(self) -> dict[str, Any]: + spec = super().ray_stage_spec() + spec[RayStageSpecKeys.IS_ACTOR_STAGE] = False + return spec + + def xenna_stage_spec(self) -> dict[str, Any]: + return { + "is_actor_stage": False, + "is_fanout_stage": False, + "is_repartition_stage": False, + } + + +@dataclass +class PayloadReleaseStage(ProcessingStage[AudioTask, AudioTask]): + _curator_pipeline_helper_stage = True + + name: str = "PayloadReleaseStage" + payload_ref_key: str = "waveform_ref" + waveform_key: str = "waveform" + remove_payload_metadata: bool = True + + def __post_init__(self) -> None: + self.batch_size = 1 + if self.resources is None: + self.resources = {"cpus": 0.1} + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: AudioTask) -> AudioTask: + released_ids: set[str] = set() + released_bytes = 0 + for payload_ref in task_payload_refs(task): + if payload_ref.payload_id in released_ids: + continue + release_payload_ref(payload_ref) + released_ids.add(payload_ref.payload_id) + released_bytes += int(payload_ref.amount_bytes) + if released_ids: + self._log_metrics( + { + "payload_release_count": float(len(released_ids)), + "payload_release_bytes": float(released_bytes), + } + ) + if isinstance(task.data, dict): + stripped_data = strip_payload_refs(task.data) + task.data.clear() + task.data.update(stripped_data) + task.data.pop(self.waveform_key, None) + if self.remove_payload_metadata: + for key in tuple(task.data): + if str(key).startswith("_curator_payload_"): + task.data.pop(key, None) + return task + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + return [self.process(task) for task in tasks] + + def ray_stage_spec(self) -> dict[str, Any]: + spec = super().ray_stage_spec() + spec[RayStageSpecKeys.IS_ACTOR_STAGE] = False + return spec + + def xenna_stage_spec(self) -> dict[str, Any]: + return { + "is_actor_stage": False, + "is_fanout_stage": False, + "is_repartition_stage": False, + } diff --git a/nemo_curator/tasks/task_terminals.py b/nemo_curator/tasks/task_terminals.py new file mode 100644 index 0000000000..ea7a2ee95f --- /dev/null +++ b/nemo_curator/tasks/task_terminals.py @@ -0,0 +1,169 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: BLE001, C901, PLR0912 + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from nemo_curator.stages.base import ProcessingStage + from nemo_curator.tasks import Task + + +TERMINAL_GROUP_ID_KEY = "_curator_terminal_group_id" +TERMINAL_INDEX_KEY = "_curator_terminal_idx" +TERMINAL_COUNT_KEY = "_curator_terminal_count" +TERMINAL_SOURCE_INDEX_KEY = "_curator_terminal_source_index" +TERMINAL_DROPPED_KEY = "_curator_terminal_dropped" +TERMINAL_DROPPED_BY_STAGE_KEY = "_curator_terminal_dropped_by_stage" +TERMINAL_DROP_REASON_KEY = "_curator_terminal_drop_reason" + + +def preserve_dropped_terminal_tasks( + stage: ProcessingStage, + input_tasks: list[Task], + output_tasks: list[Task | None], +) -> list[Task | None]: + """Preserve terminal records required by downstream aggregators. + + Normal filters should still drop rows. A planner may, however, split a + logical row into ordered terminal records that a later assembler must see + exactly once. If an intermediate stage filters such a terminal record, turn + it into a lightweight tombstone so the assembler can finish the parent + instead of buffering forever. + + Terminal records are identified by the generic ``_curator_terminal_*`` + fields. Audio global planning also carries ``_curator_segment_*`` debug + aliases, but the terminal fields are the backend-preserved contract. + """ + if getattr(stage, "_curator_consumes_segment_rows", False): + return output_tasks + if not any(_terminal_row_key(task) is not None for task in input_tasks): + return output_tasks + + if len(output_tasks) == len(input_tasks): + changed = False + preserved: list[Task | None] = [] + for input_task, output_task in zip(input_tasks, output_tasks, strict=True): + if output_task is None and _terminal_row_key(input_task) is not None: + preserved.append(_terminal_row_tombstone(stage, input_task)) + changed = True + else: + preserved.append(output_task) + if changed: + return preserved + + output_by_key: dict[tuple[str, int, int], Task] = {} + duplicate_output_key = False + non_segment_outputs: list[Task] = [] + for output_task in output_tasks: + if output_task is None: + continue + key = _terminal_row_key(output_task) + if key is None: + non_segment_outputs.append(output_task) + continue + if key in output_by_key: + duplicate_output_key = True + output_by_key[key] = output_task + + input_keys = [_terminal_row_key(task) for task in input_tasks] + if not any(key is not None and key not in output_by_key for key in input_keys): + return output_tasks + + if not duplicate_output_key and not non_segment_outputs and all(key is not None for key in input_keys): + return [ + output_by_key.get(key) or _terminal_row_tombstone(stage, input_task) + for input_task, key in zip(input_tasks, input_keys, strict=True) + ] + + preserved = [task for task in output_tasks if task is not None] + for input_task, key in zip(input_tasks, input_keys, strict=True): + if key is not None and key not in output_by_key: + preserved.append(_terminal_row_tombstone(stage, input_task)) + return preserved + + +def _terminal_row_key(task: object) -> tuple[str, int, int] | None: + data = getattr(task, "data", None) + if not isinstance(data, dict): + return None + if TERMINAL_GROUP_ID_KEY in data: + try: + return ( + str(data[TERMINAL_GROUP_ID_KEY]), + int(data.get(TERMINAL_INDEX_KEY, 0)), + int(data.get(TERMINAL_COUNT_KEY, 1)), + ) + except (TypeError, ValueError): + return None + if "_curator_segment_parent_id" not in data: + return None + try: + return ( + str(data["_curator_segment_parent_id"]), + int(data.get("_curator_segment_idx", 0)), + int(data.get("_curator_segment_count", 1)), + ) + except (TypeError, ValueError): + return None + + +def _terminal_row_tombstone(stage: ProcessingStage, task: Task) -> Task: + data = dict(getattr(task, "data", {}) or {}) + data = _strip_payload_refs(data) + for key in _terminal_tombstone_drop_data_keys(stage): + data.pop(key, None) + skip_key = str(getattr(stage, "skip_me_key", "_skip_me") or "_skip_me") + data.setdefault(skip_key, "dropped_segment_row") + stage_name = str(getattr(stage, "name", type(stage).__name__)) + data[TERMINAL_DROPPED_KEY] = True + data[TERMINAL_DROPPED_BY_STAGE_KEY] = stage_name + data.setdefault(TERMINAL_DROP_REASON_KEY, "dropped_before_terminal_assembly") + if "_curator_segment_parent_id" in data: + data["_curator_segment_dropped"] = True + data["_curator_segment_dropped_by_stage"] = stage_name + data.setdefault("_curator_segment_drop_reason", "dropped_before_segment_assembly") + try: + tombstone = task.__class__( + dataset_name=task.dataset_name, + data=data, + _stage_perf=list(task._stage_perf), + _metadata=dict(task._metadata), + ) + except TypeError: + tombstone = copy.copy(task) + tombstone.data = data + tombstone._stage_perf = list(task._stage_perf) + tombstone._metadata = dict(task._metadata) + tombstone.task_id = task.task_id + return tombstone + + +def _strip_payload_refs(data: dict) -> dict: + try: + from nemo_curator.pipeline.payload_refs import strip_payload_refs + except Exception: + return data + stripped = strip_payload_refs(data) + return stripped if isinstance(stripped, dict) else data + + +def _terminal_tombstone_drop_data_keys(stage: ProcessingStage) -> tuple[str, ...]: + drop_keys = getattr(stage, "terminal_tombstone_drop_data_keys", None) + if not callable(drop_keys): + return () + return tuple(str(key) for key in drop_keys() if str(key)) diff --git a/nemo_curator/utils/gpu_sampler.py b/nemo_curator/utils/gpu_sampler.py new file mode 100644 index 0000000000..6c41a5fd0b --- /dev/null +++ b/nemo_curator/utils/gpu_sampler.py @@ -0,0 +1,173 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Background NVML GPU-utilization sampler (adapter-agnostic). + +A worker-local daemon thread polls NVML for SM utilization and memory-used +percent. By default it samples all NVML-visible devices on the node so hardware +metrics are independent of which processor/actor owns a GPU. Callers may pass +``gpu_uuids`` with ``sample_all_visible=False`` when they need actor-local +attribution only. Per physical GPU: ``window_stats`` returns a +``{normalized_uuid: {gpu_util_pct, gpu_mem_used_pct}}`` mean over ``[t0, t1]``. + +No-op when ``pynvml`` is unavailable, no UUIDs match, or NVML raises (fields +simply omitted). ``gpu`` here is NVML SM duty-cycle percent, not FLOP efficiency. +""" + +from __future__ import annotations + +import threading +import time +from collections import deque + +from loguru import logger + + +def norm_uuid(value: object) -> str: + """Normalize a GPU UUID for comparison (drop ``GPU-`` prefix, lowercase).""" + text = value.decode() if isinstance(value, bytes) else str(value) + return text.strip().lower().removeprefix("gpu-") + + +class GpuUtilSampler: + """Polls NVML in a background thread; reports windowed mean util/mem.""" + + def __init__( + self, + gpu_uuids: tuple[str, ...] = (), + interval_s: float = 0.2, + *, + sample_all_visible: bool = True, + ) -> None: + self._target_uuids = {norm_uuid(u) for u in (gpu_uuids or ()) if str(u).strip()} + self._sample_all_visible = bool(sample_all_visible) + self._interval_s = max(float(interval_s), 0.02) + self._handles: list[object] = [] + # normalized UUID per handle -- the consumer's per-GPU attribution key. + self._handle_keys: list[str] = [] + # (t, [util% per handle], [mem% per handle]) aligned to ``_handles``; + # time-ordered and pruned by ``window_stats`` to stay bounded. + self._samples: deque[tuple[float, list[float | None], list[float | None]]] = deque() + self._lock = threading.Lock() + self._stop = threading.Event() + self._thread: threading.Thread | None = None + self._pynvml = None + self._read_error_count = 0 + + def _resolve_handles(self) -> None: + import pynvml + + self._pynvml = pynvml + pynvml.nvmlInit() + for idx in range(pynvml.nvmlDeviceGetCount()): + handle = pynvml.nvmlDeviceGetHandleByIndex(idx) + key = norm_uuid(pynvml.nvmlDeviceGetUUID(handle)) + if self._sample_all_visible or key in self._target_uuids: + self._handles.append(handle) + self._handle_keys.append(key) + + def start(self) -> None: + try: + self._resolve_handles() + except Exception as exc: # noqa: BLE001 + logger.debug("GPU sampler disabled: NVML handle resolution failed: {}", exc) + self._handles = [] + if not self._handles: + if self._target_uuids: + logger.debug( + "GPU sampler disabled: no NVML handles matched target UUIDs {}", sorted(self._target_uuids) + ) + else: + logger.debug("GPU sampler disabled: no target GPU UUIDs were provided") + return + self._thread = threading.Thread(target=self._loop, name="gpu-util-sampler", daemon=True) + self._thread.start() + + def _loop(self) -> None: + pynvml = self._pynvml + n = len(self._handles) + while not self._stop.is_set(): + # Position-aligned to ``_handles`` (None on read error so a transient + # failure on one GPU never shifts the others). + utils: list[float | None] = [None] * n + mems: list[float | None] = [None] * n + for k, handle in enumerate(self._handles): + try: + utils[k] = float(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu) + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + mems[k] = 100.0 * float(mem.used) / float(mem.total) if mem.total else 0.0 + except Exception as exc: # noqa: BLE001 + self._read_error_count += 1 + if self._read_error_count == 1 or self._read_error_count % 100 == 0: + logger.debug("GPU sampler NVML read failed for handle {}: {}", k, exc) + continue + with self._lock: + self._samples.append((time.time(), utils, mems)) + self._stop.wait(self._interval_s) + + def window_stats(self, t0: float, t1: float) -> dict[str, dict[str, float]]: + """Per-GPU mean util/mem over ``[t0, t1]``, keyed by normalized UUID. + + Returns ``{uuid: {gpu_util_pct, gpu_mem_used_pct}}`` (empty if no samples + landed in the window); the consumer maps each UUID back to a physical index. + """ + n = len(self._handles) + util_sum = [0.0] * n + util_cnt = [0] * n + mem_sum = [0.0] * n + mem_cnt = [0] * n + with self._lock: + # Windows advance monotonically (batches run sequentially), so drop + # anything older than ``t0`` -- never reused, keeps the deque bounded. + while self._samples and self._samples[0][0] < t0: + self._samples.popleft() + for ts, utils, mems in self._samples: + if ts > t1: + break # time-ordered: no later sample falls in the window + for k in range(n): + if utils[k] is not None: + util_sum[k] += utils[k] + util_cnt[k] += 1 + if mems[k] is not None: + mem_sum[k] += mems[k] + mem_cnt[k] += 1 + result: dict[str, dict[str, float]] = {} + for k, key in enumerate(self._handle_keys): + if not util_cnt[k]: + continue + result[key] = { + "gpu_util_pct": util_sum[k] / util_cnt[k], + "gpu_mem_used_pct": (mem_sum[k] / mem_cnt[k]) if mem_cnt[k] else 0.0, + } + return result + + def diagnostics(self) -> dict[str, float]: + """Small scalar state so perf summaries explain missing GPU samples.""" + return { + "gpu_sampler_active": float(self._thread is not None and bool(self._handles)), + "gpu_sampler_handle_count": float(len(self._handles)), + "gpu_sampler_target_uuid_count": float(len(self._target_uuids)), + "gpu_sampler_sample_all_visible": float(self._sample_all_visible), + "gpu_sampler_error_count": float(self._read_error_count), + } + + def stop(self) -> None: + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=2.0) + try: + if self._pynvml is not None: + self._pynvml.nvmlShutdown() + except Exception as exc: # noqa: BLE001 + logger.debug("GPU sampler NVML shutdown failed: {}", exc) diff --git a/nemo_curator/utils/performance_utils.py b/nemo_curator/utils/performance_utils.py index 897b79b4c4..7e815c487f 100644 --- a/nemo_curator/utils/performance_utils.py +++ b/nemo_curator/utils/performance_utils.py @@ -17,7 +17,7 @@ import contextlib import statistics import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import attrs from loguru import logger @@ -28,9 +28,27 @@ from nemo_curator.stages.base import ProcessingStage +#: Identity fields on ``StagePerfStats``: best-effort string labels for the +#: actor/node/GPU that produced the record. They are metadata, not numeric +#: metrics, so they MUST be excluded from ``items()`` -- downstream collection +#: calls ``float()`` on every yielded value and would crash on a string. +_IDENTITY_FIELDS = ( + "invocation_id", + "actor_id", + "node_id", + "gpu_id", + "physical_address", + "pod_ip", + "hostname", + "gpu_indices", + "gpu_uuids", +) + + @attrs.define class StagePerfStats: """Statistics for tracking stage performance metrics. + Attributes: stage_name: Name of the processing stage. process_time: Total processing time in seconds. @@ -38,6 +56,14 @@ class StagePerfStats: input_data_size_mb: Size of input data in megabytes. num_items_processed: Number of items processed in this stage. custom_metrics: Custom metrics to track. + invocation_id: Unique id for ONE ``process_batch`` call. The same record + is attached to every output task of that call, so the audio summary + dedups on it. Empty when unset -- consumers then fall back to a + value-tuple fingerprint. + actor_id: Best-effort label of the producing actor. Empty when unknown. + node_id: Best-effort node label. Empty when unknown. + gpu_id: Best-effort GPU label ``":"``. Empty for + CPU stages / when unknown. """ stage_name: str @@ -46,9 +72,28 @@ class StagePerfStats: input_data_size_mb: float = 0.0 num_items_processed: int = 0 custom_metrics: dict[str, float] = attrs.field(factory=dict) + # identity (metadata, never a numeric metric -- see _IDENTITY_FIELDS) + invocation_id: str = "" + actor_id: str = "" + node_id: str = "" + gpu_id: str = "" + physical_address: str = "" + pod_ip: str = "" + hostname: str = "" + gpu_indices: list[int] = attrs.field(factory=list) + gpu_uuids: list[str] = attrs.field(factory=list) def __add__(self, other: StagePerfStats) -> StagePerfStats: - """Add two StagePerfStats.""" + """Add two StagePerfStats, summing scalars and custom metrics. + + Identity is per-worker, so it survives only when both operands share it; + a cross-worker sum clears identity + invocation_id rather than mis-attribute. + """ + same_worker = ( + self.actor_id == other.actor_id + and self.node_id == other.node_id + and self.physical_address == other.physical_address + ) return StagePerfStats( stage_name=self.stage_name, process_time=self.process_time + other.process_time, @@ -59,6 +104,16 @@ def __add__(self, other: StagePerfStats) -> StagePerfStats: key: self.custom_metrics.get(key, 0.0) + other.custom_metrics.get(key, 0.0) for key in set(self.custom_metrics.keys()) | set(other.custom_metrics.keys()) }, + # invocation_id identifies a single call -- a sum is not one call. + invocation_id="", + actor_id=self.actor_id if same_worker else "", + node_id=self.node_id if same_worker else "", + gpu_id=self.gpu_id if same_worker else "", + physical_address=self.physical_address if same_worker else "", + pod_ip=self.pod_ip if same_worker else "", + hostname=self.hostname if same_worker else "", + gpu_indices=list(self.gpu_indices) if same_worker else [], + gpu_uuids=list(self.gpu_uuids) if same_worker else [], ) def __radd__(self, other: int | StagePerfStats) -> StagePerfStats: @@ -77,9 +132,29 @@ def reset(self) -> None: self.input_data_size_mb = 0.0 self.num_items_processed = 0 self.custom_metrics = {} + self.invocation_id = "" + self.actor_id = "" + self.node_id = "" + self.gpu_id = "" + self.physical_address = "" + self.pod_ip = "" + self.hostname = "" + self.gpu_indices = [] + self.gpu_uuids = [] def to_dict(self) -> dict[str, float | int]: - """Convert the stats to a dictionary.""" + """Convert to the stable main-branch public dictionary schema.""" + return { + "stage_name": self.stage_name, + "process_time": self.process_time, + "actor_idle_time": self.actor_idle_time, + "input_data_size_mb": self.input_data_size_mb, + "num_items_processed": self.num_items_processed, + "custom_metrics": dict(self.custom_metrics), + } + + def to_extended_dict(self) -> dict[str, Any]: + """Convert to the complete observability schema, including identity.""" return attrs.asdict(self) def items(self) -> list[tuple[str, float | int]]: @@ -88,7 +163,10 @@ def items(self) -> list[tuple[str, float | int]]: """ res = self.to_dict() res.pop("stage_name", None) - # Extract and drop the raw custom_metrics dict from the flattened output + # Identity fields are string metadata; downstream collectors call float() + # on every yielded value, so they MUST be dropped here. + for identity_field in _IDENTITY_FIELDS: + res.pop(identity_field, None) custom_metrics = res.pop("custom_metrics", {}) # Flatten custom_metrics with a stable prefix for key, value in custom_metrics.items(): @@ -97,9 +175,7 @@ def items(self) -> list[tuple[str, float | int]]: class StageTimer: - """Tracker for stage performance stats. - Tracks processing time and other metrics at a per process_data call level. - """ + """Tracks processing time and other metrics per process_data call.""" def __init__(self, stage: ProcessingStage) -> None: """Initialize the stage timer. @@ -123,7 +199,6 @@ def _reset(self) -> None: def reinit(self, stage_input_size: int = 1) -> None: """Reinitialize the stage timer. Args: - stage: The stage to reinitialize the timer for. stage_input_size: The size of the stage input. """ self._reset() diff --git a/nemo_curator/utils/pipeline_hardware_sampler.py b/nemo_curator/utils/pipeline_hardware_sampler.py new file mode 100644 index 0000000000..fd66542207 --- /dev/null +++ b/nemo_curator/utils/pipeline_hardware_sampler.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: C901, PLR0912 + +"""Run-level, observational hardware sampling helpers. + +Stage-level GPU samples answer "was this actor busy during this invocation?" +This module samples visible GPUs over the whole pipeline run so benchmark +summaries can also answer "were the requested devices generally occupied?". +It is intentionally fail-open and does not influence placement or autoscaling. +""" + +from __future__ import annotations + +import contextlib +import time +from typing import Any + +from loguru import logger + +from nemo_curator.utils.gpu_sampler import GpuUtilSampler + + +class _PipelineHardwareSamplerActor: + def __init__(self, interval_s: float = 0.5) -> None: + import ray + + self._node_id = str(ray.get_runtime_context().get_node_id()) + self._started_at = time.time() + self._sampler = GpuUtilSampler(interval_s=interval_s, sample_all_visible=True) + self._sampler.start() + + def node_id(self) -> str: + return self._node_id + + def stop(self) -> dict[str, float]: + stopped_at = time.time() + stats = self._sampler.window_stats(self._started_at, stopped_at) + diagnostics = self._sampler.diagnostics() + self._sampler.stop() + metrics: dict[str, float] = { + "pipeline_hardware_wall_time_s": stopped_at - self._started_at, + "pipeline_hardware_sampler_node_count": 1.0, + "pipeline_hardware_sampler_active_node_count": float(diagnostics.get("gpu_sampler_active", 0.0) > 0), + "pipeline_hardware_gpu_device_count": float(len(stats)), + "pipeline_hardware_gpu_sampler_error_count": diagnostics.get("gpu_sampler_error_count", 0.0), + } + util_sum = 0.0 + mem_sum = 0.0 + for gpu_uuid, gpu_stats in sorted(stats.items()): + safe_key = f"{self._node_id[:8]}_{gpu_uuid[:12]}" + util = float(gpu_stats.get("gpu_util_pct", 0.0)) + mem = float(gpu_stats.get("gpu_mem_used_pct", 0.0)) + metrics[f"pipeline_hardware_gpu_util_pct_{safe_key}"] = util + metrics[f"pipeline_hardware_gpu_mem_used_pct_{safe_key}"] = mem + util_sum += util + mem_sum += mem + if stats: + metrics["pipeline_hardware_gpu_util_pct_mean_all_sampled"] = util_sum / len(stats) + metrics["pipeline_hardware_gpu_mem_used_pct_mean_all_sampled"] = mem_sum / len(stats) + return metrics + + +def start_pipeline_hardware_samplers(*, interval_s: float = 0.5, startup_timeout_s: float = 5.0) -> list[Any]: + """Start one sampler actor per live Ray node, best effort.""" + + import ray + + remote_cls = ray.remote(num_cpus=0)(_PipelineHardwareSamplerActor) + pending: dict[Any, Any] = {} + for node in ray.nodes(): + if not node.get("Alive"): + continue + node_id = str(node.get("NodeID", "")) + if not node_id: + continue + resource_key = f"node:{node_id}" + if resource_key not in node.get("Resources", {}): + logger.debug("Skipping pipeline hardware sampler on node {} without resource {}", node_id, resource_key) + continue + resources = {resource_key: 0.001} + try: + actor = remote_cls.options(resources=resources).remote(interval_s) + pending[actor.node_id.remote()] = actor + except Exception as exc: # noqa: BLE001 + logger.debug("Failed to start pipeline hardware sampler on node {}: {}", node_id, exc) + if not pending: + return [] + + ready_refs, pending_refs = ray.wait(list(pending), num_returns=len(pending), timeout=max(0.0, startup_timeout_s)) + actors: list[Any] = [] + for ref in ready_refs: + actor = pending[ref] + try: + ray.get(ref) + except Exception as exc: # noqa: BLE001 + logger.debug("Pipeline hardware sampler actor failed during startup: {}", exc) + with contextlib.suppress(Exception): + ray.kill(actor, no_restart=True) + continue + actors.append(actor) + for ref in pending_refs: + actor = pending[ref] + logger.debug("Skipping pipeline hardware sampler actor that did not start within {}s", startup_timeout_s) + with contextlib.suppress(Exception): + ray.kill(actor, no_restart=True) + return actors + + +def stop_pipeline_hardware_samplers(actors: list[Any], *, stop_timeout_s: float = 10.0) -> dict[str, float]: + """Stop sampler actors and aggregate scalar metrics.""" + + if not actors: + return { + "pipeline_hardware_sampler_node_count": 0.0, + "pipeline_hardware_sampler_active_node_count": 0.0, + "pipeline_hardware_gpu_device_count": 0.0, + } + + import ray + + metrics: dict[str, float] = {} + pending: dict[Any, Any] = {} + for actor in actors: + try: + pending[actor.stop.remote()] = actor + except Exception as exc: # noqa: BLE001 + logger.debug("Failed to request pipeline hardware sampler stop: {}", exc) + if not pending: + return metrics + ready_refs, pending_refs = ray.wait(list(pending), num_returns=len(pending), timeout=max(0.0, stop_timeout_s)) + for ref in pending_refs: + logger.debug("Killing pipeline hardware sampler actor that did not stop within {}s", stop_timeout_s) + with contextlib.suppress(Exception): + ray.kill(pending[ref], no_restart=True) + for ref in ready_refs: + try: + result = ray.get(ref) + except Exception as exc: # noqa: BLE001 + logger.debug("Pipeline hardware sampler stop failed: {}", exc) + continue + for key, value in result.items(): + if key == "pipeline_hardware_wall_time_s": + metrics[key] = max(metrics.get(key, 0.0), float(value)) + continue + if key.startswith(("pipeline_hardware_gpu_util_pct_", "pipeline_hardware_gpu_mem_used_pct_")): + metrics[key] = float(value) + continue + metrics[key] = metrics.get(key, 0.0) + float(value) + + device_count = metrics.get("pipeline_hardware_gpu_device_count", 0.0) + if device_count: + # Node means were summed above; normalize them to a run-wide sampled-device mean. + util_keys = [ + key + for key in metrics + if key.startswith("pipeline_hardware_gpu_util_pct_") + and key != "pipeline_hardware_gpu_util_pct_mean_all_sampled" + ] + mem_keys = [ + key + for key in metrics + if key.startswith("pipeline_hardware_gpu_mem_used_pct_") + and key != "pipeline_hardware_gpu_mem_used_pct_mean_all_sampled" + ] + if util_keys: + metrics["pipeline_hardware_gpu_util_pct_mean_all_sampled"] = sum(metrics[key] for key in util_keys) / len( + util_keys + ) + if mem_keys: + metrics["pipeline_hardware_gpu_mem_used_pct_mean_all_sampled"] = sum( + metrics[key] for key in mem_keys + ) / len(mem_keys) + return metrics diff --git a/pyproject.toml b/pyproject.toml index cedc8c48bf..60f47c2006 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,14 @@ audio_cuda12 = [ "torchcodec; platform_machine == 'x86_64' and platform_system != 'Darwin'", ] +# Qwen-Omni/Qwen-ASR is an opt-in CUDA stack. Keeping it separate preserves +# the existing audio_cuda12 environment for every other audio pipeline. +audio_qwen = [ + "nemo_curator[audio_cuda12]", + "nemo_curator[vllm]", + "qwen-omni-utils", +] + image_cpu = [ "Pillow", "torchvision" diff --git a/tests/backends/ray_data/test_utils.py b/tests/backends/ray_data/test_utils.py index 8dda7625b6..1fc184705a 100644 --- a/tests/backends/ray_data/test_utils.py +++ b/tests/backends/ray_data/test_utils.py @@ -12,38 +12,114 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from unittest.mock import MagicMock, Mock, patch +import numpy as np import pytest from ray.data import ActorPoolStrategy +from nemo_curator.backends.ray_data.adapter import RayDataStageAdapter +from nemo_curator.backends.ray_data.executor import RayDataExecutor from nemo_curator.backends.ray_data.utils import ( + coerce_batch_tasks, get_actor_compute_strategy_for_stage, ) from nemo_curator.backends.utils import RayStageSpecKeys, get_available_cpu_gpu_resources +from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask from tests.backends.test_utils import reset_head_node_cache # noqa: F401 +class _PreplannedEchoStage(ProcessingStage[AudioTask, AudioTask]): + name = "preplanned_echo" + resources = Resources(cpus=1.0) + batch_size = 99 + + def process(self, task: AudioTask) -> AudioTask: + return task + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + for task in tasks: + task.data["seen_batch_size"] = len(tasks) + return tasks + + +class _CentralizedPlanningStage(ProcessingStage[AudioTask, AudioTask]): + name = "centralized_planning" + resources = Resources(cpus=1.0) + batch_size = 2 + + def __init__(self) -> None: + self.batch_policy = BatchPolicy( + buckets_sec=[0, 30, 1200], + max_items_per_batch_by_bucket=[2, 1, 1], + max_audio_sec_per_batch=None, + ) + + def process(self, task: AudioTask) -> AudioTask: + return task + + def build_prebucketed_tasks(self, tasks: list[AudioTask]) -> list[AudioTask]: + return list(tasks) + + def scheduler_task_cost(self, task: AudioTask) -> float: + return float(task.data["duration"]) + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + for task in tasks: + task.data["processed_batch_size"] = len(tasks) + return tasks + + def assemble_prebucketed_task_results( + self, + tasks: list[AudioTask], + _processed_tasks: list[AudioTask], + ) -> list[AudioTask]: + return list(tasks) + + +class _FakeDataset: + def __init__(self, sample_items: list[object] | None = None) -> None: + self.repartition_calls: list[tuple[tuple, dict]] = [] + self.map_batches_calls: list[tuple[object, dict]] = [] + self.sample_output: dict | None = None + self.sample_items = sample_items + + def repartition(self, *args, **kwargs) -> "_FakeDataset": + self.repartition_calls.append((args, kwargs)) + return self + + def map_batches(self, fn: Callable[[dict[str, object]], dict[str, object]], **kwargs) -> "_FakeDataset": + self.map_batches_calls.append((fn, kwargs)) + sample_items = self.sample_items + if sample_items is None: + first = AudioTask(data={"duration": 5.0}) + second = AudioTask(data={"duration": 600.0}) + sample_items = [first, second] + self.sample_output = fn({"item": sample_items}) + return self + + class TestGetAvailableCpuGpuResources: # TODO: Move this to tests/backends/test_utils.py """Test class for utility functions in ray_data backend.""" def test_get_available_cpu_gpu_resources_conftest(self, shared_ray_client: None): """Test get_available_cpu_gpu_resources function.""" - # Test with Ray resources from conftest.py cpus, gpus = get_available_cpu_gpu_resources() assert cpus == 11 - # GPU count depends on whether GPU tests are running in this session - # Can be 0 (CPU-only) or 2 (GPU-enabled) depending on test selection - assert gpus in [0.0, 2.0] + # GPU count depends on local hardware and whether GPU tests are selected. + assert gpus in [0.0, 1.0, 2.0] @pytest.mark.usefixtures("reset_head_node_cache") def test_get_resources_with_ignore_head_node( self, shared_ray_client: None, ): - """Test get_available_cpu_gpu_resources with ignore_head_node=True to skip head node. - Since this test is run with the head node, the resources should be 0.""" + """ignore_head_node=True skips the head node; running on the head node, resources are 0.""" cpus_without_head, gpus_without_head = get_available_cpu_gpu_resources(ignore_head_node=True) assert cpus_without_head == 0 assert gpus_without_head == 0 @@ -98,7 +174,7 @@ def test_actor_compute_strategy( ray_stage_spec: dict[str, object], expected: ActorPoolStrategy, expected_warning: str | None, - ): + ) -> None: mock_stage = Mock(num_workers=lambda: num_workers, ray_stage_spec=lambda: ray_stage_spec) mock_stage.name = "stage" @@ -111,7 +187,7 @@ def test_actor_compute_strategy( mock_warning.assert_called_once() assert expected_warning in mock_warning.call_args.args[0] - def test_actor_compute_strategy_rejects_invalid_sizing(self): + def test_actor_compute_strategy_rejects_invalid_sizing(self) -> None: mock_stage = Mock( num_workers=lambda: None, ray_stage_spec=lambda: { @@ -124,3 +200,90 @@ def test_actor_compute_strategy_rejects_invalid_sizing(self): with pytest.raises(ValueError, match="Invalid Ray Data actor pool sizing for stage stage"): get_actor_compute_strategy_for_stage(mock_stage) + + +class TestCoerceBatchTasks: + def test_coerce_batch_tasks_from_numpy_object_array(self) -> None: + sentinel = object() + batch = np.array([sentinel], dtype=object) + assert coerce_batch_tasks(batch) == [sentinel] + + def test_coerce_batch_tasks_empty(self) -> None: + assert coerce_batch_tasks([]) == [] + assert coerce_batch_tasks(np.array([], dtype=object)) == [] + assert coerce_batch_tasks(None) == [] + + +def test_ray_data_adapter_passes_backend_batch_to_stage_process_batch() -> None: + dataset = _FakeDataset() + adapter = RayDataStageAdapter(_CentralizedPlanningStage()) + + out = adapter.process_dataset(dataset) + + assert out is dataset + assert dataset.repartition_calls == [] + assert len(dataset.map_batches_calls) == 1 + assert dataset.map_batches_calls[0][1]["batch_size"] == 2 + assert dataset.sample_output is not None + processed_durations = [task.data["duration"] for task in dataset.sample_output["item"]] + processed_batch_sizes = [task.data["processed_batch_size"] for task in dataset.sample_output["item"]] + assert processed_durations == [5.0, 600.0] + assert processed_batch_sizes == [2, 2] + + +def test_ray_data_executor_keeps_centralized_stage_in_ray_data( + monkeypatch: pytest.MonkeyPatch, +) -> None: + executor = RayDataExecutor(ignore_head_node=True) + stage = _CentralizedPlanningStage() + input_dataset = object() + output_dataset = object() + calls: dict[str, object] = {} + + class FakeRayDataStageAdapter: + def __init__(self, stage_arg: ProcessingStage) -> None: + calls["stage"] = stage_arg + + def process_dataset(self, dataset_arg: object) -> object: + calls["dataset"] = dataset_arg + return output_dataset + + monkeypatch.setattr( + executor, + "_dataset_to_tasks", + Mock(side_effect=AssertionError("centralized stages should not materialize Ray Data datasets")), + ) + monkeypatch.setattr( + executor, + "_tasks_to_dataset", + Mock(side_effect=AssertionError("centralized stages should not rebuild Ray Data datasets")), + ) + monkeypatch.setattr("nemo_curator.backends.ray_data.executor.RayDataStageAdapter", FakeRayDataStageAdapter) + + out = executor._process_stage_dataset(stage, input_dataset) + + assert out is output_dataset + assert calls == {"stage": stage, "dataset": input_dataset} + + +def test_ray_data_executor_keeps_noncentral_stage_in_ray_data(monkeypatch: pytest.MonkeyPatch) -> None: + executor = RayDataExecutor(ignore_head_node=True) + stage = _PreplannedEchoStage() + input_dataset = object() + output_dataset = object() + calls: dict[str, object] = {} + + class FakeRayDataStageAdapter: + def __init__(self, stage_arg: ProcessingStage) -> None: + calls["stage"] = stage_arg + + def process_dataset(self, dataset_arg: object) -> object: + calls["dataset"] = dataset_arg + return output_dataset + + monkeypatch.setattr("nemo_curator.backends.ray_data.executor.RayDataStageAdapter", FakeRayDataStageAdapter) + + out = executor._process_stage_dataset(stage, input_dataset) + + assert out is output_dataset + assert calls == {"stage": stage, "dataset": input_dataset} diff --git a/tests/backends/test_task_id_postprocess.py b/tests/backends/test_task_id_postprocess.py index 905272e9fa..d685982a5b 100644 --- a/tests/backends/test_task_id_postprocess.py +++ b/tests/backends/test_task_id_postprocess.py @@ -18,14 +18,21 @@ end-to-end against real backends in tests/backends/test_integration.py (``test_task_ids``). This file keeps only the cases that are awkward or impossible to trigger through a real pipeline: filter-``None`` positional -alignment, the ambiguous-cardinality ``"r"``-uuid fallback, in-place -re-derivation, and source content-id vs. positional-index selection.""" +alignment, the ambiguous-cardinality ``"r"``-uuid fallback, preservation of + framework overwrite semantics, and source content-id selection.""" from dataclasses import dataclass from nemo_curator.backends.base import BaseStageAdapter +from nemo_curator.pipeline.payload_refs import PayloadRef from nemo_curator.stages.base import ProcessingStage -from nemo_curator.tasks import EmptyTask, FileGroupTask, Task +from nemo_curator.tasks import AudioTask, EmptyTask, FileGroupTask, Task +from nemo_curator.tasks.task_terminals import ( + TERMINAL_COUNT_KEY, + TERMINAL_DROPPED_KEY, + TERMINAL_GROUP_ID_KEY, + TERMINAL_INDEX_KEY, +) @dataclass @@ -42,6 +49,27 @@ def process(self, task: Task) -> Task: return task +@dataclass +class _DropSegmentRowStage(ProcessingStage[AudioTask, AudioTask]): + name: str = "drop_segment_row" + skip_me_key: str = "_skip_me" + _curator_preserves_terminal_tasks: bool = True + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def terminal_tombstone_drop_data_keys(self) -> tuple[str, ...]: + return ("large_payload",) + + def process(self, task: AudioTask) -> AudioTask | None: + if task.data.get("drop"): + return None + return task + + @dataclass class _SimpleTask(Task[list[int]]): @property @@ -77,8 +105,6 @@ def test_filter_stage_keeps_positional_alignment(self) -> None: assert c2.task_id == "0_2_0" # child of p2, not p1 def test_in_place_return_is_reassigned(self) -> None: - # A 1:1 stage that returns its input unchanged still gets a fresh - # segment appended (ids are re-derived at each stage boundary). t = _task("0_5") out = _assign([t], [t]) assert out == [t] @@ -97,11 +123,86 @@ def test_ambiguous_batch_fanout_falls_back_to_uuid(self) -> None: assert all(t.task_id.startswith("r") for t in out) assert all("_" not in t.task_id for t in out) + def test_dropped_segment_row_is_preserved_as_tombstone(self) -> None: + keep = AudioTask( + dataset_name="d", + data={ + "_curator_segment_parent_id": "manifest:0:0", + "_curator_segment_idx": 0, + "_curator_segment_count": 2, + }, + ) + drop = AudioTask( + dataset_name="d", + data={ + "drop": True, + "large_payload": object(), + "payload_ref": PayloadRef( + payload_id="p", + owner_node_id="node", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=1, + sample_rate=16000, + num_samples=1, + ), + "_curator_segment_parent_id": "manifest:0:0", + "_curator_segment_idx": 1, + "_curator_segment_count": 2, + }, + ) + keep.task_id = "0_0" + drop.task_id = "0_1" + + out = BaseStageAdapter(_DropSegmentRowStage()).process_batch([keep, drop]) + + assert len(out) == 2 + assert out[0] is keep + tombstone = out[1] + assert tombstone.data["_skip_me"] == "dropped_segment_row" + assert tombstone.data["_curator_segment_dropped"] is True + assert tombstone.data["_curator_segment_idx"] == 1 + assert "large_payload" not in tombstone.data + assert "payload_ref" not in tombstone.data + assert tombstone.task_id == "0_1_0" + + def test_dropped_generic_terminal_row_is_preserved_as_tombstone(self) -> None: + keep = AudioTask( + dataset_name="d", + data={ + TERMINAL_GROUP_ID_KEY: "parent-0", + TERMINAL_INDEX_KEY: 0, + TERMINAL_COUNT_KEY: 2, + }, + ) + drop = AudioTask( + dataset_name="d", + data={ + "drop": True, + "large_payload": object(), + TERMINAL_GROUP_ID_KEY: "parent-0", + TERMINAL_INDEX_KEY: 1, + TERMINAL_COUNT_KEY: 2, + }, + ) + keep.task_id = "0_0" + drop.task_id = "0_1" + + out = BaseStageAdapter(_DropSegmentRowStage()).process_batch([keep, drop]) + + assert len(out) == 2 + tombstone = out[1] + assert tombstone.data["_skip_me"] == "dropped_segment_row" + assert tombstone.data[TERMINAL_DROPPED_KEY] is True + assert "_curator_segment_dropped" not in tombstone.data + assert "large_payload" not in tombstone.data + assert tombstone.task_id == "0_1_0" + class TestSourceStage: def test_uses_content_id_rooted_at_input(self) -> None: - # FileGroupTask.get_deterministic_id() hashes its files; the source - # output is rooted at the EmptyTask input id "0" → "0_". + # FileGroupTask.get_deterministic_id() hashes its files; output with no + # content ids are rooted at the framework EmptyTask id. empty = EmptyTask(dataset_name="empty", data=None) a = FileGroupTask(dataset_name="d", data=["a.parquet"]) b = FileGroupTask(dataset_name="d", data=["b.parquet"]) diff --git a/tests/backends/test_utils.py b/tests/backends/test_utils.py index 39ef35d1ad..8bde59fc32 100644 --- a/tests/backends/test_utils.py +++ b/tests/backends/test_utils.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ruff: noqa: ANN202 +import sys +import types import uuid from collections.abc import Iterator from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING @@ -61,11 +65,9 @@ def test_merge_nested_dicts(self): result = merge_executor_configs(base, override) - # Check that nested dicts are merged assert result["runtime_env"]["env_vars"]["A"] == "1" assert result["runtime_env"]["env_vars"]["B"] == "3" assert result["runtime_env"]["env_vars"]["C"] == "4" - # Check that other keys are preserved assert result["runtime_env"]["pip"] == ["package1"] assert result["runtime_env"]["working_dir"] == "." assert result["other_config"] == "value1" @@ -126,11 +128,9 @@ def setup_on_node( stage1 = MockStage1() stage2 = MockStage1().with_(name="mock_stage_2", resources=Resources(cpus=0.5, gpus=0.0)) - # Test execute_setup_on_node([stage1, stage2]) - # Check the files written to the temp directory - # Verify that NodeInfo and WorkerMetadata were passed correctly + # Verify NodeInfo / WorkerMetadata were passed via the per-call files. for stage_name in ["mock_stage_1", "mock_stage_2"]: stage_files = list(tmp_path.glob(f"{stage_name}_*.txt")) assert len(stage_files) == len(ray.nodes()), ( @@ -149,7 +149,7 @@ def setup_on_node( f"Expected node IDs to be the same as the Ray nodes, got {node_ids}" ) - # Check that there are exactly two log records that start with "Executing setup on node" and end with "for 2 stages" + # Log records starting with "Executing setup on node" and ending with "for 2 stages". matching_logs = [ record.message for record in caplog.records @@ -160,6 +160,41 @@ def setup_on_node( f"Expected {len(ray.nodes())} logs for setup on node for 2 stages, got {len(matching_logs)}: {matching_logs}" ) + def test_execute_setup_on_node_uses_stage_setup_resources(self, monkeypatch: pytest.MonkeyPatch) -> None: + """setup_on_node should use the stage's setup resource contract.""" + + class GPUStage(ProcessingStage): + name = "gpu_stage" + resources = Resources(cpus=8.0, gpus=4.0) + + def process(self, task: "Task") -> "Task": + return task + + class CPUOnlySetupStage(GPUStage): + name = "cpu_only_setup_stage" + + def setup_on_node_resources(self) -> Resources: + return Resources(cpus=1.0, gpus=0.0) + + submitted_kwargs = [] + + def fake_nodes(): + return [{"Alive": True, "NodeID": "node-1"}] + + def fake_submit_on_each_node(_remote_fn, _stage, **kwargs): # noqa: ANN001 + submitted_kwargs.append(kwargs) + return [] + + monkeypatch.setattr(ray, "nodes", fake_nodes) + monkeypatch.setattr("nemo_curator.backends.utils.submit_on_each_node", fake_submit_on_each_node) + + execute_setup_on_node([GPUStage(), CPUOnlySetupStage()]) + + assert submitted_kwargs == [ + {"ignore_head_node": False, "num_cpus": 8.0, "num_gpus": 4.0}, + {"ignore_head_node": False, "num_cpus": 1.0, "num_gpus": 0.0}, + ] + def test_execute_setup_on_node_ignore_head_node( self, shared_ray_client: None, @@ -189,7 +224,6 @@ def setup_on_node( stage = MockStage1() - # Test with ignore_head_node=True execute_setup_on_node([stage], ignore_head_node=True) # Verify the cache variable is set directly (not using the lazy function) @@ -302,3 +336,169 @@ def test_capacity_check(self, available_gpus: float, needed: int, should_raise: check_total_gpu_capacity(needed) else: check_total_gpu_capacity(needed) + + +@dataclass +class _FakeGpuAllocation: + """Mirror of cosmos_xenna ``GpuAllocation`` (only ``index`` is read).""" + + index: int + used_fraction: float = 1.0 + + +@dataclass +class _FakeWorkerResources: + """Mirror of cosmos_xenna ``WorkerResources`` (only ``gpus`` is read).""" + + node: str + gpus: list[_FakeGpuAllocation] + + +class _FakeRayContext: + def __init__( + self, + *, + node_id: str = "nodeabcdef123", + actor_id: str = "", + worker_id: str = "worker123456", + ) -> None: + self._node_id = node_id + self._actor_id = actor_id + self._worker_id = worker_id + + def get_node_id(self) -> str: + return self._node_id + + def get_actor_id(self) -> str: + return self._actor_id + + def get_worker_id(self) -> str: + return self._worker_id + + +class TestBackendPerfIdentity: + """Backend-specific GPU label resolvers (no cross-backend fallbacks).""" + + def test_xenna_allocation_index(self) -> None: + from nemo_curator.backends.perf_identity import build_xenna_perf_identity + + alloc = _FakeWorkerResources(node="ray-node-abc", gpus=[_FakeGpuAllocation(index=3)]) + with pytest.MonkeyPatch.context() as mp: + mp.setenv("CUDA_VISIBLE_DEVICES", "7") + identity = build_xenna_perf_identity( + "QwenOmni_inference", + worker_id="worker-abc", + node_id="node-0", + allocation=alloc, + requires_gpu=True, + ) + assert identity.gpu_id == "node-0:3" + assert identity.node_id == "node-0" + assert identity.actor_id == "QwenOmni_inference:actor-worker-a" + assert identity.gpu_indices == (3,) + + def test_xenna_physical_address_uses_pod_ip_and_all_allocation_gpus(self) -> None: + from nemo_curator.backends.perf_identity import build_xenna_perf_identity + + alloc = _FakeWorkerResources( + node="ray-node-abc", + gpus=[_FakeGpuAllocation(index=0), _FakeGpuAllocation(index=1)], + ) + with pytest.MonkeyPatch.context() as mp: + mp.setenv("POD_IP", "10.244.181.136") + identity = build_xenna_perf_identity( + "QwenOmni_inference", + worker_id="worker-abc", + node_id="node-0", + allocation=alloc, + requires_gpu=True, + ) + assert identity.gpu_id == "node-0:0" + assert identity.pod_ip == "10.244.181.136" + assert identity.physical_address == "10.244.181.136:0,1" + assert identity.gpu_indices == (0, 1) + + def test_xenna_cpu_stage_with_empty_allocation_is_blank_gpu(self) -> None: + from nemo_curator.backends.perf_identity import build_xenna_perf_identity + + alloc = _FakeWorkerResources(node="ray-node-abc", gpus=[]) + with pytest.MonkeyPatch.context() as mp: + mp.delenv("CUDA_VISIBLE_DEVICES", raising=False) + identity = build_xenna_perf_identity( + "reader", + worker_id="w1", + node_id="node-0", + allocation=alloc, + requires_gpu=False, + ) + assert identity.gpu_id == "" + + def test_xenna_bare_gpu_index_when_node_unknown(self) -> None: + from nemo_curator.backends.perf_identity import build_xenna_perf_identity + + alloc = _FakeWorkerResources(node="", gpus=[_FakeGpuAllocation(index=2)]) + identity = build_xenna_perf_identity( + "infer", + worker_id="w1", + node_id="", + allocation=alloc, + requires_gpu=True, + ) + assert identity.gpu_id == "2" + + def test_ray_does_not_parse_cuda_visible_devices(self) -> None: + from nemo_curator.backends.perf_identity import build_ray_perf_identity + + with pytest.MonkeyPatch.context() as mp: + mp.setenv("CUDA_VISIBLE_DEVICES", "5,6") + identity = build_ray_perf_identity("infer", requires_gpu=True) + # Driver-side test has no Ray actor GPU assignment — must stay blank, not CVD. + assert identity.gpu_id == "" + + def test_ray_runtime_context_resolves_gpu_without_worker_env(self) -> None: + from nemo_curator.backends.perf_identity import build_ray_perf_identity + + fake_ray = types.SimpleNamespace( + is_initialized=lambda: True, + get_runtime_context=lambda: _FakeRayContext(worker_id="workerabcdef999"), + get_gpu_ids=lambda: [0, 1], + util=types.SimpleNamespace(get_node_ip_address=lambda: "10.0.0.5"), + ) + with pytest.MonkeyPatch.context() as mp: + mp.delenv("RAY_WORKER_ID", raising=False) + mp.setitem(sys.modules, "ray", fake_ray) + identity = build_ray_perf_identity("infer", requires_gpu=True) + + assert identity.actor_id == "infer:actor-workerab" + assert identity.node_id == "node-nodeabcd" + assert identity.gpu_id == "node-nodeabcd:0" + assert identity.physical_address == "10.0.0.5:0,1" + assert identity.gpu_indices == (0, 1) + + def test_ray_runtime_context_maps_uuid_gpu_assignments_with_nvml(self) -> None: + from nemo_curator.backends.perf_identity import build_ray_perf_identity + + fake_ray = types.SimpleNamespace( + is_initialized=lambda: True, + get_runtime_context=lambda: _FakeRayContext(actor_id="actorabcdef999"), + get_gpu_ids=lambda: ["GPU-aaaa", "GPU-bbbb"], + util=types.SimpleNamespace(get_node_ip_address=lambda: "10.0.0.5"), + ) + fake_pynvml = types.SimpleNamespace( + nvmlInit=lambda: None, + nvmlShutdown=lambda: None, + nvmlDeviceGetCount=lambda: 3, + nvmlDeviceGetHandleByIndex=lambda index: index, + nvmlDeviceGetUUID=lambda handle: ["GPU-zzzz", "GPU-aaaa", b"GPU-bbbb"][handle], + ) + with pytest.MonkeyPatch.context() as mp: + mp.delenv("RAY_WORKER_ID", raising=False) + mp.setitem(sys.modules, "ray", fake_ray) + mp.setitem(sys.modules, "pynvml", fake_pynvml) + identity = build_ray_perf_identity("infer", requires_gpu=True) + + assert identity.actor_id == "infer:actor-actorabc" + assert identity.gpu_id == "node-nodeabcd:1" + assert identity.physical_address == "10.0.0.5:1,2" + assert identity.gpu_indices == (1, 2) + assert identity.gpu_uuids == ("GPU-aaaa", "GPU-bbbb") diff --git a/tests/backends/xenna/__init__.py b/tests/backends/xenna/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/backends/xenna/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/backends/xenna/test_executor.py b/tests/backends/xenna/test_executor.py new file mode 100644 index 0000000000..96c9ee6450 --- /dev/null +++ b/tests/backends/xenna/test_executor.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import pytest +from cosmos_xenna.utils.verbosity import VerbosityLevel + +from nemo_curator.backends.xenna.executor import XennaExecutor +from nemo_curator.stages.audio.common import ManifestWriterStage +from nemo_curator.stages.audio.inference.asr.stage import ASRStage +from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy +from nemo_curator.stages.audio.io.audio_file_reader import AudioFileReaderStage +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask, Task + + +class _PassthroughStage(ProcessingStage[AudioTask, AudioTask]): + name = "passthrough" + resources = Resources(cpus=1.0) + batch_size = 2 + + def process(self, task: AudioTask) -> AudioTask: + return task + + +class _CentralizedStage(_PassthroughStage): + name = "centralized" + + def __init__(self) -> None: + self.batch_policy = BatchPolicy( + buckets_sec=[0, 30, 1200], + max_items_per_batch_by_bucket=[2, 1, 1], + max_audio_sec_per_batch=None, + ) + + def build_prebucketed_tasks(self, tasks: list[AudioTask]) -> list[AudioTask]: + return list(tasks) + + def scheduler_task_cost(self, task: AudioTask) -> float: + return float(task.data.get("duration", 0.0)) + + def assemble_prebucketed_task_results( + self, + _tasks: list[AudioTask], + processed_tasks: list[AudioTask], + ) -> list[AudioTask]: + return processed_tasks + + +class _WorkerSizedStage(_PassthroughStage): + name = "worker_sized" + + def __init__(self, workers: int | None = None, stage_spec: dict[str, Any] | None = None) -> None: + self._workers = workers + self._stage_spec = stage_spec or {} + + def num_workers(self) -> int | None: + return self._workers + + def xenna_stage_spec(self) -> dict[str, Any]: + return dict(self._stage_spec) + + +def test_xenna_executor_keeps_centralized_stage_inside_one_pipeline(monkeypatch) -> None: # noqa: ANN001 + executor = XennaExecutor() + stages: list[ProcessingStage[Any, Any]] = [ + _PassthroughStage(), + _CentralizedStage(), + _PassthroughStage(), + ] + initial_tasks = [AudioTask(data={"duration": 5.0})] + calls: list[tuple[list[ProcessingStage[Any, Any]], list[Task]]] = [] + + def fake_run_xenna_pipeline( + stages_arg: list[ProcessingStage[Any, Any]], + initial_tasks_arg: list[Task], + ) -> list[Task]: + calls.append((stages_arg, initial_tasks_arg)) + return initial_tasks_arg + + monkeypatch.setattr(executor, "_run_xenna_pipeline", fake_run_xenna_pipeline) + + out = executor.execute(stages, initial_tasks) + + assert out == initial_tasks + assert calls == [(stages, initial_tasks)] + + +def test_xenna_verbosity_none_uses_default() -> None: + executor = XennaExecutor(config={"actor_pool_verbosity_level": None}) + + assert executor._get_verbosity_config("actor_pool_verbosity_level") is VerbosityLevel.INFO + + +def test_xenna_verbosity_bad_string_has_helpful_error() -> None: + executor = XennaExecutor(config={"actor_pool_verbosity_level": "loud"}) + + with pytest.raises(ValueError, match="Invalid Xenna verbosity config actor_pool_verbosity_level='loud'"): + executor._get_verbosity_config("actor_pool_verbosity_level") + + +def test_xenna_stage_spec_falls_back_to_stage_num_workers() -> None: + stage_spec = XennaExecutor()._build_stage_spec(_WorkerSizedStage(workers=3)) + + assert stage_spec.num_workers == 3 + assert stage_spec.num_workers_per_node is None + + +def test_real_audio_stages_use_main_worker_sizing_contract(tmp_path) -> None: # noqa: ANN001 + executor = XennaExecutor() + + asr_spec = executor._build_stage_spec( + ASRStage( + adapter_target="tests.fake.Adapter", + model_id="fake-model", + xenna_num_workers=2, + ) + ) + reader_spec = executor._build_stage_spec(AudioFileReaderStage(xenna_num_workers=3)) + writer_spec = executor._build_stage_spec(ManifestWriterStage(output_path=str(tmp_path / "out.jsonl"))) + + assert asr_spec.num_workers == 2 + assert asr_spec.num_workers_per_node is None + assert reader_spec.num_workers == 3 + assert reader_spec.num_workers_per_node is None + assert writer_spec.num_workers == 1 + assert writer_spec.num_workers_per_node is None + + +def test_xenna_stage_spec_num_workers_is_rejected() -> None: + with pytest.raises(ValueError, match=r"Use num_workers\(\) instead"): + XennaExecutor()._build_stage_spec(_WorkerSizedStage(stage_spec={"num_workers": 4})) + + +def test_xenna_num_workers_per_node_is_rejected_with_stage_num_workers() -> None: + with pytest.raises(ValueError, match=r"num_workers\(\).*num_workers_per_node"): + XennaExecutor()._build_stage_spec(_WorkerSizedStage(workers=3, stage_spec={"num_workers_per_node": 2})) + + +def test_xenna_num_workers_per_node_is_rejected_with_legacy_num_workers() -> None: + stage = _WorkerSizedStage(stage_spec={"num_workers": 4, "num_workers_per_node": 2}) + + with pytest.raises(ValueError, match=r"Use num_workers\(\) instead"): + XennaExecutor()._build_stage_spec(stage) + + +def test_xenna_rejects_conflicting_cluster_worker_counts() -> None: + stage = _WorkerSizedStage(workers=3, stage_spec={"num_workers": 4}) + + with pytest.raises(ValueError, match=r"Use num_workers\(\) instead"): + XennaExecutor()._build_stage_spec(stage) diff --git a/tests/models/asr/__init__.py b/tests/models/asr/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/asr/test_base.py b/tests/models/asr/test_base.py new file mode 100644 index 0000000000..459ab11e05 --- /dev/null +++ b/tests/models/asr/test_base.py @@ -0,0 +1,107 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the generic ASRStage<->ASRAdapter split contract via a non-Qwen fake adapter.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import patch + +import numpy as np + +from nemo_curator.models.asr.base import ASRAdapter, ASRResult +from nemo_curator.stages.audio.inference.asr import ASRStage +from nemo_curator.tasks import AudioTask + +_SR = 16000 + + +class _FakeASRAdapter: + """Minimal non-Qwen ``ASRAdapter`` that echoes each item's language as the transcription.""" + + def __init__(self, model_id: str, revision: str | None = None, **adapter_kwargs: object) -> None: + self.model_id = model_id + self.revision = revision + self.adapter_kwargs = adapter_kwargs + self.last_metrics: dict[str, float] = {} + self.setup_called = False + self.seen_items: list[dict[str, Any]] = [] + + @classmethod + def prefetch_weights(cls, _model_id: str, _revision: str | None = None) -> None: + return None + + def setup(self) -> None: + self.setup_called = True + + def teardown(self) -> None: + return None + + def transcribe_batch(self, items: list[dict[str, Any]]) -> list[ASRResult]: + self.seen_items = items + return [ASRResult(text=f"fake:{it.get('language')}", model_id=self.model_id) for it in items] + + +# ---------------------------------------------------------------------- +# ASRResult: canonical adapter-output shape +# ---------------------------------------------------------------------- + + +def test_asr_result_defaults() -> None: + """The shape every adapter returns and the stage reads must stay stable.""" + r = ASRResult(text="hello") + assert r.text == "hello" + assert r.secondary_text is None + assert r.skipped is False + assert r.model_id == "" + assert r.extras == {} + + +# ---------------------------------------------------------------------- +# Protocol conformance for an arbitrary (non-Qwen) adapter +# ---------------------------------------------------------------------- + + +def test_fake_adapter_conforms_to_asr_protocol() -> None: + """A minimal hand-written adapter satisfies ASRAdapter (requires @runtime_checkable).""" + adapter = _FakeASRAdapter(model_id="fake/model") + assert isinstance(adapter, ASRAdapter) + + +# ---------------------------------------------------------------------- +# Swappability: ASRStage drives ANY conforming adapter end-to-end +# ---------------------------------------------------------------------- + + +def test_asr_stage_drives_arbitrary_conforming_adapter() -> None: + """ASRStage resolves adapter_target, constructs+sets up the adapter, and delegates process_batch.""" + stage = ASRStage( + adapter_target="tests.models.asr.test_base._FakeASRAdapter", + model_id="fake/model", + pred_text_key="pred_text", + ) + + # Patch resolution so the dotted string need not be importable. + with patch("hydra.utils.get_class", return_value=_FakeASRAdapter): + stage.setup() + + assert isinstance(stage._adapter, _FakeASRAdapter) + assert stage._adapter.setup_called is True + + task = AudioTask(data={"waveform": np.zeros(_SR, dtype=np.float32), "sample_rate": _SR, "source_lang": "es"}) + results = stage.process_batch([task]) + + # Stage mapped "es" -> "Spanish" and packaged the result under the configured key. + assert results[0].data["pred_text"] == "fake:Spanish" diff --git a/tests/models/asr/test_package_lazy_import.py b/tests/models/asr/test_package_lazy_import.py new file mode 100644 index 0000000000..9e559f4d77 --- /dev/null +++ b/tests/models/asr/test_package_lazy_import.py @@ -0,0 +1,62 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that ``nemo_curator.models.asr`` does not eagerly import GPU adapters.""" + +from __future__ import annotations + +import builtins +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pytest + + +def test_importing_asr_subpackage_does_not_load_qwen_omni(monkeypatch: pytest.MonkeyPatch) -> None: + """``import nemo_curator.models.asr`` must not pull in ``qwen_omni`` at init time.""" + original_import = builtins.__import__ + blocked: list[str] = [] + + def tracking_import( + name: str, + globals_: object | None = None, + locals_: object | None = None, + fromlist: tuple[str, ...] = (), + level: int = 0, + ) -> object: + if name.endswith("nemo_curator.models.asr.qwen_omni") or name == "nemo_curator.models.asr.qwen_omni": + blocked.append(name) + msg = f"blocked eager import of {name}" + raise ImportError(msg) + return original_import(name, globals_, locals_, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", tracking_import) + + for mod_name in list(sys.modules): + if mod_name in {"nemo_curator.models.asr", "nemo_curator.models.asr.base"}: + del sys.modules[mod_name] + + import nemo_curator.models.asr as asr_pkg + + assert blocked == [] + assert asr_pkg.ASRAdapter is not None + assert asr_pkg.ASRResult is not None + assert "QwenOmniASRAdapter" in asr_pkg._LAZY + + +def test_asr_subpackage_lazy_getattr_resolves_qwen_adapter() -> None: + from nemo_curator.models.asr import QwenOmniASRAdapter + + assert QwenOmniASRAdapter.__name__ == "QwenOmniASRAdapter" diff --git a/tests/models/asr/test_qwen_omni.py b/tests/models/asr/test_qwen_omni.py new file mode 100644 index 0000000000..8bf5c3f9e6 --- /dev/null +++ b/tests/models/asr/test_qwen_omni.py @@ -0,0 +1,353 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the concrete ``QwenOmniASRAdapter`` internals (no GPU / no real vLLM required).""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +from nemo_curator.models.asr.base import ASRAdapter +from nemo_curator.models.asr.qwen_omni import QwenOmniASRAdapter + +_SR = 16000 + + +# ---------------------------------------------------------------------- +# Protocol conformance (requires @runtime_checkable) +# ---------------------------------------------------------------------- + + +def test_qwen_adapter_conforms_to_asr_protocol() -> None: + """QwenOmniASRAdapter satisfies the ASRAdapter contract (requires @runtime_checkable).""" + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni") + assert isinstance(adapter, ASRAdapter) + + +# ---------------------------------------------------------------------- +# QwenOmniASRAdapter helpers (no GPU, no vLLM required) +# ---------------------------------------------------------------------- + + +def test_qwen_adapter_first_output_text_handles_empty_vllm_output() -> None: + assert QwenOmniASRAdapter._first_output_text(SimpleNamespace(outputs=[])) == "" + + +def test_qwen_adapter_count_output_tokens_handles_empty_vllm_output() -> None: + assert QwenOmniASRAdapter._count_output_tokens([SimpleNamespace(outputs=[])]) == 0.0 + + +def test_qwen_adapter_infer_turn_scatters_outputs_by_index() -> None: + """``_infer_turn`` scatters vLLM outputs back to original positions and reports time + tokens.""" + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni") + + def _fake_generate(inputs: list[dict[str, object]]) -> list[SimpleNamespace]: + return [ + SimpleNamespace(outputs=[SimpleNamespace(text=f"t{i}", token_ids=[0, 1])]) for i, _ in enumerate(inputs) + ] + + adapter._generate = _fake_generate # type: ignore[method-assign] + + # Length-4 batch where only positions 1 and 3 produced valid inputs. + texts, generation_s, tokens = adapter._infer_turn( + inputs=[{"prompt": "a"}, {"prompt": "b"}], + indices=[1, 3], + n=4, + ) + + assert texts == ["", "t0", "", "t1"] + assert tokens == 4.0 # 2 outputs x 2 token_ids each + assert generation_s >= 0.0 + + +def test_qwen_adapter_infer_turn_raises_on_vllm_count_mismatch() -> None: + """A short vLLM result list must fail loud (strict=True), not silently drop utterances.""" + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni") + + def _short_generate(_inputs: list[dict[str, object]]) -> list[SimpleNamespace]: + # vLLM returns fewer outputs than inputs (e.g. a scheduler drop). + return [SimpleNamespace(outputs=[SimpleNamespace(text="only-one", token_ids=[0])])] + + adapter._generate = _short_generate # type: ignore[method-assign] + + with pytest.raises(ValueError, match="zip"): + adapter._infer_turn(inputs=[{"prompt": "a"}, {"prompt": "b"}], indices=[0, 1], n=2) + + +def test_qwen_adapter_turn2_extends_shared_audio_prompt_messages() -> None: + adapter = QwenOmniASRAdapter( + model_id="mock/qwen-omni", + prompt_text="Transcribe {language}.", + followup_prompt="Refine {language}.", + system_prompt="System {language}.", + ) + waveform = np.zeros(_SR, dtype=np.float32) + + turn1_messages = adapter._build_messages(waveform, "English") + turn2_messages = adapter._build_turn2_messages(waveform, "draft text", "English") + + assert [message["role"] for message in turn2_messages[:2]] == [message["role"] for message in turn1_messages] + assert turn2_messages[0]["content"][0]["text"] == turn1_messages[0]["content"][0]["text"] + assert turn2_messages[1]["content"][0]["text"] == turn1_messages[1]["content"][0]["text"] + assert turn2_messages[1]["content"][1]["audio"] is waveform + assert turn2_messages[2] == {"role": "assistant", "content": [{"type": "text", "text": "draft text"}]} + assert turn2_messages[3] == {"role": "user", "content": [{"type": "text", "text": "Refine English."}]} + + +def test_qwen_adapter_prompt_replaces_language_and_reference_transcript() -> None: + adapter = QwenOmniASRAdapter( + model_id="mock/qwen-omni", + prompt_text="Transcribe {language}: {transcript}", + en_prompt_text="English prompt {transcript}", + followup_prompt="Refine {language}: {transcript}", + ) + waveform = np.zeros(_SR, dtype=np.float32) + + turn1_messages = adapter._build_messages(waveform, "English", "hello reference") + turn2_messages = adapter._build_turn2_messages(waveform, "draft", "Spanish", "hola ref") + + assert turn1_messages[-1]["content"][0]["text"] == "English prompt hello reference" + assert turn2_messages[0]["content"][0]["text"] == "Transcribe Spanish: hola ref" + assert turn2_messages[2]["content"][0]["text"] == "Refine Spanish: hola ref" + + +def test_qwen_adapter_transcribe_batch_packages_results() -> None: + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni", followup_prompt="refine") + adapter._run_two_turn = MagicMock( # type: ignore[method-assign] + return_value=( + ["text-a", "text-b", ""], + ["refined-a", "", ""], + {2}, + ), + ) + items = [ + { + "waveform": np.zeros(_SR, dtype=np.float32), + "sample_rate": _SR, + "language": "English", + "reference_text": "ref-a", + }, + { + "waveform": np.zeros(_SR, dtype=np.float32), + "sample_rate": _SR, + "language": "English", + "reference_text": "ref-b", + }, + {"waveform": np.zeros(0, dtype=np.float32), "sample_rate": _SR, "language": None}, + ] + results = adapter.transcribe_batch(items) + + assert [r.text for r in results] == ["text-a", "text-b", ""] + assert [r.secondary_text for r in results] == ["refined-a", "", ""] + assert [r.skipped for r in results] == [False, False, True] + assert all(r.model_id == "mock/qwen-omni" for r in results) + + adapter._run_two_turn.assert_called_once() + _waveforms, _srs, langs, refs = adapter._run_two_turn.call_args[0] + assert langs == ["English", "English", None] + assert refs == ["ref-a", "ref-b", None] + + +def test_qwen_adapter_single_turn_drops_secondary_text() -> None: + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni", followup_prompt=None) + adapter._run_two_turn = MagicMock( # type: ignore[method-assign] + return_value=(["text-a"], [""], set()), + ) + results = adapter.transcribe_batch( + [ + {"waveform": np.zeros(_SR, dtype=np.float32), "sample_rate": _SR}, + ] + ) + assert results[0].secondary_text is None + + +def test_qwen_adapter_prepare_single_accepts_canonical_torch_2d_waveform() -> None: + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni") + adapter._build_messages = MagicMock(return_value=[{"role": "user", "content": []}]) # type: ignore[method-assign] + adapter._pack_vllm_inputs = MagicMock(return_value={"prompt": "p"}) # type: ignore[method-assign] + waveform = torch.stack([torch.ones(_SR), torch.zeros(_SR)]) + + prepared = adapter._prepare_single(waveform, _SR, "English") + + assert prepared is not None + inputs, waveform_16k = prepared + assert inputs == {"prompt": "p"} + assert waveform_16k.shape == (_SR,) + assert waveform_16k.dtype == np.float32 + np.testing.assert_allclose(waveform_16k, np.full(_SR, 0.5, dtype=np.float32)) + + +def test_qwen_adapter_prepare_single_skips_too_short_waveform_before_preprocess() -> None: + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni") + adapter._build_messages = MagicMock(return_value=[{"role": "user", "content": []}]) # type: ignore[method-assign] + adapter._pack_vllm_inputs = MagicMock(return_value={"prompt": "p"}) # type: ignore[method-assign] + + assert adapter._prepare_single(np.zeros(100, dtype=np.float32), _SR, "English") is None + adapter._build_messages.assert_not_called() + adapter._pack_vllm_inputs.assert_not_called() + + +# ---------------------------------------------------------------------- +# Elevated vLLM knobs +# ---------------------------------------------------------------------- + + +def test_qwen_adapter_has_elevated_vllm_knobs_as_dataclass_fields() -> None: + """vLLM knobs are dataclass fields settable from YAML ``adapter_kwargs``.""" + adapter = QwenOmniASRAdapter( + model_id="mock/qwen-omni", + enable_prefix_caching=False, + prefix_caching_hash_algo="sha256", + limit_mm_per_prompt_audio=1, + max_num_batched_tokens=49152, + seed=99, + ) + assert adapter.enable_prefix_caching is False + assert adapter.prefix_caching_hash_algo == "sha256" + assert adapter.limit_mm_per_prompt_audio == 1 + assert adapter.max_num_batched_tokens == 49152 + assert adapter.seed == 99 + + +def test_qwen_adapter_vllm_knob_defaults_match_doc() -> None: + """Default vLLM knob values match the tutorial when YAML omits overrides.""" + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni") + assert adapter.enable_prefix_caching is True + assert adapter.prefix_caching_hash_algo == "xxhash" + assert adapter.limit_mm_per_prompt_audio == 2 + assert adapter.max_num_batched_tokens is None + assert adapter.seed == 1234 + + +def test_qwen_adapter_rejects_invalid_max_num_batched_tokens() -> None: + with pytest.raises(ValueError, match="max_num_batched_tokens must be positive"): + QwenOmniASRAdapter(model_id="mock/qwen-omni", max_num_batched_tokens=0) + + +def test_qwen_adapter_setup_threads_vllm_knobs_into_llm_ctor() -> None: + """setup() forwards the elevated knobs to the vLLM LLM ctor.""" + adapter = QwenOmniASRAdapter( + model_id="mock/qwen-omni", + enable_prefix_caching=False, + prefix_caching_hash_algo="sha256", + limit_mm_per_prompt_audio=3, + max_num_batched_tokens=49152, + seed=42, + tensor_parallel_size=1, + ) + fake_llm = MagicMock() + fake_processor = MagicMock() + with ( + patch("nemo_curator.models.asr.qwen_omni.VLLM_AVAILABLE", new=True), + patch("nemo_curator.models.asr.qwen_omni.process_mm_info", MagicMock()), + patch("nemo_curator.models.vllm_model.LLM", return_value=fake_llm) as llm_ctor, + patch( + "nemo_curator.models.asr.qwen_omni.Qwen3OmniMoeProcessor.from_pretrained", + return_value=fake_processor, + ), + patch("nemo_curator.models.vllm_model.SamplingParams"), + ): + adapter.setup() + + llm_ctor.assert_called_once() + kwargs = llm_ctor.call_args.kwargs + assert kwargs["enable_prefix_caching"] is False + assert kwargs["prefix_caching_hash_algo"] == "sha256" + assert kwargs["limit_mm_per_prompt"] == {"image": 1, "video": 1, "audio": 3} + assert kwargs["max_num_batched_tokens"] == 49152 + assert kwargs["seed"] == 42 + assert "revision" not in kwargs + + +def test_qwen_adapter_setup_forwards_revision_to_llm_and_processor() -> None: + """Tier-1 revision must reach inference loaders, not only prefetch_weights.""" + adapter = QwenOmniASRAdapter( + model_id="mock/qwen-omni", + revision="abc123", + tensor_parallel_size=1, + ) + fake_llm = MagicMock() + fake_processor = MagicMock() + with ( + patch("nemo_curator.models.asr.qwen_omni.VLLM_AVAILABLE", new=True), + patch("nemo_curator.models.asr.qwen_omni.process_mm_info", MagicMock()), + patch("nemo_curator.models.vllm_model.LLM", return_value=fake_llm) as llm_ctor, + patch( + "nemo_curator.models.asr.qwen_omni.Qwen3OmniMoeProcessor.from_pretrained", + return_value=fake_processor, + ) as proc_ctor, + patch("nemo_curator.models.vllm_model.SamplingParams"), + ): + adapter.setup() + + assert llm_ctor.call_args.kwargs["revision"] == "abc123" + proc_ctor.assert_called_once_with("mock/qwen-omni", revision="abc123") + + +def test_qwen_adapter_setup_cleans_up_partial_engine_when_processor_fails() -> None: + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni", tensor_parallel_size=1) + fake_llm = MagicMock() + with ( + patch("nemo_curator.models.asr.qwen_omni.VLLM_AVAILABLE", new=True), + patch("nemo_curator.models.asr.qwen_omni.process_mm_info", MagicMock()), + patch("nemo_curator.models.vllm_model.LLM", return_value=fake_llm), + patch( + "nemo_curator.models.asr.qwen_omni.Qwen3OmniMoeProcessor.from_pretrained", + side_effect=RuntimeError("processor failed"), + ), + patch("nemo_curator.models.vllm_model.SamplingParams"), + pytest.raises(RuntimeError, match="processor failed"), + ): + adapter.setup() + + assert adapter._llm is None + assert adapter._sampling_params is None + assert adapter._processor is None + assert adapter._prep_pool is None + + +def test_qwen_adapter_marks_empty_turn1_outputs_skipped_and_excludes_turn2() -> None: + adapter = QwenOmniASRAdapter(model_id="mock/qwen-omni", followup_prompt="refine") + waveform_a = np.ones(_SR, dtype=np.float32) + waveform_b = np.ones(_SR, dtype=np.float32) + adapter._prepare_batch = MagicMock( # type: ignore[method-assign] + return_value=[ + ({"prompt": "a"}, waveform_a), + ({"prompt": "b"}, waveform_b), + ], + ) + adapter._prepare_turn2_batch = MagicMock(return_value=[{"prompt": "turn2-b"}]) # type: ignore[method-assign] + adapter._infer_turn = MagicMock( # type: ignore[method-assign] + side_effect=[ + (["", "text-b"], 0.1, 2.0), + (["", "refined-b"], 0.2, 3.0), + ], + ) + + pred_texts, disfluency_texts, skipped_indices = adapter._run_two_turn( + [waveform_a, waveform_b], + [_SR, _SR], + ["English", "English"], + ) + + assert pred_texts == ["", "text-b"] + assert disfluency_texts == ["", "refined-b"] + assert skipped_indices == {0} + adapter._prepare_turn2_batch.assert_called_once_with([waveform_b], ["text-b"], ["English"], [None]) + assert adapter.last_metrics["utterances_skipped_empty_output"] == 1.0 diff --git a/tests/pipeline/__init__.py b/tests/pipeline/__init__.py new file mode 100644 index 0000000000..06506ed8d6 --- /dev/null +++ b/tests/pipeline/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline unit tests.""" diff --git a/tests/pipeline/test_payload_refs.py b/tests/pipeline/test_payload_refs.py new file mode 100644 index 0000000000..4bbae51ac6 --- /dev/null +++ b/tests/pipeline/test_payload_refs.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Callable + +import pytest + +from nemo_curator.pipeline import payload_refs +from nemo_curator.pipeline.payload_refs import PayloadRef + + +class _RemoteMethod: + def __init__(self, function: Callable[..., object]) -> None: + self._function = function + + def remote(self, *args: object, **kwargs: object) -> object: + return self._function(*args, **kwargs) + + +class _AdmissionActor: + def __init__(self) -> None: + self.calls: list[list[tuple[str, str, float | None]]] = [] + self.heartbeat_many = _RemoteMethod(self._heartbeat_many) + + def _heartbeat_many(self, requests: list[tuple[str, str, float | None]]) -> list[bool]: + self.calls.append(requests) + return [True] * len(requests) + + +class _StoreActor: + def __init__(self, values: dict[str, object]) -> None: + self.values = values + self.pin_calls: list[list[tuple[str, float | None]]] = [] + self.get_calls: list[list[tuple[str, float | None]]] = [] + self.pin_many = _RemoteMethod(self._pin_many) + self.get_many = _RemoteMethod(self._get_many) + + def _pin_many(self, requests: list[tuple[str, float | None]]) -> list[bool]: + self.pin_calls.append(requests) + return [payload_id in self.values for payload_id, _ttl in requests] + + def _get_many(self, requests: list[tuple[str, float | None]]) -> list[object]: + self.get_calls.append(requests) + return [self.values[payload_id] for payload_id, _ttl in requests] + + +def _ref(payload_id: str, *, amount_bytes: int = 6) -> PayloadRef: + return PayloadRef( + payload_id=payload_id, + owner_node_id="node-a", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=amount_bytes, + sample_rate=16_000, + num_samples=100, + ) + + +def test_resolve_payload_refs_batched_is_byte_bounded_and_ordered(monkeypatch: pytest.MonkeyPatch) -> None: + admission = _AdmissionActor() + store = _StoreActor({"a": "payload-a", "b": "payload-b"}) + actors = {"admission": admission, "store": store} + monkeypatch.setattr(payload_refs, "_get_named_actor", lambda name, _namespace=None: actors[name]) + monkeypatch.setattr(payload_refs, "_ray_get", lambda value: value) + + resolved = payload_refs.resolve_payload_refs_batched( + [_ref("b"), _ref("a"), _ref("b")], + max_batch_bytes=10, + ) + + assert resolved == ["payload-b", "payload-a", "payload-b"] + assert [[request[1] for request in call] for call in admission.calls] == [["b"], ["a"]] + assert [[request[0] for request in call] for call in store.pin_calls] == [["b"], ["a"]] + assert [[request[0] for request in call] for call in store.get_calls] == [["b"], ["a"]] + + +def test_resolve_payload_refs_batched_rejects_missing_store_payload(monkeypatch: pytest.MonkeyPatch) -> None: + admission = _AdmissionActor() + store = _StoreActor({}) + actors = {"admission": admission, "store": store} + monkeypatch.setattr(payload_refs, "_get_named_actor", lambda name, _namespace=None: actors[name]) + monkeypatch.setattr(payload_refs, "_ray_get", lambda value: value) + + with pytest.raises(KeyError, match="no longer present"): + payload_refs.resolve_payload_refs_batched([_ref("missing")]) + + +def test_resolve_payload_refs_batched_rejects_boolean_byte_limit() -> None: + with pytest.raises(ValueError, match="max_batch_bytes must be positive"): + payload_refs.resolve_payload_refs_batched([_ref("payload")], max_batch_bytes=True) diff --git a/tests/pipeline/test_prefetch.py b/tests/pipeline/test_prefetch.py new file mode 100644 index 0000000000..4115bb2382 --- /dev/null +++ b/tests/pipeline/test_prefetch.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from threading import Event + +import pytest + +from nemo_curator.pipeline.prefetch import BoundedOneAheadPrefetchIterator + + +def test_one_ahead_prefetch_overlaps_next_load_and_preserves_order() -> None: + second_started = Event() + release_second = Event() + + def load(value: int) -> str: + if value == 2: + second_started.set() + assert release_second.wait(timeout=2.0) + return f"loaded-{value}" + + iterator = iter( + BoundedOneAheadPrefetchIterator( + [1, 2], + loader=load, + size_bytes=lambda _value: 4, + max_inflight_bytes=8, + ) + ) + + assert next(iterator) == (1, "loaded-1") + assert second_started.wait(timeout=2.0) + release_second.set() + assert next(iterator) == (2, "loaded-2") + with pytest.raises(StopIteration): + next(iterator) + + +def test_one_ahead_prefetch_respects_combined_byte_bound() -> None: + loaded: list[int] = [] + + def load(value: int) -> int: + loaded.append(value) + return value + + iterator = iter( + BoundedOneAheadPrefetchIterator( + [1, 2], + loader=load, + size_bytes=lambda _value: 6, + max_inflight_bytes=10, + ) + ) + + assert next(iterator) == (1, 1) + assert loaded == [1] + assert next(iterator) == (2, 2) + assert loaded == [1, 2] + + +def test_one_ahead_prefetch_propagates_loader_errors() -> None: + def load(value: int) -> int: + if value == 2: + msg = "cannot load second item" + raise RuntimeError(msg) + return value + + iterator = iter( + BoundedOneAheadPrefetchIterator( + [1, 2], + loader=load, + size_bytes=lambda _value: 1, + max_inflight_bytes=2, + ) + ) + + assert next(iterator) == (1, 1) + with pytest.raises(RuntimeError, match="cannot load second item"): + next(iterator) diff --git a/tests/pipelines/audio/__init__.py b/tests/pipelines/audio/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/pipelines/audio/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/pipelines/audio/test_qwen_omni_inprocess.py b/tests/pipelines/audio/test_qwen_omni_inprocess.py new file mode 100644 index 0000000000..e3d6bb3773 --- /dev/null +++ b/tests/pipelines/audio/test_qwen_omni_inprocess.py @@ -0,0 +1,326 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: ANN001, ANN202, S108 + +import pytest +from omegaconf import OmegaConf + +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.audio.common import ManifestReader, ManifestReaderStage, ManifestWriterStage +from nemo_curator.stages.audio.inference.asr.stage import ASRStage +from nemo_curator.stages.file_partitioning import FilePartitioningStage +from nemo_curator.stages.payload_lifecycle import AudioPayloadMaterializeStage, PayloadReleaseStage + + +class _DualPayloadConsumer(ASRStage): + def payload_bindings(self) -> list[dict[str, str]]: + return [ + { + "ref_key": "audio_ref", + "waveform_key": "audio", + "sample_rate_key": "sr", + "num_samples_key": "samples", + }, + { + "ref_key": "reference_ref", + "waveform_key": "reference_audio", + "sample_rate_key": "reference_sr", + "num_samples_key": "reference_samples", + }, + ] + + +def _cfg(*, consumers: list[str] | None = None, release_after: str = "qwen_omni"): + return OmegaConf.create( + { + "audio_reader_skip_on_read_error": True, + "payload_lifecycle": { + "enabled": True, + "materialize_after": "manifest_reader", + "payload_keys": ["audio_filepath"], + "ref_key": "audio_ref", + "consumers": consumers or ["qwen_omni"], + "release_after": release_after, + "target_sample_rate": 16000, + "target_nchannels": 1, + "node_memory_fraction": 0.55, + "max_node_payload_bytes": "10g", + "lease_ttl_s": 1234, + "admission_actor_name": "test_payload_admission", + "admission_poll_interval_s": 0.5, + }, + } + ) + + +def _logical_stages(): + reader = ManifestReader( + manifest_path="/tmp/input.jsonl", + ) + asr = ASRStage( + adapter_target="nemo_curator.models.asr.QwenOmniASRAdapter", + model_id="test/model", + name="qwen_omni", + waveform_key="audio", + waveform_ref_key="audio_ref", + sample_rate_key="sr", + pred_text_key="prediction", + disfluency_text_key="raw_prediction", + skip_me_key="skip_me", + max_inference_duration_s=120, + adapter_batch_size=1, + ) + writer = ManifestWriterStage(output_path="/tmp/output.jsonl") + return reader, asr, writer + + +def _asr_stage(name: str, *, pred_text_key: str = "prediction"): + return ASRStage( + adapter_target="nemo_curator.models.asr.QwenOmniASRAdapter", + model_id="test/model", + name=name, + waveform_key="audio", + waveform_ref_key="audio_ref", + sample_rate_key="sr", + pred_text_key=pred_text_key, + disfluency_text_key=None, + skip_me_key="skip_me", + max_inference_duration_s=120, + adapter_batch_size=1, + ) + + +def _dual_stage(name: str = "dual_gpu_stage"): + return _DualPayloadConsumer( + adapter_target="nemo_curator.models.asr.QwenOmniASRAdapter", + model_id="test/model", + name=name, + waveform_key="audio", + waveform_ref_key="audio_ref", + sample_rate_key="sr", + pred_text_key="prediction", + disfluency_text_key=None, + max_inference_duration_s=120, + adapter_batch_size=1, + ) + + +def _expanded(stages, cfg=None): + cfg = cfg or _cfg() + pipeline = Pipeline( + name="test_pipeline", + stages=list(stages), + config=OmegaConf.to_container(cfg, resolve=True), + ) + pipeline.build() + return pipeline.stages + + +def test_payload_lifecycle_build_is_idempotent() -> None: + pipeline = Pipeline( + name="test_pipeline", + stages=list(_logical_stages()), + config=OmegaConf.to_container(_cfg(), resolve=True), + ) + + pipeline.build() + first_names = [stage.name for stage in pipeline.stages] + pipeline.build() + + assert [stage.name for stage in pipeline.stages] == first_names + + +def test_payload_lifecycle_helpers_use_fresh_pipeline_run_id_without_mutating_config() -> None: + cfg = OmegaConf.to_container(_cfg(), resolve=True) + cfg["_curator_pipeline_run_id"] = "stale-run" + first_pipeline = Pipeline( + name="test_pipeline", + stages=list(_logical_stages()), + config=cfg, + ) + second_pipeline = Pipeline( + name="test_pipeline", + stages=list(_logical_stages()), + config=cfg, + ) + + first_pipeline.build() + second_pipeline.build() + first_materialize = next( + stage for stage in first_pipeline.stages if isinstance(stage, AudioPayloadMaterializeStage) + ) + second_materialize = next( + stage for stage in second_pipeline.stages if isinstance(stage, AudioPayloadMaterializeStage) + ) + + assert first_materialize.run_id != "stale-run" + assert second_materialize.run_id != "stale-run" + assert first_materialize.run_id != second_materialize.run_id + assert cfg["_curator_pipeline_run_id"] == "stale-run" + + +def test_payload_lifecycle_add_stage_after_build_replans_from_logical_graph() -> None: + reader, asr, writer = _logical_stages() + pipeline = Pipeline( + name="test_pipeline", + stages=[reader, asr], + config=OmegaConf.to_container(_cfg(), resolve=True), + ) + + pipeline.build() + assert [type(stage) for stage in pipeline.stages] == [ + FilePartitioningStage, + ManifestReaderStage, + AudioPayloadMaterializeStage, + ASRStage, + PayloadReleaseStage, + ] + + pipeline.add_stage(writer) + pipeline.build() + + assert [type(stage) for stage in pipeline.stages] == [ + FilePartitioningStage, + ManifestReaderStage, + AudioPayloadMaterializeStage, + ASRStage, + PayloadReleaseStage, + ManifestWriterStage, + ] + assert [stage.name for stage in pipeline.stages if stage.is_source_stage] == ["file_partitioning"] + assert [stage.name for stage in pipeline.stages if stage.is_sink_stage] == ["manifest_writer"] + + +def test_logical_graph_expands_to_payload_lifecycle() -> None: + expanded = _expanded(_logical_stages()) + + assert [type(stage) for stage in expanded] == [ + FilePartitioningStage, + ManifestReaderStage, + AudioPayloadMaterializeStage, + ASRStage, + PayloadReleaseStage, + ManifestWriterStage, + ] + + materialize = expanded[2] + assert materialize.target_sample_rate == 16000 + assert materialize.target_nchannels == 1 + assert materialize.duration_key == "duration" + assert materialize.num_samples_key == "num_samples" + assert materialize.waveform_key == "audio" + assert materialize.waveform_ref_key == "audio_ref" + assert materialize.sample_rate_key == "sr" + assert materialize.skip_on_read_error is True + assert materialize.node_memory_fraction == 0.55 + assert materialize.max_node_payload_bytes == "10g" + assert materialize.lease_ttl_s == 1234 + assert materialize.admission_actor_name == "test_payload_admission" + assert materialize.admission_poll_interval_s == 0.5 + assert materialize.run_id + + release = expanded[4] + assert release.payload_ref_key == "audio_ref" + assert release.waveform_key == "audio" + + +def test_payload_lifecycle_can_span_multiple_backend_visible_consumers() -> None: + reader = ManifestReader(manifest_path="/tmp/input.jsonl") + first = _asr_stage("gpu_stage_1", pred_text_key="stage_1_text") + second = _asr_stage("gpu_stage_2", pred_text_key="stage_2_text") + writer = ManifestWriterStage(output_path="/tmp/output.jsonl") + + expanded = _expanded( + [reader, first, second, writer], + _cfg(consumers=["gpu_stage_1", "gpu_stage_2"], release_after="gpu_stage_2"), + ) + + assert [type(stage) for stage in expanded] == [ + FilePartitioningStage, + ManifestReaderStage, + AudioPayloadMaterializeStage, + ASRStage, + ASRStage, + PayloadReleaseStage, + ManifestWriterStage, + ] + assert [stage.name for stage in expanded] == [ + "file_partitioning", + "manifest_reader_stage", + "audio_payload_materialize", + "gpu_stage_1", + "gpu_stage_2", + "payload_release", + "manifest_writer", + ] + + +def test_payload_lifecycle_supports_multiple_source_keys() -> None: + reader = ManifestReader(manifest_path="/tmp/input.jsonl") + consumer = _dual_stage() + writer = ManifestWriterStage(output_path="/tmp/output.jsonl") + cfg = OmegaConf.create( + { + "payload_lifecycle": { + "enabled": True, + "materialize_after": "manifest_reader", + "payloads": [ + { + "source_key": "audio_filepath", + "ref_key": "audio_ref", + "waveform_key": "audio", + "sample_rate_key": "sr", + "num_samples_key": "samples", + }, + { + "source_key": "reference_audio_filepath", + "ref_key": "reference_ref", + "waveform_key": "reference_audio", + "sample_rate_key": "reference_sr", + "num_samples_key": "reference_samples", + "duration_key": "reference_duration", + }, + ], + "consumers": ["dual_gpu_stage"], + "release_after": "dual_gpu_stage", + } + } + ) + + expanded = _expanded([reader, consumer, writer], cfg) + + assert [type(stage) for stage in expanded] == [ + FilePartitioningStage, + ManifestReaderStage, + AudioPayloadMaterializeStage, + AudioPayloadMaterializeStage, + _DualPayloadConsumer, + PayloadReleaseStage, + ManifestWriterStage, + ] + assert expanded[2].audio_filepath_key == "audio_filepath" + assert expanded[2].waveform_ref_key == "audio_ref" + assert expanded[3].audio_filepath_key == "reference_audio_filepath" + assert expanded[3].waveform_ref_key == "reference_ref" + assert expanded[3].duration_key == "reference_duration" + + +def test_raw_qwen_config_rejects_explicit_helper_stages() -> None: + reader, asr, writer = _logical_stages() + + with pytest.raises(ValueError, match="logical stages only"): + _expanded([reader, AudioPayloadMaterializeStage(), asr, writer]) + + with pytest.raises(ValueError, match="logical stages only"): + _expanded([reader, asr, PayloadReleaseStage(), writer]) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index bc6c380e13..66194ed148 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -18,7 +18,7 @@ import pytest from nemo_curator.pipeline.pipeline import Pipeline, assign_root_task_ids -from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.base import CompositeStage, ProcessingStage from nemo_curator.stages.resources import Resources from nemo_curator.tasks import EmptyTask, Task @@ -37,6 +37,17 @@ def process(self, task: Task) -> Task: return task +@dataclass +class _SingleStageComposite(CompositeStage[Task, Task]): + name: str = "single" + + def __post_init__(self) -> None: + super().__init__() + + def decompose(self) -> list[ProcessingStage]: + return [_NoopStage(name="leaf")] + + @dataclass class _SimpleTask(Task[list[int]]): @property @@ -113,6 +124,34 @@ def test_default_first_source_last_sink_stage(self) -> None: assert lone.is_source_stage is True assert lone.is_sink_stage is True + def test_add_stage_after_build_reassigns_default_sink(self) -> None: + s0, s1, s2 = _NoopStage(name="s0"), _NoopStage(name="s1"), _NoopStage(name="s2") + pipeline = Pipeline(name="t", stages=[s0, s1]) + + pipeline.build() + assert [s.is_source_stage for s in (s0, s1)] == [True, False] + assert [s.is_sink_stage for s in (s0, s1)] == [False, True] + + pipeline.add_stage(s2) + pipeline.build() + + assert [s.is_source_stage for s in (s0, s1, s2)] == [True, False, False] + assert [s.is_sink_stage for s in (s0, s1, s2)] == [False, False, True] + + def test_direct_stage_append_after_build_replans_and_reassigns_sink(self) -> None: + s0, s1, s2 = _NoopStage(name="s0"), _NoopStage(name="s1"), _NoopStage(name="s2") + pipeline = Pipeline(name="t", stages=[s0, s1]) + + pipeline.build() + assert [s.is_source_stage for s in (s0, s1)] == [True, False] + assert [s.is_sink_stage for s in (s0, s1)] == [False, True] + + pipeline.stages.append(s2) + pipeline.build() + + assert [s.is_source_stage for s in (s0, s1, s2)] == [True, False, False] + assert [s.is_sink_stage for s in (s0, s1, s2)] == [False, False, True] + def test_explicit_marks_override_defaults(self) -> None: s0, s1, s2 = _NoopStage(name="s0"), _NoopStage(name="s1"), _NoopStage(name="s2") s1.is_source_stage = True @@ -135,25 +174,36 @@ def test_multiple_explicit_marks_raise(self) -> None: with pytest.raises(ValueError, match="multiple sink stages marked"): Pipeline(name="t", stages=[t0, t1]).build() + def test_single_stage_composite_preserves_main_behavior(self) -> None: + pipeline = Pipeline(name="t", stages=[_SingleStageComposite()]) + pipeline.build() + + assert [type(stage) for stage in pipeline.stages] == [_SingleStageComposite] + assert pipeline.decomposition_info == {} + class TestRootTaskIds: - """``assign_root_task_ids`` roots user-provided initial tasks under the - implicit ``EmptyTask`` root id ``"0"``.""" + """``assign_root_task_ids`` follows the framework-owned main contract.""" def test_empty_task_id_is_zero(self) -> None: assert EmptyTask().task_id == "0" assert EmptyTask(dataset_name="d", data=None).task_id == "0" + def test_rewrites_existing_internal_task_ids(self) -> None: + tasks = [_SimpleTask(dataset_name="d", data=[1]) for _ in range(3)] + for i, task in enumerate(tasks): + task.task_id = f"t{i}" + assign_root_task_ids(tasks) + assert [t.task_id for t in tasks] == ["0_0", "0_1", "0_2"] + def test_roots_user_tasks_at_zero(self) -> None: tasks = [_SimpleTask(dataset_name="d", data=[1]) for _ in range(3)] assign_root_task_ids(tasks) - # User-provided initial tasks are children of root "0", by position. assert [t.task_id for t in tasks] == ["0_0", "0_1", "0_2"] def test_skips_empty_tasks(self) -> None: et = EmptyTask(dataset_name="d", data=None) real = _SimpleTask(dataset_name="d", data=[1]) assign_root_task_ids([et, real]) - # EmptyTask stays "0"; the real task is rooted by its position. assert et.task_id == "0" assert real.task_id == "0_1" diff --git a/tests/stages/audio/inference/asr/__init__.py b/tests/stages/audio/inference/asr/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/stages/audio/inference/test_asr_stage.py b/tests/stages/audio/inference/test_asr_stage.py new file mode 100644 index 0000000000..4c5b76617a --- /dev/null +++ b/tests/stages/audio/inference/test_asr_stage.py @@ -0,0 +1,777 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: S108 + +"""Tests for the generic ``ASRStage`` exercised against a mock ``ASRAdapter`` (no real model load).""" + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +from nemo_curator.backends.base import BaseStageAdapter +from nemo_curator.models.asr.base import ASRResult +from nemo_curator.pipeline.payload_refs import PayloadRef +from nemo_curator.stages.audio.inference.asr import ASRStage +from nemo_curator.stages.audio.inference.asr import stage as asr_stage_module +from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy +from nemo_curator.tasks import AudioTask + +_QWEN_ADAPTER_TARGET = "nemo_curator.models.asr.qwen_omni.QwenOmniASRAdapter" +_SR = 16000 + + +def _make_stage( # noqa: PLR0913 + *, + disfluency_text_key: str | None = None, + keep_waveform: bool = True, + default_language: str | None = None, + max_inference_duration_s: float = 2400.0, + batch_policy: BatchPolicy | None = None, + batch_size: int = 32, + reference_text_key: str | None = None, + supported_language_codes: list[str] | None = None, + payload_prefetch_enabled: bool = False, + payload_prefetch_max_bytes: int | None = None, + payload_resolve_max_batch_bytes: int | None = None, +) -> ASRStage: + """Build an ASRStage wired to a mock adapter (no real model load).""" + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/qwen-omni", + pred_text_key="qwen3_prediction_s1", + disfluency_text_key=disfluency_text_key, + keep_waveform=keep_waveform, + default_language=default_language, + max_inference_duration_s=max_inference_duration_s, + batch_policy=batch_policy, + batch_size=batch_size, + reference_text_key=reference_text_key, + supported_language_codes=supported_language_codes, + payload_prefetch_enabled=payload_prefetch_enabled, + payload_prefetch_max_bytes=payload_prefetch_max_bytes, + payload_resolve_max_batch_bytes=payload_resolve_max_batch_bytes, + ) + mock_adapter = MagicMock() + mock_adapter.last_metrics = {} + stage._adapter = mock_adapter + return stage + + +def _make_task(waveform_len: int = _SR, source_lang: str | None = "en") -> AudioTask: + data: dict[str, object] = { + "waveform": np.zeros(waveform_len, dtype=np.float32), + "sample_rate": _SR, + } + if source_lang is not None: + data["source_lang"] = source_lang + return AudioTask(data=data) + + +def _chunking_policy() -> BatchPolicy: + return BatchPolicy( + buckets_sec=[0], + max_items_per_batch_by_bucket=[32], + max_audio_sec_per_batch=None, + ) + + +# ---------------------------------------------------------------------- +# Stage-level: process / process_batch contract +# ---------------------------------------------------------------------- + + +def test_process_raises_not_implemented() -> None: + stage = _make_stage() + with pytest.raises(NotImplementedError): + stage.process(_make_task()) + + +def test_empty_batch() -> None: + stage = _make_stage() + assert stage.process_batch([]) == [] + + +def test_basic_inference_single_turn() -> None: + stage = _make_stage(keep_waveform=False) + stage._adapter.transcribe_batch.return_value = [ASRResult(text="hello world")] + + results = stage.process_batch([_make_task()]) + + assert results[0].data["qwen3_prediction_s1"] == "hello world" + assert "waveform" not in results[0].data # keep_waveform=False -> dropped + + +def test_keep_waveform_default_is_true() -> None: + """Default ``keep_waveform`` is True so downstream stages can reuse the waveform.""" + stage = _make_stage() # no keep_waveform override + stage._adapter.transcribe_batch.return_value = [ASRResult(text="hello world")] + + results = stage.process_batch([_make_task()]) + assert "waveform" in results[0].data + + +def test_disfluency_text_key_stores_secondary() -> None: + stage = _make_stage(disfluency_text_key="qwen3_prediction_s2") + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="hello world", secondary_text="hello world cleaned"), + ] + + results = stage.process_batch([_make_task()]) + + assert results[0].data["qwen3_prediction_s1"] == "hello world" + assert results[0].data["qwen3_prediction_s2"] == "hello world cleaned" + assert "qwen3_prediction_s2" in stage.outputs()[1] + + +def test_disfluency_text_key_none_is_normalised_to_empty_string() -> None: + stage = _make_stage(disfluency_text_key="qwen3_prediction_s2") + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="hello world", secondary_text=None), + ] + + results = stage.process_batch([_make_task()]) + assert results[0].data["qwen3_prediction_s2"] == "" + + +def test_keep_waveform_false_drops_waveform() -> None: + stage = _make_stage(keep_waveform=False) + stage._adapter.transcribe_batch.return_value = [ASRResult(text="text")] + + results = stage.process_batch([_make_task()]) + assert "waveform" not in results[0].data + + +def test_adapter_not_initialized_raises() -> None: + stage = ASRStage(adapter_target=_QWEN_ADAPTER_TARGET, model_id="mock/model") + with pytest.raises(RuntimeError, match="setup"): + stage.process_batch([_make_task()]) + + +def test_multi_task_batch_preserves_order() -> None: + stage = _make_stage() + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="text1"), + ASRResult(text="text2"), + ] + results = stage.process_batch([_make_task(), _make_task()]) + + assert results[0].data["qwen3_prediction_s1"] == "text1" + assert results[1].data["qwen3_prediction_s1"] == "text2" + + +def test_adapter_result_length_mismatch_raises() -> None: + stage = _make_stage() + stage._adapter.transcribe_batch.return_value = [ASRResult(text="x")] # 1 result + with pytest.raises(RuntimeError, match=r"returned 1 results for 2 items"): + stage.process_batch([_make_task(), _make_task()]) + + +# ---------------------------------------------------------------------- +# Model-input segmentation + stitch-back +# ---------------------------------------------------------------------- + + +def test_pre_slice_short_clip_passes_through_unchanged() -> None: + """A clip under max_inference_duration_s yields one sub-chunk; no stitching.""" + stage = _make_stage( + max_inference_duration_s=2400.0, + batch_policy=_chunking_policy(), + ) + stage._adapter.transcribe_batch.return_value = [ASRResult(text="single")] + + BaseStageAdapter(stage).process_batch([_make_task(waveform_len=_SR * 30)]) # 30 s clip + + items = stage._adapter.transcribe_batch.call_args[0][0] + assert len(items) == 1 + assert items[0]["chunk_count"] == 1 + assert items[0]["chunk_idx"] == 0 + + +def test_pre_slice_over_long_clip_into_contiguous_sub_chunks() -> None: + """A 95-s clip with max_inference_duration_s=30 s slices into [30, 30, 30, 5] sub-chunks.""" + stage = _make_stage( + max_inference_duration_s=30.0, + batch_policy=_chunking_policy(), + ) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="chunk0"), + ASRResult(text="chunk1"), + ASRResult(text="chunk2"), + ASRResult(text="chunk3"), + ] + waveform = np.arange(_SR * 95, dtype=np.float32) + task = AudioTask(data={"waveform": waveform, "sample_rate": _SR, "source_lang": "en"}) + BaseStageAdapter(stage).process_batch([task]) + + items = stage._adapter.transcribe_batch.call_args[0][0] + assert len(items) == 4 + assert [it["chunk_idx"] for it in items] == [0, 1, 2, 3] + assert all(it["chunk_count"] == 4 for it in items) + chunk_lengths = [int(it["waveform"].shape[0]) for it in items] + assert chunk_lengths == [_SR * 30, _SR * 30, _SR * 30, _SR * 5] + assert sum(chunk_lengths) == int(waveform.shape[0]) # no audio lost / repeated + # Sub-chunks are the contiguous prefix of waveform. + np.testing.assert_array_equal(items[0]["waveform"], waveform[: _SR * 30]) + np.testing.assert_array_equal(items[1]["waveform"], waveform[_SR * 30 : _SR * 60]) + np.testing.assert_array_equal(items[2]["waveform"], waveform[_SR * 60 : _SR * 90]) + np.testing.assert_array_equal(items[3]["waveform"], waveform[_SR * 90 :]) + + +def test_pre_slice_canonical_torch_waveform_uses_sample_axis() -> None: + """Canonical ``(channels, samples)`` tensors are sliced along the sample axis.""" + stage = _make_stage( + max_inference_duration_s=30.0, + batch_policy=_chunking_policy(), + ) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="chunk0"), + ASRResult(text="chunk1"), + ASRResult(text="chunk2"), + ASRResult(text="chunk3"), + ] + waveform = torch.arange(_SR * 95, dtype=torch.float32).reshape(1, -1) + task = AudioTask(data={"waveform": waveform, "sample_rate": _SR, "source_lang": "en"}) + BaseStageAdapter(stage).process_batch([task]) + + items = stage._adapter.transcribe_batch.call_args[0][0] + assert len(items) == 4 + assert [tuple(it["waveform"].shape) for it in items] == [ + (1, _SR * 30), + (1, _SR * 30), + (1, _SR * 30), + (1, _SR * 5), + ] + torch.testing.assert_close(items[0]["waveform"], waveform[:, : _SR * 30]) + + +def test_pre_slice_stitch_back_joins_per_parent_with_single_space() -> None: + """Stitch-back joins sub-chunk texts (and secondary texts) with a single space; one row per parent.""" + stage = _make_stage( + disfluency_text_key="qwen3_prediction_s2", + max_inference_duration_s=30.0, + batch_policy=_chunking_policy(), + ) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="hello", secondary_text="hello clean"), + ASRResult(text="world", secondary_text="world clean"), + ] + waveform = np.zeros(_SR * 50, dtype=np.float32) # 50s -> 2 sub-chunks + task = AudioTask(data={"waveform": waveform, "sample_rate": _SR, "source_lang": "en"}) + results = BaseStageAdapter(stage).process_batch([task]) + + assert len(results) == 1 # one parent in, one parent out + assert results[0].data["qwen3_prediction_s1"] == "hello world" + assert results[0].data["qwen3_prediction_s2"] == "hello clean world clean" + + +def test_pre_slice_marks_parent_skipped_only_if_all_chunks_skipped() -> None: + """A parent is marked skipped only if every sub-chunk was skipped.""" + stage = _make_stage( + max_inference_duration_s=30.0, + batch_policy=_chunking_policy(), + ) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="good", skipped=False), + ASRResult(text="", skipped=True), + # task2: every chunk skipped + ASRResult(text="", skipped=True), + ASRResult(text="", skipped=True), + ] + task_partial = AudioTask( + data={ + "waveform": np.zeros(_SR * 50, dtype=np.float32), + "sample_rate": _SR, + } + ) + task_all_skip = AudioTask( + data={ + "waveform": np.zeros(_SR * 50, dtype=np.float32), + "sample_rate": _SR, + } + ) + results = BaseStageAdapter(stage).process_batch([task_partial, task_all_skip]) + + assert results[0].data["qwen3_prediction_s1"] == "good" + assert results[0].data.get("_skip_me") != "empty_audio" + assert results[1].data["qwen3_prediction_s1"] == "" + assert results[1].data["_skip_me"] == "empty_audio" + + +def test_segment_metrics_count_model_items_and_parent_rows() -> None: + """Metrics count model segments while stitch-back restores parent rows.""" + stage = _make_stage( + max_inference_duration_s=30.0, + batch_policy=_chunking_policy(), + ) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="a"), + ASRResult(text="b"), + ASRResult(text="c"), + ASRResult(text="d"), + ] + stage._log_metrics = MagicMock() # type: ignore[method-assign] + + task_short = AudioTask(data={"waveform": np.zeros(_SR * 10, dtype=np.float32), "sample_rate": _SR}) # 1 chunk + task_long = AudioTask(data={"waveform": np.zeros(_SR * 75, dtype=np.float32), "sample_rate": _SR}) # 3 chunks + BaseStageAdapter(stage).process_batch([task_short, task_long]) + + metrics = stage._log_metrics.call_args[0][0] + assert metrics["utterances_input"] == 2.0 + assert metrics["utterances_processed"] == 2.0 + assert metrics["sub_chunks_generated"] == 4.0 + + +def test_model_input_segmentation_without_batch_policy_slices_normal_flow() -> None: + """Chunking is independent from scheduler bucketing and works in normal process_batch.""" + stage = _make_stage( + max_inference_duration_s=30.0, + batch_policy=BatchPolicy( + enabled=False, + strategy="placeholder", + buckets_sec=[], + max_items_per_batch_by_bucket=[], + ), + ) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="chunk0"), + ASRResult(text="chunk1"), + ASRResult(text="chunk2"), + ASRResult(text="chunk3"), + ] + waveform = np.arange(_SR * 95, dtype=np.float32) + task = AudioTask(data={"waveform": waveform, "sample_rate": _SR, "source_lang": "en"}) + + result = stage.process_batch([task]) + + items = stage._adapter.transcribe_batch.call_args[0][0] + assert [it["chunk_idx"] for it in items] == [0, 1, 2, 3] + assert all(it["chunk_count"] == 4 for it in items) + assert result[0].data["qwen3_prediction_s1"] == "chunk0 chunk1 chunk2 chunk3" + + +def test_model_input_segmentation_normal_flow_caps_adapter_calls_by_batch_size() -> None: + stage = _make_stage( + max_inference_duration_s=30.0, + batch_policy=None, + batch_size=2, + ) + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="chunk0"), ASRResult(text="chunk1")], + [ASRResult(text="chunk2"), ASRResult(text="chunk3")], + ] + waveform = np.arange(_SR * 95, dtype=np.float32) + task = AudioTask(data={"waveform": waveform, "sample_rate": _SR, "source_lang": "en"}) + + result = stage.process_batch([task]) + + assert [len(call.args[0]) for call in stage._adapter.transcribe_batch.call_args_list] == [2, 2] + assert result[0].data["qwen3_prediction_s1"] == "chunk0 chunk1 chunk2 chunk3" + + +def test_payload_prefetch_plans_from_metadata_resolves_parent_once_and_slices_per_call( + monkeypatch: pytest.MonkeyPatch, +) -> None: + waveform = np.arange(_SR * 75, dtype=np.float32) + payload_ref = PayloadRef( + payload_id="payload-1", + owner_node_id="node-a", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=int(waveform.nbytes), + sample_rate=_SR, + num_samples=len(waveform), + ) + resolve_calls: list[list[str]] = [] + + def resolve(refs: list[PayloadRef], *, max_batch_bytes: int | None = None) -> list[np.ndarray]: + assert max_batch_bytes == 8_000_000 + resolve_calls.append([ref.payload_id for ref in refs]) + return [waveform for _ref in refs] + + monkeypatch.setattr(asr_stage_module, "resolve_payload_refs_batched", resolve) + stage = _make_stage( + max_inference_duration_s=30.0, + batch_size=1, + payload_prefetch_enabled=True, + payload_prefetch_max_bytes=10_000_000, + payload_resolve_max_batch_bytes=8_000_000, + ) + stage._start_payload_lease_keeper = MagicMock() # type: ignore[method-assign] + stage._stop_payload_lease_keeper = MagicMock() # type: ignore[method-assign] + stage.payload_consumer_node_id = MagicMock(return_value="node-a") # type: ignore[method-assign] + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="chunk0")], + [ASRResult(text="chunk1")], + [ASRResult(text="chunk2")], + ] + task = AudioTask(data={"waveform_ref": payload_ref, "sample_rate": _SR, "source_lang": "en"}) + + result = stage.process_batch([task]) + + assert resolve_calls == [["payload-1"]] + assert [len(call.args[0][0]["waveform"]) for call in stage._adapter.transcribe_batch.call_args_list] == [ + _SR * 30, + _SR * 30, + _SR * 15, + ] + assert result[0].data["qwen3_prediction_s1"] == "chunk0 chunk1 chunk2" + assert "waveform" not in result[0].data + + +def test_payload_prefetch_requires_explicit_byte_budget() -> None: + with pytest.raises(ValueError, match="payload_prefetch_max_bytes is required"): + _make_stage(payload_prefetch_enabled=True) + + +def test_payload_prefetch_enabled_requires_bool() -> None: + with pytest.raises(TypeError, match="payload_prefetch_enabled must be a bool"): + _make_stage(payload_prefetch_enabled="true") # type: ignore[arg-type] + + +@pytest.mark.parametrize("field", ["payload_resolve_max_batch_bytes", "payload_prefetch_max_bytes"]) +def test_payload_byte_limits_reject_bool(field: str) -> None: + with pytest.raises(ValueError, match=field): + _make_stage(**{field: True}) + + +# ---------------------------------------------------------------------- +# Stage-level: language mapping (ISO code -> name) +# ---------------------------------------------------------------------- + + +def test_language_resolution_from_task() -> None: + stage = _make_stage() + stage._adapter.transcribe_batch.return_value = [ASRResult(text="hola")] + + task = AudioTask( + data={ + "waveform": np.zeros(_SR, dtype=np.float32), + "sample_rate": _SR, + "source_lang": "es", + } + ) + stage.process_batch([task]) + + items = stage._adapter.transcribe_batch.call_args[0][0] + assert items[0]["language"] == "Spanish" + + +def test_default_language_used_when_task_language_missing() -> None: + stage = _make_stage(default_language="en") + stage._adapter.transcribe_batch.return_value = [ASRResult(text="hello")] + + task = AudioTask( + data={ + "waveform": np.zeros(_SR, dtype=np.float32), + "sample_rate": _SR, + } + ) + stage.process_batch([task]) + + items = stage._adapter.transcribe_batch.call_args[0][0] + assert items[0]["language"] == "English" + + +def test_supported_language_filter_skips_before_adapter_call() -> None: + stage = _make_stage(supported_language_codes=["en"]) + + results = stage.process_batch([_make_task(source_lang="pl")]) + + stage._adapter.transcribe_batch.assert_not_called() + assert results[0].data["qwen3_prediction_s1"] == "" + assert results[0].data["_skip_me"] == "lang_not_supported:pl" + + +def test_reference_text_key_is_passed_to_adapter_items() -> None: + stage = _make_stage(reference_text_key="text") + stage._adapter.transcribe_batch.return_value = [ASRResult(text="hello")] + + task = AudioTask( + data={ + "waveform": np.zeros(_SR, dtype=np.float32), + "sample_rate": _SR, + "source_lang": "en", + "text": "reference transcript", + } + ) + stage.process_batch([task]) + + items = stage._adapter.transcribe_batch.call_args[0][0] + assert items[0]["reference_text"] == "reference transcript" + + +def test_reference_text_key_is_preserved_for_chunked_items() -> None: + stage = _make_stage( + max_inference_duration_s=30.0, + reference_text_key="text", + ) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="chunk0"), + ASRResult(text="chunk1"), + ] + + task = AudioTask( + data={ + "waveform": np.zeros(_SR * 50, dtype=np.float32), + "sample_rate": _SR, + "source_lang": "en", + "text": "reference transcript", + } + ) + stage.process_batch([task]) + + items = stage._adapter.transcribe_batch.call_args[0][0] + assert [item["reference_text"] for item in items] == ["reference transcript", "reference transcript"] + + +# ---------------------------------------------------------------------- +# Stage-level: I/O contract +# ---------------------------------------------------------------------- + + +def test_inputs_outputs_single_turn() -> None: + stage = ASRStage(adapter_target=_QWEN_ADAPTER_TARGET, model_id="mock/model") + _required, optional_inputs = stage.inputs() + assert "waveform_ref" in optional_inputs + assert "sample_rate" in optional_inputs + + _required, optional_outputs = stage.outputs() + assert "_skip_me" in optional_outputs + assert "pred_text" in optional_outputs + + +def test_inputs_use_waveform_when_payload_ref_is_disabled() -> None: + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/model", + waveform_ref_key=None, + ) + _required, optional_inputs = stage.inputs() + assert "waveform" in optional_inputs + + +def test_inputs_include_reference_text_key_when_configured() -> None: + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/model", + reference_text_key="text", + ) + _required, optional_inputs = stage.inputs() + assert "text" in optional_inputs + + +def test_outputs_two_turn_includes_disfluency_key() -> None: + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/model", + disfluency_text_key="pred_text_secondary", + ) + _required, optional_outputs = stage.outputs() + assert "pred_text_secondary" in optional_outputs + + +# ---------------------------------------------------------------------- +# Stage-level: skip / metrics +# ---------------------------------------------------------------------- + + +def test_skipped_result_sets_skip_key() -> None: + stage = _make_stage() + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="", skipped=True), + ] + results = stage.process_batch([_make_task()]) + assert results[0].data["_skip_me"] == "empty_audio" + + +def test_metrics_account_skipped_utterances() -> None: + stage = _make_stage() + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="text"), + ASRResult(text="", skipped=True), + ] + stage._log_metrics = MagicMock() # type: ignore[method-assign] + + stage.process_batch([_make_task(), _make_task()]) + + metrics = stage._log_metrics.call_args[0][0] + assert metrics["utterances_input"] == 2.0 + assert metrics["utterances_processed"] == 1.0 + assert metrics["utterances_skipped"] == 1.0 + assert metrics["sub_chunks_generated"] == 2.0 + assert metrics["adapter_inference_calls"] == 1.0 + assert metrics["adapter_inference_items"] == 2.0 + + +def test_metrics_count_actual_adapter_inference_calls_after_chunk_splitting() -> None: + stage = _make_stage( + max_inference_duration_s=30.0, + batch_size=1, + ) + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="a")], + [ASRResult(text="b")], + [ASRResult(text="c")], + ] + stage._log_metrics = MagicMock() # type: ignore[method-assign] + + task = AudioTask(data={"waveform": np.zeros(_SR * 75, dtype=np.float32), "sample_rate": _SR}) + stage.process_batch([task]) + + metrics = stage._log_metrics.call_args[0][0] + assert metrics["utterances_input"] == 1.0 + assert metrics["sub_chunks_generated"] == 3.0 + assert metrics["adapter_inference_calls"] == 3.0 + assert metrics["adapter_inference_items"] == 3.0 + assert stage._adapter.transcribe_batch.call_count == 3 + + +def test_metrics_model_alias_skips_already_emitted_keys() -> None: + """Adapter metrics the stage already emits must NOT be re-aliased as ``model_``.""" + stage = _make_stage() + stage._adapter.transcribe_batch.return_value = [ASRResult(text="text")] + stage._adapter.last_metrics = { + "audio_duration_s": 999.0, + "extra_diagnostic_metric": 42.0, + } + stage._log_metrics = MagicMock() # type: ignore[method-assign] + + stage.process_batch([_make_task()]) + + metrics = stage._log_metrics.call_args[0][0] + assert "model_audio_duration_s" not in metrics + assert metrics["model_extra_diagnostic_metric"] == 42.0 + + +# ---------------------------------------------------------------------- +# Stage-level: setup_on_node weight prefetch + setup() adapter construction +# ---------------------------------------------------------------------- + + +@patch("nemo_curator.models.asr.qwen_omni.snapshot_download") +def test_setup_on_node_downloads_weights(mock_download: MagicMock) -> None: + stage = ASRStage(adapter_target=_QWEN_ADAPTER_TARGET, model_id="mock/model") + stage.setup_on_node() + mock_download.assert_called_once_with("mock/model") + + +@patch( + "nemo_curator.models.asr.qwen_omni.snapshot_download", + side_effect=RuntimeError("missing auth"), +) +def test_setup_on_node_raises_by_default(mock_download: MagicMock) -> None: + stage = ASRStage(adapter_target=_QWEN_ADAPTER_TARGET, model_id="mock/model") + with pytest.raises(RuntimeError, match="prefetch_weights failed"): + stage.setup_on_node() + mock_download.assert_called_once_with("mock/model") + + +@patch( + "nemo_curator.models.asr.qwen_omni.snapshot_download", + side_effect=RuntimeError("offline"), +) +def test_setup_on_node_can_warn_and_retry_later(mock_download: MagicMock) -> None: + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/model", + prefetch_fail_on_error=False, + ) + stage.setup_on_node() + mock_download.assert_called_once_with("mock/model") + + +def test_adapter_target_required() -> None: + with pytest.raises(TypeError): + ASRStage(model_id="mock/model") + + +def test_model_id_required() -> None: + with pytest.raises(TypeError): + ASRStage(adapter_target=_QWEN_ADAPTER_TARGET) + + +def test_max_inference_duration_defaults_to_qwen_window() -> None: + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/model", + ) + assert stage.max_inference_duration_s == 2400.0 + + +def test_setup_uses_adapter_target_and_kwargs() -> None: + """``setup()`` resolves adapter_target via hydra.utils.get_class and + constructs the adapter with model_id+revision+**adapter_kwargs.""" + adapter_kwargs = { + "prompt_file": "/tmp/ml.md", + "en_prompt_file": "/tmp/en.md", + "followup_prompt_file": "/tmp/followup.md", + "system_prompt_file": "/tmp/system.md", + "max_model_len": 8192, + "max_num_batched_tokens": 49152, + "max_num_seqs": 8, + "gpu_memory_utilization": 0.8, + "tensor_parallel_size": 4, + "max_output_tokens": 384, + "temperature": 0.1, + "top_k": 5, + "prep_workers": 16, + "enable_prefix_caching": False, + "prefix_caching_hash_algo": "sha256", + "limit_mm_per_prompt_audio": 1, + "seed": 999, + } + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/model", + revision="abc123", + adapter_kwargs=adapter_kwargs, + ) + + fake_adapter = MagicMock() + fake_cls = MagicMock(return_value=fake_adapter) + with patch("hydra.utils.get_class", return_value=fake_cls) as get_class: + stage.setup() + + get_class.assert_called_with(_QWEN_ADAPTER_TARGET) + fake_cls.assert_called_once_with( + model_id="mock/model", + revision="abc123", + **adapter_kwargs, + ) + fake_adapter.setup.assert_called_once_with() + assert stage._adapter is fake_adapter + + +def test_setup_on_node_prefetches_without_constructing_adapter() -> None: + """``setup_on_node()`` uses the adapter classmethod only. + + Adapter construction is reserved for ``setup()``, where stage-level + ``adapter_kwargs`` are forwarded to the worker-local adapter instance. + """ + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/model", + revision="abc123", + adapter_kwargs={"prompt_file": "/tmp/ml.md"}, + ) + + fake_cls = MagicMock() + fake_cls.prefetch_weights = MagicMock() + with patch("hydra.utils.get_class", return_value=fake_cls): + stage.setup_on_node() + + fake_cls.assert_not_called() + fake_cls.prefetch_weights.assert_called_once_with("mock/model", "abc123") diff --git a/tests/stages/audio/inference/test_batch_policy.py b/tests/stages/audio/inference/test_batch_policy.py new file mode 100644 index 0000000000..092a0da49d --- /dev/null +++ b/tests/stages/audio/inference/test_batch_policy.py @@ -0,0 +1,501 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the generic cost-bucketed batching primitives: ``BatchPolicy`` and ``run_bucketed``.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from nemo_curator.backends.base import BaseStageAdapter +from nemo_curator.models.asr.base import ASRResult +from nemo_curator.stages.audio.inference.asr import ASRStage +from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy, BucketQueueScheduler, run_bucketed +from nemo_curator.tasks import AudioTask + +_QWEN_ADAPTER_TARGET = "nemo_curator.models.asr.qwen_omni.QwenOmniASRAdapter" +_SR = 16000 + + +def _make_stage( + *, + batch_policy: BatchPolicy | None = None, + batch_size: int = 32, + adapter_batch_size: int | None = None, +) -> ASRStage: + stage = ASRStage( + adapter_target=_QWEN_ADAPTER_TARGET, + model_id="mock/qwen-omni", + pred_text_key="qwen3_prediction_s1", + batch_policy=batch_policy, + batch_size=batch_size, + adapter_batch_size=adapter_batch_size, + ) + mock_adapter = MagicMock() + mock_adapter.last_metrics = {} + stage._adapter = mock_adapter + return stage + + +def _make_task(waveform_len: int = _SR) -> AudioTask: + return AudioTask(data={"waveform": np.zeros(waveform_len, dtype=np.float32), "sample_rate": _SR}) + + +# ---------------------------------------------------------------------- +# BatchPolicy: validation + bucket math +# ---------------------------------------------------------------------- + + +def test_batch_policy_invalid_strategy_rejected() -> None: + with pytest.raises(ValueError, match="duration_bucketed"): + BatchPolicy(strategy="token_bucketed") + + +def test_batch_policy_inconsistent_lengths_rejected() -> None: + with pytest.raises(ValueError, match="lengths must match"): + BatchPolicy(buckets_sec=[0, 60, 600], max_items_per_batch_by_bucket=[10, 5]) + + +def test_batch_policy_disabled_allows_placeholder_bucket_config() -> None: + policy = BatchPolicy( + enabled=False, + strategy="placeholder", + buckets_sec=[], + max_items_per_batch_by_bucket=[], + max_audio_sec_per_batch=-1.0, + ) + + assert policy.enabled is False + + +def test_batch_policy_enabled_must_be_bool() -> None: + with pytest.raises(TypeError, match="enabled must be a bool"): + BatchPolicy(enabled="false") # type: ignore[arg-type] + + +def test_batch_policy_prebatching_window_size_validation() -> None: + with pytest.raises(TypeError, match="prebatching_window_size must be an int or None"): + BatchPolicy(prebatching_window_size="8") # type: ignore[arg-type] + + with pytest.raises(ValueError, match="prebatching_window_size must be > 0"): + BatchPolicy(prebatching_window_size=0) + + policy = BatchPolicy( + enabled=False, + strategy="placeholder", + buckets_sec=[], + max_items_per_batch_by_bucket=[], + max_audio_sec_per_batch=-1.0, + prebatching_window_size=0, + ) + assert policy.enabled is False + + +def test_batch_policy_numeric_field_validation() -> None: + with pytest.raises(TypeError, match="flush_interval_ms must be an int"): + BatchPolicy(flush_interval_ms=250.5) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="flush_interval_ms must be >= 0"): + BatchPolicy(flush_interval_ms=-1) + + with pytest.raises(TypeError, match="buckets_sec entry must be numeric"): + BatchPolicy(buckets_sec=[0, "60"], max_items_per_batch_by_bucket=[1, 1]) # type: ignore[list-item] + + with pytest.raises(TypeError, match="max_items_per_batch_by_bucket entry must be an int"): + BatchPolicy(buckets_sec=[0, 60], max_items_per_batch_by_bucket=[1, True]) # type: ignore[list-item] + + with pytest.raises(ValueError, match="bucketed_inference_batch_size has 1 entries"): + BatchPolicy(buckets_sec=[0, 60], max_items_per_batch_by_bucket=[1, 1], bucketed_inference_batch_size=[1]) + + with pytest.raises(TypeError, match="bucketed_inference_batch_size entry must be an int"): + BatchPolicy( + buckets_sec=[0, 60], + max_items_per_batch_by_bucket=[1, 1], + bucketed_inference_batch_size=[1, True], # type: ignore[list-item] + ) + + with pytest.raises(TypeError, match="max_audio_sec_per_batch must be numeric or None"): + BatchPolicy(max_audio_sec_per_batch=True) # type: ignore[arg-type] + + +def test_batch_policy_bucket_for_clamps_above_top_edge() -> None: + """Left-edge semantics: bucket i covers [buckets_sec[i], buckets_sec[i+1]).""" + p = BatchPolicy(buckets_sec=[0, 60, 600], max_items_per_batch_by_bucket=[10, 5, 1]) + assert p.bucket_for(0.0) == 0 # [0, 60) + assert p.bucket_for(30.0) == 0 # [0, 60) + assert p.bucket_for(60.0) == 1 # boundary lands in the bucket that starts at 60 + assert p.bucket_for(599.0) == 1 # [60, 600) + assert p.bucket_for(600.0) == 2 # [600, +inf) + assert p.bucket_for(9999.0) == 2 # clamped into top bucket + + +def test_bucket_queue_scheduler_flushes_on_caps_timer_and_drain() -> None: + policy = BatchPolicy( + buckets_sec=[0, 60], + max_items_per_batch_by_bucket=[2, 2], + max_audio_sec_per_batch=100.0, + flush_interval_ms=50, + ) + scheduler = BucketQueueScheduler(policy) + + assert scheduler.enqueue(0, "short-a", 10.0, now_ms=0.0) == [] + item_cap_batch = scheduler.enqueue(1, "short-b", 20.0, now_ms=10.0) + assert [(batch.items, batch.total_cost, batch.flush_reason) for batch in item_cap_batch] == [ + (["short-a", "short-b"], 30.0, "item_cap") + ] + + assert scheduler.enqueue(2, "long-a", 70.0, now_ms=20.0) == [] + cost_overflow_batch = scheduler.enqueue(3, "long-b", 80.0, now_ms=30.0) + assert [(batch.items, batch.total_cost, batch.flush_reason) for batch in cost_overflow_batch] == [ + (["long-a"], 70.0, "capacity") + ] + assert [(batch.items, batch.flush_reason) for batch in scheduler.flush_all()] == [(["long-b"], "drain")] + + assert scheduler.enqueue(4, "timer-a", 5.0, now_ms=100.0) == [] + assert scheduler.flush_due(now_ms=149.0) == [] + timer_batch = scheduler.flush_due(now_ms=150.0) + assert [(batch.items, batch.flush_reason) for batch in timer_batch] == [(["timer-a"], "timer")] + + +def test_bucket_queue_scheduler_can_disable_timer_checks_for_finite_planning() -> None: + policy = BatchPolicy( + buckets_sec=[0], + max_items_per_batch_by_bucket=[10], + max_audio_sec_per_batch=None, + flush_interval_ms=1, + ) + scheduler = BucketQueueScheduler(policy, enable_timer=False) + + assert scheduler.enqueue(0, "a", 1.0, now_ms=0.0) == [] + assert scheduler.flush_due(now_ms=10.0) == [] + assert [(batch.items, batch.flush_reason) for batch in scheduler.flush_all()] == [(["a"], "drain")] + + +# ---------------------------------------------------------------------- +# run_bucketed: the shared, stage-agnostic dispatch helper +# ---------------------------------------------------------------------- + + +def test_run_bucketed_preserves_input_order_across_buckets() -> None: + """Results realign to input order regardless of internal bucket order.""" + policy = BatchPolicy( + buckets_sec=[0, 30, 1200], + max_items_per_batch_by_bucket=[32, 16, 8], + max_audio_sec_per_batch=None, + ) + # durations: long, short, long, short -> two buckets, interleaved input. + items = [{"d": 600.0, "v": "L0"}, {"d": 5.0, "v": "S1"}, {"d": 700.0, "v": "L2"}, {"d": 10.0, "v": "S3"}] + calls: list[list[str]] = [] + + def run_fn(sub: list[dict]) -> list[str]: + calls.append([it["v"] for it in sub]) + return [it["v"] for it in sub] + + out = run_bucketed(items, run_fn, cost_fn=lambda it: it["d"], policy=policy) + + assert out == ["L0", "S1", "L2", "S3"] + assert len(calls) == 2 # one per occupied bucket + + +def test_run_bucketed_without_policy_runs_single_call() -> None: + items = [{"d": 1.0}, {"d": 2.0}, {"d": 3.0}] + calls = 0 + + def run_fn(sub: list[dict]) -> list[int]: + nonlocal calls + calls += 1 + return list(range(len(sub))) + + out = run_bucketed(items, run_fn, cost_fn=lambda it: it["d"], policy=None) + + assert calls == 1 + assert out == [0, 1, 2] + + +def test_run_bucketed_disabled_policy_runs_single_call() -> None: + items = [{"d": 1.0}, {"d": 120.0}, {"d": 3.0}] + policy = BatchPolicy( + enabled=False, + buckets_sec=[0, 60], + max_items_per_batch_by_bucket=[1, 1], + max_audio_sec_per_batch=None, + ) + calls: list[list[float]] = [] + + def run_fn(sub: list[dict]) -> list[float]: + calls.append([it["d"] for it in sub]) + return [it["d"] for it in sub] + + out = run_bucketed(items, run_fn, cost_fn=lambda it: it["d"], policy=policy) + + assert out == [1.0, 120.0, 3.0] + assert calls == [[1.0, 120.0, 3.0]] + + +def test_run_bucketed_empty_items_short_circuits() -> None: + def run_fn(_sub: list) -> list: + msg = "run_fn must not be called for empty items" + raise AssertionError(msg) + + assert run_bucketed([], run_fn, cost_fn=lambda _it: 0.0) == [] + + +def test_run_bucketed_mismatched_result_count_raises() -> None: + def run_fn(_sub: list) -> list: + return ["only-one"] + + with pytest.raises(RuntimeError, match=r"returned 1 results for 2 items"): + run_bucketed([{"d": 1.0}, {"d": 2.0}], run_fn, cost_fn=lambda it: it["d"]) + + +def test_asr_process_batch_buckets_mixed_duration_items() -> None: + """ASRStage buckets model items inside the backend-provided batch.""" + policy = BatchPolicy( + strategy="duration_bucketed", + buckets_sec=[0, 30, 1200], + max_items_per_batch_by_bucket=[32, 16, 8], + max_audio_sec_per_batch=None, + ) + stage = _make_stage(batch_policy=policy) + long_task = _make_task(_SR * 600) + short_a = _make_task(_SR * 5) + short_b = _make_task(_SR * 10) + + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="long")], + [ASRResult(text="short-a"), ASRResult(text="short-b")], + ] + + results = BaseStageAdapter(stage).process_batch([long_task, short_a, short_b]) + + assert stage._adapter.transcribe_batch.call_count == 2 + durations_by_call = [ + [item["audio_seconds"] for item in call.args[0]] for call in stage._adapter.transcribe_batch.call_args_list + ] + assert durations_by_call == [[600.0], [5.0, 10.0]] + assert results[0].data["qwen3_prediction_s1"] == "long" + assert results[1].data["qwen3_prediction_s1"] == "short-a" + assert results[2].data["qwen3_prediction_s1"] == "short-b" + + +def test_asr_process_batch_buckets_long_row_tails_with_matching_segments() -> None: + """Long-row tails share an adapter call with matching-duration segments.""" + policy = BatchPolicy( + strategy="duration_bucketed", + buckets_sec=[0, 600, 1200, 2400], + max_items_per_batch_by_bucket=[32, 16, 8, 4], + max_audio_sec_per_batch=2400.0, + ) + stage = _make_stage(batch_policy=policy) + long_50m = _make_task(_SR * 3000) + ten_min = _make_task(_SR * 600) + tiny = _make_task(_SR * 5) + + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="long-40m")], + [ASRResult(text="tail"), ASRResult(text="ten")], + [ASRResult(text="tiny")], + ] + + results = BaseStageAdapter(stage).process_batch([long_50m, ten_min, tiny]) + + durations_by_call = [ + [item["audio_seconds"] for item in call.args[0]] for call in stage._adapter.transcribe_batch.call_args_list + ] + assert durations_by_call == [[2400.0], [600.0, 600.0], [5.0]] + assert results[0].data["qwen3_prediction_s1"] == "long-40m tail" + assert results[1].data["qwen3_prediction_s1"] == "ten" + assert results[2].data["qwen3_prediction_s1"] == "tiny" + + +def test_base_adapter_centralizes_asr_chunking_and_bucketing(monkeypatch: pytest.MonkeyPatch) -> None: + policy = BatchPolicy( + strategy="duration_bucketed", + buckets_sec=[0, 600, 1200, 2400], + max_items_per_batch_by_bucket=[32, 16, 8, 4], + max_audio_sec_per_batch=2400.0, + ) + stage = _make_stage(batch_policy=policy) + long_50m = _make_task(_SR * 3000) + ten_min = _make_task(_SR * 600) + tiny = _make_task(_SR * 5) + + chunk_waveform_calls = 0 + original_chunk_waveform = stage._chunk_waveform + + def counting_chunk_waveform(waveform: np.ndarray, sample_rate: int, max_seconds: float) -> list[np.ndarray]: + nonlocal chunk_waveform_calls + chunk_waveform_calls += 1 + return original_chunk_waveform(waveform, sample_rate, max_seconds) + + monkeypatch.setattr(stage, "_chunk_waveform", counting_chunk_waveform) + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="long-40m")], + [ASRResult(text="tail"), ASRResult(text="ten")], + [ASRResult(text="tiny")], + ] + + results = BaseStageAdapter(stage).process_batch([long_50m, ten_min, tiny]) + + durations_by_call = [ + [item["audio_seconds"] for item in call.args[0]] for call in stage._adapter.transcribe_batch.call_args_list + ] + assert durations_by_call == [[2400.0], [600.0, 600.0], [5.0]] + assert chunk_waveform_calls == 3 + assert results[0].data["qwen3_prediction_s1"] == "long-40m tail" + assert results[1].data["qwen3_prediction_s1"] == "ten" + assert results[2].data["qwen3_prediction_s1"] == "tiny" + + +def test_asr_batch_policy_partitions_items_by_bucket() -> None: + policy = BatchPolicy( + strategy="duration_bucketed", + buckets_sec=[0, 30, 1200, 2400], + max_items_per_batch_by_bucket=[32, 16, 8, 4], + max_audio_sec_per_batch=10_000.0, + flush_interval_ms=250, + ) + stage = _make_stage(batch_policy=policy) + short_a = _make_task(_SR * 5) + short_b = _make_task(_SR * 10) + long_a = _make_task(_SR * 600) + + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="long")], + [ASRResult(text="short-a"), ASRResult(text="short-b")], + ] + results = BaseStageAdapter(stage).process_batch([short_a, short_b, long_a]) + + assert stage._adapter.transcribe_batch.call_count == 2 + assert results[0].data["qwen3_prediction_s1"] == "short-a" + assert results[1].data["qwen3_prediction_s1"] == "short-b" + assert results[2].data["qwen3_prediction_s1"] == "long" + + +def test_asr_batch_policy_respects_per_bucket_item_cap() -> None: + policy = BatchPolicy( + strategy="duration_bucketed", + buckets_sec=[0, 60], + max_items_per_batch_by_bucket=[2, 1], + max_audio_sec_per_batch=None, + ) + stage = _make_stage(batch_policy=policy) + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="a"), ASRResult(text="b")], + [ASRResult(text="c")], + ] + + BaseStageAdapter(stage).process_batch([_make_task() for _ in range(3)]) + + assert stage._adapter.transcribe_batch.call_count == 2 + + +def test_asr_batch_policy_respects_audio_sec_cap() -> None: + policy = BatchPolicy( + strategy="duration_bucketed", + buckets_sec=[0, 60], + max_items_per_batch_by_bucket=[100, 100], + max_audio_sec_per_batch=15.0, + ) + stage = _make_stage(batch_policy=policy) + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="a")], + [ASRResult(text="b")], + ] + + BaseStageAdapter(stage).process_batch([_make_task(_SR * 10), _make_task(_SR * 10)]) + + assert stage._adapter.transcribe_batch.call_count == 2 + + +def test_asr_bucketed_inference_batch_size_controls_adapter_call_cap() -> None: + policy = BatchPolicy( + strategy="duration_bucketed", + buckets_sec=[0, 60, 600], + max_items_per_batch_by_bucket=[10, 10, 10], + bucketed_inference_batch_size=[3, 2, 1], + max_audio_sec_per_batch=None, + ) + stage = _make_stage(batch_policy=policy) + stage.batch_size = 99 + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="long-a")], + [ASRResult(text="long-b")], + [ASRResult(text="medium-a"), ASRResult(text="medium-b")], + [ASRResult(text="short-a"), ASRResult(text="short-b"), ASRResult(text="short-c")], + ] + + BaseStageAdapter(stage).process_batch( + [ + _make_task(_SR * 5), + _make_task(_SR * 10), + _make_task(_SR * 15), + _make_task(_SR * 120), + _make_task(_SR * 130), + _make_task(_SR * 700), + _make_task(_SR * 710), + ] + ) + + durations_by_call = [ + [item["audio_seconds"] for item in call.args[0]] for call in stage._adapter.transcribe_batch.call_args_list + ] + assert durations_by_call == [[700.0], [710.0], [120.0, 130.0], [5.0, 10.0, 15.0]] + + +def test_asr_batch_policy_none_runs_single_adapter_call() -> None: + stage = _make_stage(batch_policy=None) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="a"), + ASRResult(text="b"), + ] + + stage.process_batch([_make_task(), _make_task()]) + + assert stage._adapter.transcribe_batch.call_count == 1 + + +def test_asr_adapter_batch_size_is_independent_from_backend_window() -> None: + stage = _make_stage(batch_policy=None, batch_size=32, adapter_batch_size=1) + stage._adapter.transcribe_batch.side_effect = [ + [ASRResult(text="a")], + [ASRResult(text="b")], + [ASRResult(text="c")], + ] + + stage.process_batch([_make_task(), _make_task(), _make_task()]) + + call_sizes = [len(call.args[0]) for call in stage._adapter.transcribe_batch.call_args_list] + assert call_sizes == [1, 1, 1] + + +def test_asr_batch_policy_disabled_runs_single_adapter_call() -> None: + policy = BatchPolicy( + enabled=False, + buckets_sec=[0, 30, 1200, 2400], + max_items_per_batch_by_bucket=[1, 1, 1, 1], + max_audio_sec_per_batch=None, + ) + stage = _make_stage(batch_policy=policy) + stage._adapter.transcribe_batch.return_value = [ + ASRResult(text="short"), + ASRResult(text="long"), + ] + + stage.process_batch([_make_task(_SR * 5), _make_task(_SR * 600)]) + + assert stage._adapter.transcribe_batch.call_count == 1 diff --git a/tests/stages/audio/inference/test_bucketed_stage.py b/tests/stages/audio/inference/test_bucketed_stage.py new file mode 100644 index 0000000000..e495b93791 --- /dev/null +++ b/tests/stages/audio/inference/test_bucketed_stage.py @@ -0,0 +1,101 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the generic ``BucketedInferenceStage`` base via a minimal non-audio subclass.""" + +from __future__ import annotations + +import numpy as np + +from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy +from nemo_curator.stages.audio.inference.bucketed_stage import BucketedInferenceStage +from nemo_curator.tasks import AudioTask + + +class _SumBucketStage(BucketedInferenceStage): + """Minimal stage: fan ``vals`` into items, x10 each, then sum results back per parent task.""" + + name = "test_sum_bucket" + + def process(self, task: AudioTask) -> AudioTask: + raise NotImplementedError + + def build_items(self, tasks: list[AudioTask]) -> tuple[list[float], list[int]]: + items: list[float] = [] + parent_of: list[int] = [] + for i, t in enumerate(tasks): + for v in t.data["vals"]: + items.append(v) + parent_of.append(i) + return items, parent_of + + def item_cost(self, item: float) -> float: + return float(item) + + def run_inference(self, items: list[float]) -> list[float]: + self.calls.append(list(items)) + return [v * 10 for v in items] + + def assemble( + self, + tasks: list[AudioTask], + items: list[float], + parent_of: list[int], + results: list[float], + ) -> list[AudioTask]: + sums = [0.0 for _ in tasks] + for r, p in zip(results, parent_of, strict=True): + sums[p] += r + for t, s in zip(tasks, sums, strict=True): + t.data["out"] = s + return tasks + + +def test_bucketed_inference_stage_fans_out_buckets_and_reassembles() -> None: + """The base drives build_items -> bucketed dispatch -> assemble, one output per input task.""" + stage = _SumBucketStage() + stage.calls = [] + stage.batch_policy = BatchPolicy( + buckets_sec=[0, 30], + max_items_per_batch_by_bucket=[8, 8], + max_audio_sec_per_batch=None, + ) + t0 = AudioTask(data={"vals": [5.0, 100.0]}) # short + long -> two buckets + t1 = AudioTask(data={"vals": [10.0]}) # short + + out = stage.process_batch([t0, t1]) + + assert out == [t0, t1] + assert t0.data["out"] == (5.0 + 100.0) * 10 + assert t1.data["out"] == 10.0 * 10 + assert len(stage.calls) == 2 # one dispatch per occupied bucket + + +def test_bucketed_inference_stage_empty_batch_short_circuits() -> None: + stage = _SumBucketStage() + stage.calls = [] + assert stage.process_batch([]) == [] + assert stage.calls == [] + + +def test_bucketed_inference_stage_accepts_numpy_task_batch() -> None: + """Ray Data passes ``map_batches`` columns as ndarrays — not Python lists.""" + stage = _SumBucketStage() + stage.calls = [] + t0 = AudioTask(data={"vals": [2.0]}) + batch = np.array([t0], dtype=object) + out = stage.process_batch(batch) + assert out == [t0] + assert t0.data["out"] == 20.0 + assert len(stage.calls) == 1 diff --git a/tests/stages/audio/io/test_audio_file_reader.py b/tests/stages/audio/io/test_audio_file_reader.py new file mode 100644 index 0000000000..5fd5b3d21c --- /dev/null +++ b/tests/stages/audio/io/test_audio_file_reader.py @@ -0,0 +1,152 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +from collections.abc import Callable +from math import isclose +from pathlib import Path + +import pytest +import torch + +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.audio.io.audio_file_reader import AudioFileReaderStage +from nemo_curator.tasks import AudioTask + + +def test_audio_file_reader_process(audio_task: Callable[..., AudioTask], audio_filepath: Path) -> None: + stage = AudioFileReaderStage(target_sample_rate=16000, target_nchannels=1) + stage.setup_on_node() + task = audio_task( + audio_filepath=str(audio_filepath), + audio_item_id="id_1", + ) + + result = stage.process(task) + + out = result.data + assert out.get("audio_filepath") == str(audio_filepath) + assert out.get("sample_rate") == 16000 + assert out.get("is_mono") is True + assert out["waveform"].shape[0] == 1 + assert out.get("num_samples") == out["waveform"].shape[-1] + assert isclose(out.get("duration"), 60.0) + + +def test_audio_file_reader_process_batch_validates_audio_filepath() -> None: + stage = AudioFileReaderStage(target_sample_rate=16000, target_nchannels=1) + + assert stage.inputs() == ([], ["audio_filepath"]) + with pytest.raises(ValueError, match="failed validation"): + stage.process_batch([AudioTask(data={"duration": 1.0})]) + + +def test_audio_file_reader_skip_on_read_error( + audio_task: Callable[..., AudioTask], + monkeypatch: pytest.MonkeyPatch, +) -> None: + stage = AudioFileReaderStage(target_sample_rate=16000, target_nchannels=1) + monkeypatch.setattr( + stage, + "_load_waveform", + lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("decode lost")), + ) + task = audio_task(audio_filepath="/local/audio/missing.opus") + + result = stage.process(task) + + out = result.data + assert out["_skip_me"] == "audio_read_error" + assert "decode lost" in out["audio_read_error"] + assert out["waveform"].shape == (1, 0) + assert out["sample_rate"] == 16000 + assert out["duration"] == 0.0 + + +def test_audio_file_reader_respects_custom_output_keys( + audio_task: Callable[..., AudioTask], + monkeypatch: pytest.MonkeyPatch, +) -> None: + stage = AudioFileReaderStage( + waveform_key="custom_waveform", + sample_rate_key="custom_sample_rate", + num_samples_key="custom_num_samples", + ) + monkeypatch.setattr(stage, "_load_waveform", lambda *_args, **_kwargs: (torch.zeros(1, 123), 16000)) + task = audio_task(audio_filepath="/local/audio/example.opus") + + result = stage.process(task) + + assert result.data["custom_waveform"].shape == (1, 123) + assert result.data["custom_sample_rate"] == 16000 + assert result.data["custom_num_samples"] == 123 + assert "waveform" not in result.data + assert "sample_rate" not in result.data + assert "num_samples" not in result.data + + +def test_audio_file_reader_worker_specs() -> None: + stage = AudioFileReaderStage(ray_num_workers=2) + + assert stage.num_workers() == 2 + assert stage.ray_stage_spec()[RayStageSpecKeys.IS_ACTOR_STAGE] is True + + stage = AudioFileReaderStage(xenna_num_workers_per_node=1) + assert stage.num_workers() is None + assert stage.xenna_stage_spec() == {"num_workers_per_node": 1} + + +def test_audio_file_reader_rejects_conflicting_xenna_worker_specs() -> None: + with pytest.raises(ValueError, match="set at most one"): + AudioFileReaderStage(xenna_num_workers=2, xenna_num_workers_per_node=1) + + +def test_audio_file_reader_rejects_remote_paths( + audio_task: Callable[..., AudioTask], +) -> None: + stage = AudioFileReaderStage(skip_on_read_error=True) + task = audio_task(audio_filepath="s3://bucket/path/audio.opus") + + with pytest.raises(ValueError, match="only accepts local audio paths"): + stage.process(task) + + +def test_audio_file_reader_uses_ffmpeg_seek_for_segments( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + audio_path = tmp_path / "audio.wav" + audio_path.write_bytes(b"placeholder") + seen_cmd: list[str] = [] + + def fake_run_ffmpeg(cmd: list[str]) -> subprocess.CompletedProcess[bytes]: + seen_cmd[:] = cmd + samples = torch.zeros(16000, dtype=torch.float32).numpy().tobytes() + return subprocess.CompletedProcess(cmd, 0, stdout=samples, stderr=b"") + + stage = AudioFileReaderStage(target_sample_rate=16000, target_nchannels=1) + monkeypatch.setattr(stage, "_run_ffmpeg", fake_run_ffmpeg) + + waveform, sample_rate = stage._load_waveform( + str(audio_path), + segment_start_s=12.5, + segment_duration_s=30.0, + ) + + assert waveform.shape == (1, 16000) + assert sample_rate == 16000 + assert seen_cmd[:6] == ["ffmpeg", "-v", "error", "-ss", "12.5", "-i"] + assert seen_cmd[6] == str(audio_path) + assert "-t" in seen_cmd + assert seen_cmd[seen_cmd.index("-t") + 1] == "30" diff --git a/tests/stages/audio/io/test_nemo_tarred_reader.py b/tests/stages/audio/io/test_nemo_tarred_reader.py new file mode 100644 index 0000000000..433a413a6a --- /dev/null +++ b/tests/stages/audio/io/test_nemo_tarred_reader.py @@ -0,0 +1,252 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import tarfile +from pathlib import Path + +import numpy as np +import pytest +import soundfile as sf + +from nemo_curator.stages.audio.io.nemo_tarred_reader import ( + NemoTarredAudioReader, + NemoTarShardDiscoveryStage, + NemoTarShardReaderStage, + _expand_nemo_path, + _iter_discovery_groups, + _iter_input_cfg_entries, +) +from nemo_curator.tasks import FileGroupTask + + +@pytest.mark.parametrize( + "pattern", + [ + "manifest__OP_5..3_CL_.jsonl", + "audio__OP_10..0_CL_.tar", + ], +) +def test_expand_nemo_path_rejects_reversed_ranges(pattern: str) -> None: + with pytest.raises(ValueError, match="must be ascending"): + _expand_nemo_path(pattern) + + +def test_reader_composite_forwards_duration_filter() -> None: + reader = NemoTarredAudioReader( + yaml_path="data_config.yaml", + duration_key="segment_duration", + max_duration_s=40.0, + ) + + shard_reader = reader.decompose()[1] + assert isinstance(shard_reader, NemoTarShardReaderStage) + assert shard_reader.duration_key == "segment_duration" + assert shard_reader.max_duration_s == 40.0 + + +def test_reader_rejects_non_positive_max_duration() -> None: + with pytest.raises(ValueError, match="max_duration_s"): + NemoTarShardReaderStage(max_duration_s=0) + + +def test_reader_manifest_lookup_accepts_common_path_variants(tmp_path: Path) -> None: + manifest = tmp_path / "manifest.jsonl" + manifest.write_text( + '{"audio_filepath": "/data/shard/audio_0.wav", "duration": 1.0}\n', + encoding="utf-8", + ) + stage = NemoTarShardReaderStage() + + lookup, entry_count = stage._read_manifest(str(manifest)) + + assert entry_count == 1 + assert lookup["/data/shard/audio_0.wav"]["duration"] == 1.0 + assert lookup["data/shard/audio_0.wav"]["duration"] == 1.0 + assert lookup["audio_0.wav"]["duration"] == 1.0 + + +def test_manifest_lookup_disambiguates_shared_basename_by_path_suffix(tmp_path: Path) -> None: + """When two entries share a basename, a member resolves by longest path suffix.""" + manifest = tmp_path / "manifest.jsonl" + manifest.write_text( + '{"audio_filepath": "spk_a/utt.wav", "duration": 1.0}\n{"audio_filepath": "spk_b/utt.wav", "duration": 2.0}\n', + encoding="utf-8", + ) + stage = NemoTarShardReaderStage() + + lookup, _ = stage._read_manifest(str(manifest)) + + # Full path wins outright. + assert lookup.match("spk_a/utt.wav")["duration"] == 1.0 + assert lookup.match("spk_b/utt.wav")["duration"] == 2.0 + # A bare, ambiguous basename resolves to nothing rather than a wrong entry. + assert lookup.match("utt.wav") is None + + +def test_read_manifest_skips_lines_missing_filepath_key(tmp_path: Path) -> None: + manifest = tmp_path / "manifest.jsonl" + manifest.write_text( + '{"audio_filepath": "a.wav", "duration": 1.0}\n' + '{"duration": 2.0}\n' + '{"audio_filepath": "c.wav", "duration": 3.0}\n', + encoding="utf-8", + ) + stage = NemoTarShardReaderStage() + + lookup, entry_count = stage._read_manifest(str(manifest)) + + assert entry_count == 3 + assert lookup["a.wav"]["duration"] == 1.0 + assert lookup["c.wav"]["duration"] == 3.0 + + +def test_read_manifest_skips_invalid_json_lines(tmp_path: Path) -> None: + manifest = tmp_path / "manifest.jsonl" + manifest.write_text( + '{"audio_filepath": "ok.wav", "duration": 1.0}\nnot json\n', + encoding="utf-8", + ) + stage = NemoTarShardReaderStage() + + lookup, entry_count = stage._read_manifest(str(manifest)) + + assert entry_count == 2 + assert lookup["ok.wav"]["duration"] == 1.0 + + +def test_reader_skips_tar_members_when_extractfile_returns_none( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Tar members where extractfile() returns None must be skipped, not raise AttributeError.""" + audio = np.zeros(16000, dtype=np.float32) + wav_buf = io.BytesIO() + sf.write(wav_buf, audio, 16000, format="WAV") + wav_bytes = wav_buf.getvalue() + + tar_path = tmp_path / "shard.tar" + with tarfile.open(tar_path, "w") as tf: + info = tarfile.TarInfo(name="audio_0.wav") + info.size = len(wav_bytes) + tf.addfile(info, io.BytesIO(wav_bytes)) + + manifest_path = tmp_path / "manifest.jsonl" + manifest_path.write_text( + json.dumps({"audio_filepath": "audio_0.wav", "duration": 1.0}) + "\n", + encoding="utf-8", + ) + + def _extractfile_returns_none(_self: tarfile.TarFile, _member: tarfile.TarInfo) -> None: + return None + + monkeypatch.setattr(tarfile.TarFile, "extractfile", _extractfile_returns_none) + + stage = NemoTarShardReaderStage() + task = FileGroupTask( + dataset_name="test", + data=[str(manifest_path), str(tar_path)], + reader_config={"corpus": "test", "shard_key": "test_shard"}, + ) + task.task_id = "test_shard" + + results = stage.process(task) + + assert results == [] + + +def test_iter_discovery_groups_rejects_empty_yaml() -> None: + with pytest.raises(ValueError, match="empty"): + _iter_discovery_groups(None, "bad.yaml") + + +@pytest.mark.parametrize("config", [{"corpus": "x"}, "scalar", 42]) +def test_iter_discovery_groups_rejects_non_list_root(config: object) -> None: + with pytest.raises(TypeError, match="must be a list"): + _iter_discovery_groups(config, "bad.yaml") + + +def test_iter_discovery_groups_skips_non_mapping_entries() -> None: + groups = _iter_discovery_groups([{"input_cfg": []}, "skip-me", {"input_cfg": []}], "ok.yaml") + assert len(groups) == 2 + + +def test_iter_input_cfg_entries_skips_non_list_and_non_mapping() -> None: + assert _iter_input_cfg_entries({"input_cfg": "bad"}, "ok.yaml") == [] + assert _iter_input_cfg_entries({"input_cfg": ["bad", {"corpus": "c"}]}, "ok.yaml") == [{"corpus": "c"}] + + +def test_discovery_process_raises_on_empty_yaml(tmp_path: Path) -> None: + yaml_path = tmp_path / "empty.yaml" + yaml_path.write_text("", encoding="utf-8") + stage = NemoTarShardDiscoveryStage(yaml_path=str(yaml_path)) + + with pytest.raises(ValueError, match="empty"): + stage.process(None) # type: ignore[arg-type] + + +def test_discovery_process_raises_on_scalar_yaml_root(tmp_path: Path) -> None: + yaml_path = tmp_path / "scalar.yaml" + yaml_path.write_text("just_a_string\n", encoding="utf-8") + stage = NemoTarShardDiscoveryStage(yaml_path=str(yaml_path)) + + with pytest.raises(TypeError, match="must be a list"): + stage.process(None) # type: ignore[arg-type] + + +def test_discovery_skips_corpus_missing_required_paths(tmp_path: Path) -> None: + yaml_path = tmp_path / "data.yaml" + yaml_path.write_text( + """ +- input_cfg: + - corpus: broken + type: nemo_tarred + manifest_filepath: /data/broken/manifest_0.jsonl + - corpus: good + type: nemo_tarred + manifest_filepath: /data/good/manifest_0.jsonl + tarred_audio_filepaths: /data/good/audio_0.tar +""".lstrip(), + encoding="utf-8", + ) + stage = NemoTarShardDiscoveryStage(yaml_path=str(yaml_path)) + + tasks = stage.process(None) # type: ignore[arg-type] + + assert [task.task_id for task in tasks] == ["good/manifest_0"] + assert tasks[0].data == ["/data/good/manifest_0.jsonl", "/data/good/audio_0.tar"] + + +def test_discovery_skips_manifest_path_that_cannot_map_to_corpus(tmp_path: Path) -> None: + yaml_path = tmp_path / "data.yaml" + yaml_path.write_text( + """ +- input_cfg: + - corpus: good + type: nemo_tarred + manifest_filepath: /data/other/manifest_0.jsonl + tarred_audio_filepaths: /data/other/audio_0.tar + - corpus: good + type: nemo_tarred + manifest_filepath: /data/good/manifest_1.jsonl + tarred_audio_filepaths: /data/good/audio_1.tar +""".lstrip(), + encoding="utf-8", + ) + stage = NemoTarShardDiscoveryStage(yaml_path=str(yaml_path)) + + tasks = stage.process(None) # type: ignore[arg-type] + + assert [task.task_id for task in tasks] == ["good/manifest_1"] diff --git a/tests/stages/audio/io/test_sharded_manifest_writer.py b/tests/stages/audio/io/test_sharded_manifest_writer.py new file mode 100644 index 0000000000..5ccdb041dd --- /dev/null +++ b/tests/stages/audio/io/test_sharded_manifest_writer.py @@ -0,0 +1,166 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path + +import numpy as np + +from nemo_curator.stages.audio.io.sharded_manifest_writer import ShardedManifestWriterStage +from nemo_curator.tasks import AudioTask + + +def test_writer_drops_waveform_and_writes_final_manifest_on_completion(tmp_path: Path) -> None: + final_manifest = tmp_path / "output.jsonl" + stage = ShardedManifestWriterStage( + output_dir=str(tmp_path), + final_manifest_path=str(final_manifest), + write_perf_stats=False, + ) + stage.setup_on_node() + + task = AudioTask( + dataset_name="test", + data={ + "audio_filepath": "utt-1.wav", + "duration": 1.0, + "waveform": np.zeros(4, dtype=np.float32), + "qwen3_prediction_s1": "hello", + }, + _metadata={"_shard_key": "corpus/manifest_0", "_shard_total": 1}, + ) + task.task_id = "utt-1" + + result = stage.process(task) + + shard_path = tmp_path / "corpus" / "manifest_0.jsonl" + assert result.data == [str(shard_path)] + shard_row = json.loads(shard_path.read_text(encoding="utf-8").strip()) + assert "waveform" not in shard_row + assert shard_row["qwen3_prediction_s1"] == "hello" + assert json.loads(final_manifest.read_text(encoding="utf-8").strip()) == shard_row + + stage.teardown() + + final_row = json.loads(final_manifest.read_text(encoding="utf-8").strip()) + assert shard_row == final_row + + +def test_writer_appends_each_completed_shard_to_final_manifest_once(tmp_path: Path) -> None: + final_manifest = tmp_path / "output.jsonl" + stage = ShardedManifestWriterStage( + output_dir=str(tmp_path), + final_manifest_path=str(final_manifest), + write_perf_stats=False, + ) + stage.setup_on_node() + + first = AudioTask( + dataset_name="test", + data={"audio_filepath": "utt-1.wav", "duration": 1.0}, + _metadata={"_shard_key": "corpus/manifest_0", "_shard_total": 1}, + ) + first.task_id = "utt-1" + second = AudioTask( + dataset_name="test", + data={"audio_filepath": "utt-2.wav", "duration": 1.0}, + _metadata={"_shard_key": "corpus/manifest_1", "_shard_total": 1}, + ) + second.task_id = "utt-2" + + stage.process(first) + stage.process(second) + + rows = [json.loads(line) for line in final_manifest.read_text(encoding="utf-8").splitlines()] + assert [row["audio_filepath"] for row in rows] == ["utt-1.wav", "utt-2.wav"] + + stage.teardown() + + rows = [json.loads(line) for line in final_manifest.read_text(encoding="utf-8").splitlines()] + assert [row["audio_filepath"] for row in rows] == ["utt-1.wav", "utt-2.wav"] + + +def test_writer_rebuilds_final_manifest_from_completed_shards_on_teardown(tmp_path: Path) -> None: + final_manifest = tmp_path / "output.jsonl" + final_manifest.write_text('{"audio_filepath": "stale.wav"}\n', encoding="utf-8") + shard_path = tmp_path / "corpus" / "manifest_0.jsonl" + shard_path.parent.mkdir(parents=True) + shard_path.write_text('{"audio_filepath": "fresh.wav"}\n', encoding="utf-8") + done_path = tmp_path / "corpus" / "manifest_0.jsonl.done" + done_path.write_text("1\n", encoding="utf-8") + stage = ShardedManifestWriterStage( + output_dir=str(tmp_path), + final_manifest_path=str(final_manifest), + write_perf_stats=False, + ) + + stage.setup_on_node() + stage.teardown() + + assert final_manifest.read_text(encoding="utf-8") == '{"audio_filepath": "fresh.wav"}\n' + + +def test_writer_perf_summary_splits_invocations_and_items(tmp_path: Path) -> None: + stage = ShardedManifestWriterStage(output_dir=str(tmp_path), write_perf_stats=True) + stage.setup_on_node() + tasks = [ + AudioTask( + dataset_name="test", + data={"audio_filepath": f"utt-{i}.wav", "duration": 1.0}, + _metadata={"_shard_key": "corpus/manifest_0", "_shard_total": 2}, + ) + for i in range(2) + ] + for i, task in enumerate(tasks): + task.task_id = f"utt-{i}" + + stage.process_batch(tasks) + stage.teardown() + + writer_summary = json.loads((tmp_path / "perf_summary.json").read_text(encoding="utf-8"))["stages"][ + "sharded_manifest_writer" + ] + assert writer_summary["total_items_processed"] == 2.0 + assert writer_summary["invocation_count"] == 1.0 + assert writer_summary["custom_metrics_sum"]["writer_process_calls"] == 1.0 + assert writer_summary["custom_metrics_sum"]["writer_invocation_count"] == 1.0 + assert writer_summary["custom_metrics_sum"]["writer_items_processed"] == 2.0 + + +def test_writer_preserves_final_manifest_when_done_markers_exist(tmp_path: Path) -> None: + final_manifest = tmp_path / "output.jsonl" + final_manifest.write_text('{"audio_filepath": "old.wav"}\n', encoding="utf-8") + done_path = tmp_path / "corpus" / "manifest_0.jsonl.done" + done_path.parent.mkdir(parents=True) + done_path.write_text("1\n", encoding="utf-8") + stage = ShardedManifestWriterStage( + output_dir=str(tmp_path), + final_manifest_path=str(final_manifest), + write_perf_stats=False, + ) + + stage.setup_on_node() + + assert final_manifest.read_text(encoding="utf-8") == '{"audio_filepath": "old.wav"}\n' + + +def test_writer_teardown_does_not_overwrite_existing_perf_summary_without_new_tasks(tmp_path: Path) -> None: + perf_summary = tmp_path / "perf_summary.json" + perf_summary.write_text('{"existing": true}\n', encoding="utf-8") + stage = ShardedManifestWriterStage(output_dir=str(tmp_path), write_perf_stats=True) + + stage.setup_on_node() + stage.teardown() + + assert json.loads(perf_summary.read_text(encoding="utf-8")) == {"existing": True} diff --git a/tests/stages/audio/metrics/test_perf_summary.py b/tests/stages/audio/metrics/test_perf_summary.py new file mode 100644 index 0000000000..1b5514bae5 --- /dev/null +++ b/tests/stages/audio/metrics/test_perf_summary.py @@ -0,0 +1,360 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the identity-driven perf summary and per-actor (GPU/CPU) scheduling breakdown.""" + +from __future__ import annotations + +from nemo_curator.stages.audio.metrics.performance import ( + AudioPerformanceSummary, + serialize_stage_perf, +) +from nemo_curator.utils.performance_utils import StagePerfStats + + +def _perf( + *, + addr: str = "", + actor_id: str = "", + idle: float = 0.0, + items: int = 32, + audio_s: float = 100.0, +) -> StagePerfStats: + """A GPU-stage record keyed by physical address ``:`` (blank addr -> CPU stage).""" + node_id = "" + gpu_indices: list[int] = [] + if addr: + host, _, idx_part = addr.rpartition(":") + node_id = host + gpu_indices = [int(x) for x in idx_part.split(",") if x.strip()] + return StagePerfStats( + stage_name="QwenOmni_inference", + process_time=1.0, + actor_idle_time=idle, + num_items_processed=items, + custom_metrics={"audio_duration_s": audio_s, "utterances_input": float(items)}, + actor_id=actor_id, + node_id=node_id, + physical_address=addr, + gpu_indices=gpu_indices, + ) + + +# ---------------------------------------------------------------------- +# Serialization + fingerprint +# ---------------------------------------------------------------------- + + +def test_serialize_stage_perf_carries_identity_when_present() -> None: + [entry] = serialize_stage_perf( + [ + StagePerfStats( + stage_name="QwenOmni_inference", + process_time=1.0, + actor_id="S:actor-ab", + node_id="node-1", + gpu_id="node-1:2", + physical_address="10.0.0.5:2", + ) + ] + ) + assert entry["physical_address"] == "10.0.0.5:2" # canonical address + assert entry["gpu_id"] == "node-1:2" # legacy label still carried + assert entry["actor_id"] == "S:actor-ab" + assert entry["node_id"] == "node-1" + + +def test_serialize_stage_perf_omits_blank_identity() -> None: + [entry] = serialize_stage_perf( + [StagePerfStats(stage_name="QwenOmni_inference", process_time=1.0)] + ) # no identity resolved (CPU / non-Ray) + assert "physical_address" not in entry + assert "gpu_id" not in entry + assert "actor_id" not in entry + assert "node_id" not in entry + + +def test_fingerprint_distinguishes_actors_with_equal_timings() -> None: + """Two records byte-identical except for identity must NOT dedup to one.""" + summary = AudioPerformanceSummary(duration_key="duration") + a = _perf(addr="10.0.0.5:0", actor_id="S:actor-a") + b = _perf(addr="10.0.0.5:1", actor_id="S:actor-b") + summary.record_stage_perf([a, b]) + stage = summary.build_stage_summaries()["QwenOmni_inference"] + assert stage["invocation_count"] == 2.0 # both counted, not collapsed by dedup + + +def test_stage_summary_exposes_adapter_inference_call_count() -> None: + summary = AudioPerformanceSummary(duration_key="duration") + summary.record_stage_perf( + [ + StagePerfStats( + stage_name="QwenOmni_inference", + process_time=10.0, + num_items_processed=3, + custom_metrics={ + "audio_duration_s": 120.0, + "adapter_inference_calls": 7.0, + "adapter_inference_items": 9.0, + }, + ) + ] + ) + + stage = summary.build_stage_summaries()["QwenOmni_inference"] + assert stage["invocation_count"] == 1.0 + assert stage["adapter_inference_call_count"] == 7.0 + assert stage["adapter_inference_items"] == 9.0 + assert stage["avg_adapter_inference_batch_size"] == 9.0 / 7.0 + assert stage["avg_audio_s_per_adapter_inference_call"] == 120.0 / 7.0 + assert stage["adapter_inference_calls_per_stage_invocation"] == 7.0 + + +# ---------------------------------------------------------------------- +# Per-GPU scheduling breakdown + topology +# ---------------------------------------------------------------------- + + +def test_per_actor_carries_physical_address_and_topology() -> None: + """A tensor-parallel actor on 2 GPUs is one address but counts as 2 devices.""" + summary = AudioPerformanceSummary() + summary.record_stage_perf( + [ + StagePerfStats( + stage_name="QwenOmni_inference", + process_time=1.0, + actor_idle_time=0.1, + num_items_processed=32, + custom_metrics={"audio_duration_s": 100.0, "utterances_input": 32.0}, + actor_id="S:actor-a", + node_id="node-0", + gpu_id="node-0:0", + physical_address="10.244.181.136:0,1", + pod_ip="10.244.181.136", + hostname="worker-0", + gpu_indices=[0, 1], + ), + ] + ) + stage = summary.build_stage_summaries()["QwenOmni_inference"] + # Topology: 1 per-actor address, but 2 distinct physical devices. + assert stage["gpu_addresses"] == ["10.244.181.136:0,1"] + assert stage["gpu_count"] == 2.0 + + per_actor = stage["per_actor"]["S:actor-a"] # keyed by actor_id + assert per_actor["physical_address"] == "10.244.181.136:0,1" # canonical GPU id, as a field + assert per_actor["node_id"] == "node-0" + assert per_actor["pod_ip"] == "10.244.181.136" + assert per_actor["hostname"] == "worker-0" + assert per_actor["gpu_indices"] == [0, 1] + + +def test_per_actor_scheduling_breakdown_and_topology() -> None: + summary = AudioPerformanceSummary(duration_key="duration") + # actor-a on GPU 0 (two invocations); actor-b on GPU 1 (two invocations). + records = [ + _perf(addr="10.0.0.5:0", actor_id="S:actor-a", idle=0.10, items=32, audio_s=100.0), + _perf(addr="10.0.0.5:0", actor_id="S:actor-a", idle=0.30, items=32, audio_s=120.0), + _perf(addr="10.0.0.5:1", actor_id="S:actor-b", idle=0.05, items=16, audio_s=50.0), + _perf(addr="10.0.0.5:1", actor_id="S:actor-b", idle=0.20, items=16, audio_s=70.0), + ] + summary.record_stage_perf(records) + stage = summary.build_stage_summaries()["QwenOmni_inference"] + + # Topology + assert stage["gpu_addresses"] == ["10.0.0.5:0", "10.0.0.5:1"] + assert stage["gpu_count"] == 2.0 + assert stage["actor_count"] == 2.0 + + per_actor = stage["per_actor"] + assert set(per_actor) == {"S:actor-a", "S:actor-b"} + + a = per_actor["S:actor-a"] + assert a["physical_address"] == "10.0.0.5:0" + assert a["node_id"] == "10.0.0.5" + assert a["items_processed"] == 64.0 # 32 + 32 + assert a["audio_hours_in"] == (100.0 + 120.0) / 3600.0 + assert "batch_size_p50" in a + assert "queue_wait_s_p50" in a + assert "queue_wait_s_p95" in a + + b = per_actor["S:actor-b"] + assert b["physical_address"] == "10.0.0.5:1" + assert b["items_processed"] == 32.0 # 16 + 16 + assert b["audio_hours_in"] == (50.0 + 70.0) / 3600.0 + + +def test_per_actor_gpus_block_is_keyed_per_physical_device() -> None: + """A tensor-parallel actor reports ONE address but a nested per-device (``:``) GPU map.""" + summary = AudioPerformanceSummary() + # Actor on 2 GPUs; the sampler namespaces each device's util by UUID (``gpu_util_pct::``). + summary.record_stage_perf( + [ + StagePerfStats( + stage_name="QwenOmni_inference", + process_time=1.0, + num_items_processed=32, + custom_metrics={ + "audio_duration_s": 100.0, + "gpu_util_pct::aaa": 90.0, + "gpu_mem_used_pct::aaa": 70.0, + "gpu_util_pct::bbb": 40.0, + "gpu_mem_used_pct::bbb": 30.0, + }, + actor_id="QwenOmni_inference:actor-a", + node_id="node-0", + physical_address="10.0.0.5:0,1", + gpu_indices=[0, 1], + gpu_uuids=["GPU-aaa", "GPU-bbb"], + ), + ] + ) + stage = summary.build_stage_summaries()["QwenOmni_inference"] + gpus = stage["per_actor"]["QwenOmni_inference:actor-a"]["gpus"] + # One entry per physical device, keyed by : -- not averaged across the actor. + assert set(gpus) == {"10.0.0.5:0", "10.0.0.5:1"} + assert gpus["10.0.0.5:0"]["gpu_index"] == 0 + assert gpus["10.0.0.5:0"]["gpu_uuid"] == "GPU-aaa" + assert gpus["10.0.0.5:0"]["gpu_util_pct_p50"] == 90.0 + assert gpus["10.0.0.5:0"]["gpu_mem_used_pct_p50"] == 70.0 + assert gpus["10.0.0.5:1"]["gpu_index"] == 1 + assert gpus["10.0.0.5:1"]["gpu_util_pct_p50"] == 40.0 + # The per-GPU util is NOT summed into the stage's scalar custom-metric totals. + assert "gpu_util_pct::aaa" not in stage.get("custom_metrics_sum", {}) + + +def test_pipeline_throughput_rollup_unions_gpu_addresses() -> None: + summary = AudioPerformanceSummary(duration_key="duration") + summary.record_stage_perf( + [ + _perf(addr="10.0.0.5:0", actor_id="S:actor-a", audio_s=3600.0), + _perf(addr="10.0.0.6:0", actor_id="S:actor-b", audio_s=3600.0), + ] + ) + # total_audio_seconds is normally driven by record_task; set it directly here. + summary._total_audio_seconds = 7200.0 # 2 audio-hours + out = summary.build_summary(wall_time_s=3600.0) # 1 wall-hour + + pt = out["pipeline_throughput"] + assert pt["gpu_addresses"] == ["10.0.0.5:0", "10.0.0.6:0"] + assert pt["gpu_count"] == 2.0 + assert pt["audio_hours_per_wallclock_hour"] == 2.0 # 2 audio-h / 1 wall-h + + +def test_rows_in_prefers_reader_manifest_entries_over_discovery_input_task() -> None: + summary = AudioPerformanceSummary(duration_key="duration") + summary.record_stage_perf( + [ + StagePerfStats( + stage_name="nemo_tar_shard_discovery", + process_time=0.1, + custom_metrics={"input_tasks": 1.0, "shards_emitted": 8.0}, + ), + StagePerfStats( + stage_name="nemo_tar_shard_reader", + process_time=1.0, + custom_metrics={"manifest_entries": 123.0, "output_utterances": 100.0, "audio_duration_s": 3600.0}, + ), + ] + ) + + out = summary.build_summary(wall_time_s=10.0) + + assert out["rows_in"] == 123.0 + assert out["input_hours"] == 1.0 + + +# ---------------------------------------------------------------------- +# Graceful absence when identity is unresolved (CPU / non-Ray) +# ---------------------------------------------------------------------- + + +def test_no_identity_emits_no_per_actor_or_topology() -> None: + summary = AudioPerformanceSummary(duration_key="duration") + summary.record_stage_perf([_perf(), _perf(idle=0.5)]) # blank address/actor + stage = summary.build_stage_summaries()["QwenOmni_inference"] + assert "per_actor" not in stage + assert "gpu_addresses" not in stage + assert "gpu_count" not in stage + assert "actor_count" not in stage + + out = summary.build_summary(wall_time_s=10.0) + assert "gpu_addresses" not in out.get("pipeline_throughput", {}) + + +# ---------------------------------------------------------------------- +# Mixed pipeline: GPU stages and CPU stages coexist in one summary +# ---------------------------------------------------------------------- + + +def _cpu_perf(stage_name: str, actor_id: str) -> StagePerfStats: + """A CPU-stage record: actor + node resolved, but no GPU (no ``physical_address`` / ``gpu_id``).""" + return StagePerfStats( + stage_name=stage_name, + process_time=0.5, + num_items_processed=64, + custom_metrics={"writer_process_calls": 2.0}, + actor_id=actor_id, + node_id="node-0", + ) + + +def test_cpu_stage_gets_per_actor_but_no_gpu_fields() -> None: + """A CPU stage gets actor_count + per_actor, but no gpu_addresses / gpu_count / GPU fields.""" + summary = AudioPerformanceSummary(duration_key="duration") + summary.record_stage_perf( + [ + _cpu_perf("ShardedManifestWriter", "ShardedManifestWriter:actor-cpu01"), + ] + ) + stage = summary.build_stage_summaries()["ShardedManifestWriter"] + assert stage["actor_count"] == 1.0 + assert stage["total_items_processed"] == 64.0 + assert "gpu_addresses" not in stage + assert "gpu_count" not in stage + + per_actor = stage["per_actor"]["ShardedManifestWriter:actor-cpu01"] + assert per_actor["items_processed"] == 64.0 + assert per_actor["node_id"] == "node-0" + assert "physical_address" not in per_actor # CPU actor: no GPU + assert "gpu_indices" not in per_actor + + +def test_mixed_gpu_and_cpu_stages_in_one_pipeline() -> None: + """GPU and CPU stages coexist: both get per_actor, only the GPU stage gets topology; rollup unions GPU addresses.""" + summary = AudioPerformanceSummary(duration_key="duration") + summary.record_stage_perf( + [ + _perf(addr="10.0.0.5:0", actor_id="QwenOmni_inference:actor-a", items=32, audio_s=100.0), + _cpu_perf("ShardedManifestWriter", "ShardedManifestWriter:actor-cpu01"), + ] + ) + stages = summary.build_stage_summaries() + + gpu_stage = stages["QwenOmni_inference"] + assert gpu_stage["gpu_addresses"] == ["10.0.0.5:0"] + assert gpu_stage["actor_count"] == 1.0 + assert set(gpu_stage["per_actor"]) == {"QwenOmni_inference:actor-a"} + assert gpu_stage["per_actor"]["QwenOmni_inference:actor-a"]["physical_address"] == "10.0.0.5:0" + + cpu_stage = stages["ShardedManifestWriter"] + assert cpu_stage["actor_count"] == 1.0 + assert set(cpu_stage["per_actor"]) == {"ShardedManifestWriter:actor-cpu01"} # CPU gets per_actor too + assert "gpu_addresses" not in cpu_stage + + # Pipeline rollup unions ONLY the GPU addresses (the CPU stage contributes none). + summary._total_audio_seconds = 100.0 # exercise the throughput branch + pt = summary.build_summary(wall_time_s=50.0)["pipeline_throughput"] + assert pt["gpu_addresses"] == ["10.0.0.5:0"] + assert pt["gpu_count"] == 1.0 diff --git a/tests/stages/audio/test_common.py b/tests/stages/audio/test_common.py index 4572cda1bf..e0b3e21aec 100644 --- a/tests/stages/audio/test_common.py +++ b/tests/stages/audio/test_common.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ruff: noqa: ANN202 """Tests for common audio stages: GetAudioDurationStage, PreserveByValueStage, ManifestReaderStage, ManifestReader, and ManifestWriterStage.""" @@ -48,6 +49,12 @@ def _make_file_group_task(paths: list[str]) -> FileGroupTask: return FileGroupTask(dataset_name="test", data=paths) +def _audio_task_with_id(task_id: str, **kwargs) -> AudioTask: + task = AudioTask(**kwargs) + task.task_id = task_id + return task + + # --------------------------------------------------------------------------- # PreserveByValueStage # --------------------------------------------------------------------------- @@ -172,6 +179,29 @@ def test_reads_single_manifest(self, tmp_path: Path) -> None: assert all(isinstance(r, AudioTask) for r in result) assert result[0].data["audio_filepath"] == "a.wav" assert result[1].data["audio_filepath"] == "b.wav" + assert "_shard_total" not in result[0]._metadata + assert "_shard_key" not in result[0]._metadata + + def test_uses_storage_options_when_opening_manifest(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + manifest = tmp_path / "input.jsonl" + manifest.write_text(json.dumps({"audio_filepath": "a.wav", "segments": []})) + seen_kwargs: list[dict] = [] + + class _LocalFS: + def open(self, path: str, mode: str, encoding: str | None = None): + return open(path, mode, encoding=encoding) + + def fake_url_to_fs(path: str, **kwargs): + seen_kwargs.append(kwargs) + return _LocalFS(), path + + monkeypatch.setattr("nemo_curator.stages.audio.common.url_to_fs", fake_url_to_fs) + + stage = ManifestReaderStage(storage_options={"profile": "private"}) + result = stage.process(_make_file_group_task([str(manifest)])) + + assert len(result) == 1 + assert seen_kwargs == [{"profile": "private"}] def test_worker_defaults(self) -> None: stage = ManifestReaderStage() @@ -190,6 +220,8 @@ def test_reads_multiple_manifests(self, tmp_path: Path) -> None: assert len(result) == 2 paths = [r.data["audio_filepath"] for r in result] assert paths == ["a.wav", "b.wav"] + assert all("_shard_total" not in r._metadata for r in result) + assert all("_shard_key" not in r._metadata for r in result) def test_one_audio_entry_per_line(self, tmp_path: Path) -> None: entries = [{"audio_filepath": f"{i}.wav", "segments": []} for i in range(5)] @@ -259,6 +291,8 @@ def test_duplicate_manifests_for_repeat(self, tmp_path: Path) -> None: assert len(result) == 3 assert all(r.data["audio_filepath"] == "a.wav" for r in result) + assert all("_shard_total" not in r._metadata for r in result) + assert all("_shard_key" not in r._metadata for r in result) class TestManifestReaderDirectory: @@ -370,7 +404,7 @@ class TestManifestWriterStage: def test_writes_entry_to_jsonl(self, tmp_path: Path) -> None: out = tmp_path / "output.jsonl" - writer = ManifestWriterStage(output_path=str(out)) + writer = ManifestWriterStage(output_path=str(out), write_perf_stats=False) writer.setup_on_node() writer.setup() @@ -399,7 +433,7 @@ def test_returns_audio_task(self, tmp_path: Path) -> None: def test_propagates_metadata_and_stage_perf(self, tmp_path: Path) -> None: out = tmp_path / "output.jsonl" - writer = ManifestWriterStage(output_path=str(out)) + writer = ManifestWriterStage(output_path=str(out), write_perf_stats=False) writer.setup_on_node() writer.setup() @@ -430,6 +464,61 @@ def test_appends_across_multiple_process_calls(self, tmp_path: Path) -> None: assert len(lines) == 3 assert [json.loads(line)["entry"] for line in lines] == [1, 2, 3] + def test_process_batch_drops_waveform_and_array_like_keys(self, tmp_path: Path) -> None: + out = tmp_path / "output.jsonl" + writer = ManifestWriterStage( + output_path=str(out), + write_perf_stats=False, + drop_manifest_keys=("waveform",), + drop_array_like_values=True, + ) + writer.setup_on_node() + writer.setup() + + returned = writer.process_batch( + [ + _audio_task_with_id( + "t1", + data={ + "audio_filepath": "a.wav", + "duration": 1.0, + "waveform": torch.zeros(1, 16000), + "embedding": np.zeros(4, dtype=np.float32), + "text": "hello", + }, + ), + _audio_task_with_id("t2", data={"audio_filepath": "b.wav", "duration": 2.0}), + ] + ) + + rows = [json.loads(line) for line in out.read_text().splitlines()] + assert [row["audio_filepath"] for row in rows] == ["a.wav", "b.wav"] + assert "waveform" not in rows[0] + assert "embedding" not in rows[0] + assert rows[0]["text"] == "hello" + assert [task.task_id for task in returned] == ["", ""] + + def test_writes_perf_summary_during_process_batch(self, tmp_path: Path) -> None: + out = tmp_path / "output.jsonl" + writer = ManifestWriterStage(output_path=str(out), write_perf_stats=True) + writer.setup_on_node() + writer.setup() + + writer.process_batch( + [ + _audio_task_with_id("t1", data={"audio_filepath": "a.wav", "duration": 1.0}), + _audio_task_with_id("t2", data={"audio_filepath": "b.wav", "duration": 2.0}), + ] + ) + + summary = json.loads((tmp_path / "perf_summary.json").read_text(encoding="utf-8")) + assert summary["total_utterances"] == 2 + assert summary["total_audio_seconds"] == 3.0 + writer_summary = summary["stages"]["manifest_writer"] + assert writer_summary["total_items_processed"] == 2.0 + assert writer_summary["invocation_count"] == 1.0 + assert writer_summary["custom_metrics_sum"]["writer_items_processed"] == 2.0 + def test_setup_truncates_existing_file(self, tmp_path: Path) -> None: out = tmp_path / "output.jsonl" out.write_text('{"old": "data"}\n') diff --git a/tests/stages/audio/test_model_input_segmentation.py b/tests/stages/audio/test_model_input_segmentation.py new file mode 100644 index 0000000000..a0e6ccd91c --- /dev/null +++ b/tests/stages/audio/test_model_input_segmentation.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo_curator.stages.audio.model_input_segmentation import ( + duration_to_num_samples, + plan_audio_segments, + resolve_max_model_input_duration, +) + + +def test_resolve_max_model_input_duration_rejects_non_positive_values() -> None: + with pytest.raises(ValueError, match="max_inference_duration_s must be > 0"): + resolve_max_model_input_duration(max_duration_s=0, owner="test") + + +def test_duration_to_num_samples_rejects_invalid_sample_rate() -> None: + with pytest.raises(ValueError, match="sample_rate must be > 0"): + duration_to_num_samples(10.0, 0) + + +def test_plan_audio_segments_rejects_invalid_sample_rate() -> None: + with pytest.raises(ValueError, match="sample_rate must be > 0"): + plan_audio_segments(num_samples=100, sample_rate=0, max_duration_s=10.0, owner="test") + + +def test_plan_audio_segments_keeps_zero_sample_inputs_representable() -> None: + segments = plan_audio_segments(num_samples=0, sample_rate=16000, max_duration_s=30.0, owner="test") + + assert len(segments) == 1 + assert segments[0].index == 0 + assert segments[0].count == 1 + assert segments[0].start_sample == 0 + assert segments[0].stop_sample == 0 + assert segments[0].duration_s == 0.0 + + +def test_plan_audio_segments_handles_non_divisible_final_segment() -> None: + segments = plan_audio_segments(num_samples=95, sample_rate=10, max_duration_s=3.0, owner="test") + + assert [(segment.index, segment.count, segment.start_sample, segment.stop_sample) for segment in segments] == [ + (0, 4, 0, 30), + (1, 4, 30, 60), + (2, 4, 60, 90), + (3, 4, 90, 95), + ] + assert [segment.duration_s for segment in segments] == [3.0, 3.0, 3.0, 0.5] + + +def test_plan_audio_segments_exact_boundary_has_no_empty_tail() -> None: + segments = plan_audio_segments(num_samples=60, sample_rate=10, max_duration_s=3.0, owner="test") + + assert [(segment.index, segment.count, segment.start_sample, segment.stop_sample) for segment in segments] == [ + (0, 2, 0, 30), + (1, 2, 30, 60), + ] + assert [segment.duration_s for segment in segments] == [3.0, 3.0] + + +def test_plan_audio_segments_qwen_2400s_boundary_at_16khz() -> None: + sample_rate = 16000 + max_duration_s = 2400.0 + boundary_samples = int(sample_rate * max_duration_s) + + exact = plan_audio_segments( + num_samples=boundary_samples, + sample_rate=sample_rate, + max_duration_s=max_duration_s, + owner="ASRStage", + ) + just_over = plan_audio_segments( + num_samples=boundary_samples + 1, + sample_rate=sample_rate, + max_duration_s=max_duration_s, + owner="ASRStage", + ) + + assert [(segment.index, segment.count, segment.start_sample, segment.stop_sample) for segment in exact] == [ + (0, 1, 0, boundary_samples), + ] + assert exact[0].duration_s == max_duration_s + assert [(segment.index, segment.count, segment.start_sample, segment.stop_sample) for segment in just_over] == [ + (0, 2, 0, boundary_samples), + (1, 2, boundary_samples, boundary_samples + 1), + ] + assert just_over[0].duration_s == max_duration_s + assert just_over[1].duration_s == 1.0 / sample_rate diff --git a/tests/stages/test_payload_lifecycle.py b/tests/stages/test_payload_lifecycle.py new file mode 100644 index 0000000000..e71ad63966 --- /dev/null +++ b/tests/stages/test_payload_lifecycle.py @@ -0,0 +1,644 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +from collections.abc import Callable +from types import SimpleNamespace + +import pytest +import torch + +from nemo_curator.stages import payload_lifecycle as lifecycle +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.payload_lifecycle import ( + AudioPayloadMaterializeStage, + PayloadRef, + PayloadReleaseStage, + _PayloadAdmissionState, + _PayloadStoreState, + task_payload_refs, +) +from nemo_curator.tasks import AudioTask + + +class _RemoteMethod: + def __init__(self, fn: Callable[..., object]) -> None: + self._fn = fn + + def remote(self, *args: object, **kwargs: object) -> object: + return self._fn(*args, **kwargs) + + +class _FakeReader: + def __init__(self) -> None: + self.calls = 0 + + def process(self, task: AudioTask) -> AudioTask: + self.calls += 1 + samples = int(float(task.data["duration"]) * 16_000) + task.data["waveform"] = torch.zeros(1, samples, dtype=torch.float32) + task.data["sample_rate"] = 16_000 + task.data["num_samples"] = samples + return task + + def setup(self, *_args: object, **_kwargs: object) -> None: + return None + + def setup_on_node(self, *_args: object, **_kwargs: object) -> None: + return None + + def teardown(self) -> None: + return None + + +class _FakeSkipReader(_FakeReader): + def process(self, task: AudioTask) -> AudioTask: + self.calls += 1 + task.data["waveform"] = torch.empty(1, 0, dtype=torch.float32) + task.data["sample_rate"] = 16_000 + task.data["num_samples"] = 0 + task.data["duration"] = 0.0 + task.data["_skip_me"] = "audio_read_error" + task.data["audio_read_error"] = "RuntimeError: decode lost" + return task + + +class _RequiresTextStage(ProcessingStage[AudioTask, AudioTask]): + name = "RequiresTextStage" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], ["text"] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: AudioTask) -> AudioTask: + return task + + +@pytest.fixture(autouse=True) +def _fake_ray_get(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(lifecycle, "_ray_get", lambda obj: obj) + + +class _FakeAdmission: + def __init__(self, budget: int) -> None: + self.budget = budget + self.used = 0 + self.acquire_calls: list[tuple[str, int]] = [] + self.acquire_ttls: list[float] = [] + self.resize_calls: list[tuple[str, int]] = [] + self.resize_ttls: list[float] = [] + self.heartbeat_ttls: list[float] = [] + self.release_calls: list[tuple[str, int | None]] = [] + self.register_node = _RemoteMethod(lambda *_args: None) + self.try_acquire = _RemoteMethod(self._try_acquire) + self.resize = _RemoteMethod(self._resize) + self.release = _RemoteMethod(self._release) + self.heartbeat = _RemoteMethod(self._heartbeat) + self.snapshot = _RemoteMethod(self._snapshot) + + def _try_acquire(self, _node_id: str, owner_id: str, amount: int, _ttl: float) -> bool: + self.acquire_calls.append((owner_id, amount)) + self.acquire_ttls.append(_ttl) + if amount > self.budget or self.used + amount > self.budget: + return False + self.used += amount + return True + + def _resize(self, _node_id: str, owner_id: str, amount: int, _ttl: float) -> bool: + self.resize_calls.append((owner_id, amount)) + self.resize_ttls.append(_ttl) + if amount > self.budget: + return False + self.used = amount + return True + + def _release(self, _node_id: str, owner_id: str, amount: int | None = None) -> None: + self.release_calls.append((owner_id, amount)) + if amount is None: + self.used = 0 + else: + self.used = max(0, self.used - amount) + + def _heartbeat(self, _node_id: str, _owner_id: str, ttl: float) -> bool: + self.heartbeat_ttls.append(ttl) + return True + + def _snapshot(self) -> dict[str, int]: + return {"cluster_used": self.used, "cluster_budget": self.budget} + + +class _FakeStore: + def __init__(self) -> None: + self.payloads: dict[str, torch.Tensor] = {} + self.put_ttls: list[float] = [] + self.put = _RemoteMethod(self._put) + self.get = _RemoteMethod(self._get) + self.pin = _RemoteMethod(self._pin) + self.release = _RemoteMethod(self._release) + + def _put(self, payload_id: str, waveform: torch.Tensor, _amount: int, _ttl: float) -> None: + self.put_ttls.append(_ttl) + self.payloads[payload_id] = waveform + + def _get(self, payload_id: str, _ttl: float) -> torch.Tensor: + return self.payloads[payload_id] + + def _pin(self, payload_id: str, _ttl: float) -> bool: + return payload_id in self.payloads + + def _release(self, payload_id: str) -> int: + waveform = self.payloads.pop(payload_id, None) + if waveform is None: + return 0 + return int(waveform.element_size() * waveform.nelement()) + + +def _stage_with_fakes( + *, budget: int = 1_000_000 +) -> tuple[AudioPayloadMaterializeStage, _FakeReader, _FakeAdmission, _FakeStore]: + stage = AudioPayloadMaterializeStage( + max_node_payload_bytes=budget, + admission_poll_interval_s=0.001, + ) + reader = _FakeReader() + admission = _FakeAdmission(budget) + store = _FakeStore() + stage._reader = reader + stage._node_id = "node-1" + stage._node_budget_bytes = budget + stage._store_actor_name = "payload-store-node-1" + stage._admission = admission + stage._store = store + return stage, reader, admission, store + + +def test_audio_payload_materialize_constructs_reader_with_configured_keys() -> None: + stage = AudioPayloadMaterializeStage( + waveform_key="custom_waveform", + sample_rate_key="custom_sample_rate", + num_samples_key="custom_num_samples", + ) + stage._node_id = "node-1" + stage._admission = object() + stage._store = object() + + stage._ensure_ready() + + assert stage._reader.waveform_key == "custom_waveform" + assert stage._reader.sample_rate_key == "custom_sample_rate" + assert stage._reader.num_samples_key == "custom_num_samples" + + +def test_audio_payload_materialize_requires_duration_before_decode() -> None: + stage, reader, admission, _store = _stage_with_fakes() + + with pytest.raises(ValueError, match="requires 'duration'"): + stage.process(AudioTask(data={"audio_filepath": "s3://bucket/audio.wav"})) + + assert reader.calls == 0 + assert admission.acquire_calls == [] + + +def test_audio_payload_materialize_process_batch_validates_required_inputs() -> None: + stage, reader, admission, _store = _stage_with_fakes() + + with pytest.raises(ValueError, match="failed validation"): + stage.process_batch([AudioTask(data={"duration": 0.5})]) + + assert reader.calls == 0 + assert admission.acquire_calls == [] + + +def test_audio_payload_materialize_stores_waveform_by_ref_and_removes_payload() -> None: + stage, reader, admission, store = _stage_with_fakes() + task = AudioTask(data={"audio_filepath": "s3://bucket/audio.wav", "duration": 0.5}) + + output = stage.process(task) + + assert reader.calls == 1 + assert "waveform" not in output.data + payload_ref = output.data["waveform_ref"] + assert isinstance(payload_ref, PayloadRef) + assert payload_ref.amount_bytes == 32_000 + assert payload_ref.payload_id in store.payloads + assert admission.acquire_calls == [(payload_ref.payload_id, 32_000)] + assert admission.acquire_ttls == [stage.lease_ttl_s] + assert admission.resize_calls == [] + assert admission.heartbeat_ttls == [stage.materialized_lease_ttl_s] + assert store.put_ttls == [stage.materialized_lease_ttl_s] + assert output.data["_curator_payload_estimated_bytes"] == 32_000 + assert output.data["_curator_payload_bytes"] == 32_000 + assert "_curator_payload_producer_node_id" not in output.data + + +def test_audio_payload_materialize_passes_reader_skip_without_payload_ref() -> None: + stage, _reader, admission, store = _stage_with_fakes() + stage.skip_on_read_error = True + stage._reader = _FakeSkipReader() + task = AudioTask(data={"audio_filepath": "/local/audio.wav", "duration": 0.5}) + + output = stage.process(task) + + assert output.data["_skip_me"] == "audio_read_error" + assert output.data["audio_read_error"] == "RuntimeError: decode lost" + assert output.data["num_samples"] == 0 + assert "waveform" not in output.data + assert "waveform_ref" not in output.data + assert store.payloads == {} + assert admission.acquire_calls + assert admission.release_calls + + +def test_audio_payload_ref_carries_ray_namespace() -> None: + stage, _reader, _admission, _store = _stage_with_fakes() + stage._actor_namespace = "payload-ns" + task = AudioTask(data={"audio_filepath": "s3://bucket/audio.wav", "duration": 0.5}) + + payload_ref = stage.process(task).data["waveform_ref"] + + assert payload_ref.actor_namespace == "payload-ns" + + +def test_payload_actor_creation_is_detached_and_namespaced(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, object] = {} + + class _FakeRemoteActor: + def options(self, **options: object) -> "_FakeRemoteActor": + captured["options"] = options + return self + + def remote(self, **kwargs: object) -> str: + captured["kwargs"] = kwargs + return "actor-handle" + + fake_ray = SimpleNamespace( + get_actor=lambda _name, **_kwargs: (_ for _ in ()).throw(ValueError("missing actor")), + remote=lambda _actor_cls: _FakeRemoteActor(), + ) + monkeypatch.setitem(sys.modules, "ray", fake_ray) + + actor = lifecycle._get_named_actor_or_create(object, "payload-admission", namespace="payload-ns", value=123) + + assert actor == "actor-handle" + assert captured["options"] == { + "name": "payload-admission", + "get_if_exists": True, + "lifetime": "detached", + "namespace": "payload-ns", + } + assert captured["kwargs"] == {"value": 123} + + +def test_audio_payload_materialize_cleanup_kills_run_scoped_actors(monkeypatch: pytest.MonkeyPatch) -> None: + killed: list[tuple[str, str | None]] = [] + monkeypatch.setattr(lifecycle, "_active_ray_node_ids", lambda: ["node-a", "node/b"]) + monkeypatch.setattr(lifecycle, "_current_ray_namespace", lambda: "payload-ns") + monkeypatch.setattr( + lifecycle, "_kill_named_actor", lambda name, namespace=None: killed.append((name, namespace)) or True + ) + + stage = AudioPayloadMaterializeStage( + admission_actor_name="admission", + store_actor_prefix="store", + run_id="run/id", + ) + + stage.cleanup_run_resources() + + assert killed == [ + ("admission_run_id", "payload-ns"), + ("store_run_id_node-a", "payload-ns"), + ("store_run_id_node_b", "payload-ns"), + ] + + +def test_audio_payload_materialize_releases_tokens_when_actual_size_exceeds_budget() -> None: + stage, _reader, admission, store = _stage_with_fakes(budget=16_000) + task = AudioTask(data={"audio_filepath": "s3://bucket/audio.wav", "duration": 0.5}) + task.data["duration"] = 0.5 + + # The initial estimate is 32k, but the fake reader returns exactly 32k, so + # force the estimate small enough that resize must reject the actual payload. + stage.sample_width_bytes = 1 + + with pytest.raises(RuntimeError, match="Insufficient payload memory budget"): + stage.process(task) + + assert store.payloads == {} + assert admission.release_calls + + +def test_payload_release_stage_drops_ref_and_payload_metadata(monkeypatch: pytest.MonkeyPatch) -> None: + released: list[str] = [] + ref = PayloadRef( + payload_id="payload-1", + owner_node_id="node-1", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=123, + sample_rate=16_000, + num_samples=42, + ) + monkeypatch.setattr(lifecycle, "release_payload_ref", lambda payload_ref: released.append(payload_ref.payload_id)) + task = AudioTask( + data={ + "waveform_ref": ref, + "waveform": torch.zeros(1, 42), + "_curator_payload_estimated_bytes": 123, + "_curator_payload_bytes": 123, + "_curator_payload_producer_node_id": "node-a", + } + ) + + output = PayloadReleaseStage().process(task) + + assert released == ["payload-1"] + assert "waveform_ref" not in output.data + assert "waveform" not in output.data + assert not any(key.startswith("_curator_payload_") for key in output.data) + + +def test_payload_release_stage_preserves_data_attr_access_for_downstream_validation( + monkeypatch: pytest.MonkeyPatch, +) -> None: + released: list[str] = [] + ref = PayloadRef( + payload_id="payload-1", + owner_node_id="node-1", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=123, + sample_rate=16_000, + num_samples=42, + ) + monkeypatch.setattr(lifecycle, "release_payload_ref", lambda payload_ref: released.append(payload_ref.payload_id)) + task = AudioTask(data={"text": "keep me", "waveform_ref": ref, "nested": {"refs": [ref]}}) + data_before_release = task.data + + output = PayloadReleaseStage().process(task) + + assert released == ["payload-1"] + assert output.data is data_before_release + assert output.data.text == "keep me" + assert "waveform_ref" not in output.data + assert output.data["nested"] == {"refs": []} + assert _RequiresTextStage().validate_input(output) + + +def test_payload_release_stage_noops_for_rows_without_payload_refs(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(lifecycle, "release_payload_ref", lambda _payload_ref: pytest.fail("unexpected release")) + task = AudioTask( + data={ + "audio_filepath": "/local/audio.wav", + "_skip_me": "read_error", + "_curator_payload_estimated_bytes": 123, + "_curator_payload_bytes": 123, + } + ) + + output = PayloadReleaseStage().process(task) + + assert output.data == {"audio_filepath": "/local/audio.wav", "_skip_me": "read_error"} + + +def test_payload_release_stage_releases_all_nested_payload_refs(monkeypatch: pytest.MonkeyPatch) -> None: + released: list[str] = [] + ref_a = PayloadRef( + payload_id="payload-a", + owner_node_id="node-1", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=123, + sample_rate=16_000, + num_samples=42, + ) + ref_b = PayloadRef( + payload_id="payload-b", + owner_node_id="node-1", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=456, + sample_rate=16_000, + num_samples=84, + ) + monkeypatch.setattr(lifecycle, "release_payload_ref", lambda payload_ref: released.append(payload_ref.payload_id)) + task = AudioTask(data={"waveform_ref": ref_a, "extra_refs": [ref_b, ref_a]}) + + output = PayloadReleaseStage().process(task) + + assert sorted(released) == ["payload-a", "payload-b"] + assert "waveform_ref" not in output.data + assert output.data["extra_refs"] == [] + assert task_payload_refs(output) == [] + + +def test_payload_lease_keeper_heartbeats_until_stopped(monkeypatch: pytest.MonkeyPatch) -> None: + heartbeats: list[str] = [] + ref = PayloadRef( + payload_id="payload-a", + owner_node_id="node-1", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=123, + sample_rate=16_000, + num_samples=42, + lease_ttl_s=0.2, + ) + monkeypatch.setattr( + lifecycle, + "heartbeat_payload_refs_batched", + lambda refs: heartbeats.extend(ref.payload_id for ref in refs), + ) + + keeper = lifecycle._PayloadLeaseKeeper([ref], interval_s=0.01) + keeper.start() + time.sleep(0.035) + keeper.stop() + count_after_stop = len(heartbeats) + time.sleep(0.03) + + assert count_after_stop >= 2 + assert len(heartbeats) == count_after_stop + + +def test_payload_admission_resize_and_release() -> None: + admission = _PayloadAdmissionState(default_node_budget_bytes=100) + admission.register_node("node-a", 100) + + assert admission.try_acquire("node-a", "payload-1", 40) + assert not admission.resize("node-a", "payload-1", 120) + assert admission.resize("node-a", "payload-1", 80) + snapshot = admission.snapshot() + assert snapshot["node_used"]["node-a"] == 80 + + admission.release("node-a", "payload-1") + assert admission.snapshot()["node_used"]["node-a"] == 0 + + +def test_payload_admission_heartbeat_many_preserves_request_order() -> None: + admission = _PayloadAdmissionState(default_node_budget_bytes=100) + admission.register_node("node-a", 100) + assert admission.try_acquire("node-a", "payload-1", 40) + reap_calls = 0 + original_reap = admission._reap_expired + + def count_reap() -> None: + nonlocal reap_calls + reap_calls += 1 + original_reap() + + admission._reap_expired = count_reap # type: ignore[method-assign] + + assert admission.heartbeat_many( + [ + ("node-a", "payload-1", 5.0), + ("node-a", "missing", 5.0), + ] + ) == [True, False] + assert reap_calls == 1 + + +def test_payload_store_pin_and_get_many_preserve_request_order() -> None: + store = _PayloadStoreState() + payload_a = torch.tensor([1.0]) + payload_b = torch.tensor([2.0]) + store.put("payload-a", payload_a, 4) + store.put("payload-b", payload_b, 4) + reap_calls = 0 + original_reap = store._reap_expired + + def count_reap() -> None: + nonlocal reap_calls + reap_calls += 1 + original_reap() + + store._reap_expired = count_reap # type: ignore[method-assign] + + assert store.pin_many([("payload-b", 5.0), ("missing", 5.0)]) == [True, False] + assert reap_calls == 1 + assert store.get_many([("payload-b", 5.0), ("payload-a", 5.0)]) == [payload_b, payload_a] + assert reap_calls == 2 + + +def test_payload_admission_enforces_cluster_budget() -> None: + admission = _PayloadAdmissionState(default_node_budget_bytes=100, default_cluster_budget_bytes=150) + admission.register_node("node-a", 100) + admission.register_node("node-b", 100) + + assert admission.try_acquire("node-a", "payload-1", 100) + assert not admission.try_acquire("node-b", "payload-2", 100) + assert admission.try_acquire("node-b", "payload-3", 50) + + snapshot = admission.snapshot() + assert snapshot["cluster_budget"] == 150 + assert snapshot["cluster_used"] == 150 + + +def test_payload_materialize_rejects_invalid_byte_limit_string() -> None: + with pytest.raises(ValueError, match="max_node_payload_bytes"): + AudioPayloadMaterializeStage(max_node_payload_bytes="definitely-not-bytes")._ensure_ready() + + +def test_payload_materialize_fails_fast_when_single_row_exceeds_cluster_budget() -> None: + stage = AudioPayloadMaterializeStage(max_node_payload_bytes=1_000, max_cluster_payload_bytes=10) + stage._node_budget_bytes = 1_000 + stage._cluster_budget_bytes = 10 + + with pytest.raises(RuntimeError, match="cluster payload budget"): + stage._acquire("payload-1", 20) + + +def test_payload_materialize_times_out_when_admission_budget_never_frees() -> None: + stage, _reader, admission, _store = _stage_with_fakes(budget=1_000) + admission.budget = 0 + stage.admission_poll_interval_s = 0.0001 + stage.admission_wait_timeout_s = 0.001 + + with pytest.raises(RuntimeError, match="Timed out waiting for payload admission") as exc_info: + stage._acquire("payload-1", 100) + + assert "cluster_used" in str(exc_info.value) + assert admission.acquire_ttls + assert set(admission.acquire_ttls) == {stage.lease_ttl_s} + + +def test_task_payload_refs_finds_nested_refs() -> None: + ref = PayloadRef( + payload_id="payload-1", + owner_node_id="node-1", + store_actor_name="store", + admission_actor_name="admission", + amount_bytes=123, + sample_rate=16_000, + num_samples=42, + ) + task = AudioTask(data={"nested": {"payloads": [ref]}}) + + assert task_payload_refs(task) == [ref] + + +def test_payload_admission_reaps_expired_leases(monkeypatch: pytest.MonkeyPatch) -> None: + now = [0.0] + monkeypatch.setattr(lifecycle.time, "monotonic", lambda: now[0]) + admission = _PayloadAdmissionState(default_node_budget_bytes=100) + admission.register_node("node-a", 100) + + assert admission.try_acquire("node-a", "payload-1", 100, lease_ttl_s=1.0) + now[0] = 1.1 + assert admission.try_acquire("node-a", "payload-2", 100, lease_ttl_s=1.0) + + snapshot = admission.snapshot() + assert snapshot["node_used"]["node-a"] == 100 + assert snapshot["lease_count"] == 1 + + +def test_audio_payload_materialize_rejects_non_positive_materialized_lease() -> None: + with pytest.raises(ValueError, match="materialized_lease_ttl_s must be positive"): + AudioPayloadMaterializeStage(materialized_lease_ttl_s=0.0) + + +def test_payload_admission_explicit_release_lease_survives_without_heartbeat(monkeypatch: pytest.MonkeyPatch) -> None: + now = [0.0] + monkeypatch.setattr(lifecycle.time, "monotonic", lambda: now[0]) + admission = _PayloadAdmissionState(default_node_budget_bytes=100) + admission.register_node("node-a", 100) + + assert admission.try_acquire("node-a", "payload-1", 100, lease_ttl_s=0.0) + now[0] = 10_000.0 + + snapshot = admission.snapshot() + assert snapshot["node_used"]["node-a"] == 100 + assert snapshot["lease_count"] == 1 + assert admission.heartbeat("node-a", "payload-1", lease_ttl_s=1.0) + now[0] = 20_000.0 + assert admission.snapshot()["lease_count"] == 1 + + +def test_payload_store_explicit_release_payload_survives_without_heartbeat(monkeypatch: pytest.MonkeyPatch) -> None: + now = [0.0] + monkeypatch.setattr(lifecycle.time, "monotonic", lambda: now[0]) + store = _PayloadStoreState(default_lease_ttl_s=1.0) + waveform = torch.zeros(1, 10) + + store.put("payload-1", waveform, 40, lease_ttl_s=0.0) + now[0] = 10_000.0 + + assert store.snapshot()["payload_count"] == 1 + assert store.pin("payload-1", lease_ttl_s=1.0) + now[0] = 20_000.0 + assert store.get("payload-1", lease_ttl_s=1.0) is waveform + assert store.release("payload-1") == 40 diff --git a/tests/stages/text/io/reader/test_jsonl.py b/tests/stages/text/io/reader/test_jsonl.py index e783e48aa3..023b9a207c 100644 --- a/tests/stages/text/io/reader/test_jsonl.py +++ b/tests/stages/text/io/reader/test_jsonl.py @@ -191,5 +191,5 @@ def test_jsonl_reader_with_blocksize_limit(tmp_path: Path, caplog: pytest.LogCap # Since the storage size is larger than 10 million bytes, the FilePartitioningStage should warn file_partitioning_stage = stage.decompose()[0] with caplog.at_level("WARNING"): - file_partitioning_stage.process(EmptyTask) + file_partitioning_stage.process(EmptyTask()) assert "File group task has exceeded the storage limit per partition" in caplog.text diff --git a/tests/stages/text/io/reader/test_parquet.py b/tests/stages/text/io/reader/test_parquet.py index 4bb93e72bf..1b63ccb6c5 100644 --- a/tests/stages/text/io/reader/test_parquet.py +++ b/tests/stages/text/io/reader/test_parquet.py @@ -290,5 +290,5 @@ def test_parquet_reader_with_blocksize_limit(tmp_path: Path, caplog: pytest.LogC # Since the storage size is larger than 10_000 bytes, the FilePartitioningStage should warn file_partitioning_stage = stage.decompose()[0] with caplog.at_level("WARNING"): - file_partitioning_stage.process(EmptyTask) + file_partitioning_stage.process(EmptyTask()) assert "File group task has exceeded the storage limit per partition" in caplog.text diff --git a/tests/tasks/test_utils.py b/tests/tasks/test_utils.py index f2ede9f27b..bfd39f4d75 100644 --- a/tests/tasks/test_utils.py +++ b/tests/tasks/test_utils.py @@ -28,6 +28,27 @@ def make_dummy_task(stage_name: str, process_time: float, custom: float = 0.0) - class TestTaskPerfUtils: """Test cases for TaskPerfUtils class.""" + def test_collect_stage_metrics_ignores_identity_strings(self) -> None: + """Identity labels (actor_id/node_id/gpu_id) are excluded from items() so float() coercion won't raise.""" + perf = StagePerfStats( + stage_name="StageGpu", + process_time=1.5, + num_items_processed=4, + custom_metrics={"io": 2.0}, + actor_id="StageGpu:actor-deadbeef", + node_id="node-2", + gpu_id="node-2:1", + ) + task = EmptyTask(dataset_name="test", data=None, _stage_perf=[perf]) + + metrics = TaskPerfUtils.collect_stage_metrics([task]) # must not raise on float("node-2:1") + + assert "actor_id" not in metrics["StageGpu"] + assert "node_id" not in metrics["StageGpu"] + assert "gpu_id" not in metrics["StageGpu"] + assert np.allclose(metrics["StageGpu"]["process_time"], np.array([1.5])) + assert np.allclose(metrics["StageGpu"]["custom.io"], np.array([2.0])) + def test_collect_stage_metrics_from_workflow_result(self) -> None: """Test collecting stage metrics from WorkflowRunResult.""" workflow_result = WorkflowRunResult(workflow_name="unit") diff --git a/tests/utils/test_gpu_sampler.py b/tests/utils/test_gpu_sampler.py new file mode 100644 index 0000000000..e73b4e2e2a --- /dev/null +++ b/tests/utils/test_gpu_sampler.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.utils.gpu_sampler import GpuUtilSampler, norm_uuid + + +def test_norm_uuid_is_public_normalizer() -> None: + assert norm_uuid("GPU-ABCDEF") == "abcdef" + assert norm_uuid(b"GPU-1234") == "1234" + + +def test_gpu_sampler_reports_inactive_diagnostics_without_nvml() -> None: + sampler = GpuUtilSampler(gpu_uuids=("GPU-abc",)) + + sampler.start() + diagnostics = sampler.diagnostics() + + assert diagnostics["gpu_sampler_active"] == 0.0 + assert diagnostics["gpu_sampler_target_uuid_count"] == 1.0 + assert diagnostics["gpu_sampler_handle_count"] == 0.0 + assert diagnostics["gpu_sampler_sample_all_visible"] == 1.0 diff --git a/tutorials/audio/README.md b/tutorials/audio/README.md index 97487338c7..1f1c217a77 100644 --- a/tutorials/audio/README.md +++ b/tutorials/audio/README.md @@ -115,6 +115,8 @@ Audio pipelines can appear stuck for legitimate reasons. Before killing a run: | **Setup** | [Installation](https://docs.nvidia.com/nemo/curator/latest/get-started/installation.html) · [Configuration](https://docs.nvidia.com/nemo/curator/latest/get-started/configuration.html) | | **Concepts** | [Architecture](https://docs.nvidia.com/nemo/curator/latest/about/concepts/index.html) · [Data Loading](https://docs.nvidia.com/nemo/curator/latest/about/concepts/text/data-loading-concepts.html) | | **Advanced** | [Custom Pipelines](https://docs.nvidia.com/nemo/curator/latest/reference/index.html) · [Execution Backends](https://docs.nvidia.com/nemo/curator/latest/reference/infrastructure/execution-backends.html) · [NeMo ASR Integration](https://docs.nvidia.com/nemo/curator/latest/about/key-features.html) | +| **Developer guide** | [Audio stage internals](../../nemo_curator/stages/audio/README.md) · raw Qwen payload lifecycle, local/windowed ASR adapter batching, perf-summary fields | +| **Qwen assets** | [Qwen-Omni prompt templates](../../examples/audio/qwen_omni_inprocess/prompts/) | ## Known Issues diff --git a/uv.lock b/uv.lock index 4c0bfd823f..8ec583f04e 100644 --- a/uv.lock +++ b/uv.lock @@ -5200,6 +5200,32 @@ audio-cuda12 = [ { name = "transformers" }, { name = "whisperx", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, ] +audio-qwen = [ + { name = "accelerate" }, + { name = "cuda-python", version = "12.9.4", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "cuda-python", version = "12.9.5", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64') or sys_platform != 'linux'" }, + { name = "gpustat" }, + { name = "librosa" }, + { name = "nemo-text-processing", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "nemo-toolkit", extra = ["asr"], marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "nvidia-cudnn-cu12" }, + { name = "nvidia-ml-py" }, + { name = "onnx" }, + { name = "onnxruntime-gpu", marker = "platform_machine == 'x86_64'" }, + { name = "opencc-python-reimplemented" }, + { name = "pyannote-audio", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "pydub" }, + { name = "qwen-omni-utils" }, + { name = "scipy" }, + { name = "silero-vad" }, + { name = "soundfile" }, + { name = "torchaudio", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64') or sys_platform != 'linux'" }, + { name = "torchaudio", version = "2.10.0+cu129", source = { registry = "https://download.pytorch.org/whl/cu129" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "transformers" }, + { name = "vllm", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "whisperx", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, +] cuda12 = [ { name = "gpustat" }, { name = "nvidia-ml-py" }, @@ -5499,6 +5525,7 @@ requires-dist = [ { name = "nemo-curator", extras = ["audio-common"], marker = "extra == 'audio-cpu'" }, { name = "nemo-curator", extras = ["audio-common"], marker = "extra == 'audio-cuda12'" }, { name = "nemo-curator", extras = ["audio-cuda12"], marker = "extra == 'all'" }, + { name = "nemo-curator", extras = ["audio-cuda12"], marker = "extra == 'audio-qwen'" }, { name = "nemo-curator", extras = ["cuda12"], marker = "extra == 'audio-cuda12'" }, { name = "nemo-curator", extras = ["cuda12"], marker = "extra == 'image-cuda12'" }, { name = "nemo-curator", extras = ["cuda12"], marker = "extra == 'inference-server'" }, @@ -5536,6 +5563,7 @@ requires-dist = [ { name = "nemo-curator", extras = ["translation-segmentation"], marker = "extra == 'translation-all'" }, { name = "nemo-curator", extras = ["video-cpu"], marker = "extra == 'video-cuda12'" }, { name = "nemo-curator", extras = ["video-cuda12"], marker = "extra == 'all'" }, + { name = "nemo-curator", extras = ["vllm"], marker = "extra == 'audio-qwen'" }, { name = "nemo-curator", extras = ["vllm"], marker = "extra == 'inference-server'" }, { name = "nemo-curator", extras = ["vllm"], marker = "extra == 'interleaved-cuda12'" }, { name = "nemo-curator", extras = ["vllm"], marker = "extra == 'text-cuda12'" }, @@ -5568,6 +5596,7 @@ requires-dist = [ { name = "pylibraft-cu12", marker = "extra == 'deduplication-cuda12'", specifier = "==25.10.*" }, { name = "pynvvideocodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'video-cuda12'", specifier = "==2.0.2" }, { name = "pypdfium2", marker = "extra == 'interleaved-cpu'" }, + { name = "qwen-omni-utils", marker = "extra == 'audio-qwen'" }, { name = "raft-dask-cu12", marker = "extra == 'deduplication-cuda12'", specifier = "==25.10.*" }, { name = "rapidsmpf-cu12", marker = "extra == 'deduplication-cuda12'", specifier = "==25.10.*" }, { name = "ray", extras = ["data", "default"], specifier = ">=2.55.1" }, @@ -5613,7 +5642,7 @@ requires-dist = [ { name = "warcio", marker = "extra == 'text-cpu'" }, { name = "whisperx", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'audio-common'", specifier = ">=3.8.4" }, ] -provides-extras = ["cuda12", "vllm", "inference-server", "deduplication-cuda12", "audio-common", "audio-cpu", "audio-cuda12", "image-cpu", "image-cuda12", "translation-common", "translation-metrics", "translation-segmentation", "translation-aws", "translation-google", "translation-nmt", "translation-all", "text-cpu", "text-cuda12", "video-cpu", "video-cuda12", "math-cpu", "math-cuda12", "interleaved-cpu", "interleaved-cuda12", "sdg-cpu", "sdg-cuda12", "all"] +provides-extras = ["cuda12", "vllm", "inference-server", "deduplication-cuda12", "audio-common", "audio-cpu", "audio-cuda12", "audio-qwen", "image-cpu", "image-cuda12", "translation-common", "translation-metrics", "translation-segmentation", "translation-aws", "translation-google", "translation-nmt", "translation-all", "text-cpu", "text-cuda12", "video-cpu", "video-cuda12", "math-cpu", "math-cuda12", "interleaved-cpu", "interleaved-cuda12", "sdg-cpu", "sdg-cuda12", "all"] [package.metadata.requires-dev] build = [ @@ -8691,6 +8720,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/5d/d963412914a2f778e4594c5164dfe69bc53435877bfcd1a0db25e67cf320/quack_kernels-0.4.0-py3-none-any.whl", hash = "sha256:c7ef1d3ee317adbc363b02e69a0a26110a8fcf5e07d8ada2cf7a1b4828b5539f", size = 250771, upload-time = "2026-04-27T15:29:07.227Z" }, ] +[[package]] +name = "qwen-omni-utils" +version = "0.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "av" }, + { name = "librosa" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a2/27/81a51b10ef6d1a4f158498089530eb4307689d75adf6bd34f6632eda3099/qwen_omni_utils-0.0.9.tar.gz", hash = "sha256:c598269c7069afb4d154f8ea523972e8d8794a978fffed89cb4d87c326f97447", size = 8632, upload-time = "2026-02-10T09:45:30.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/a8/dfc07e5b9005decbd2f08ac564250775fea9526473ae1bc90fbf45e8036f/qwen_omni_utils-0.0.9-py3-none-any.whl", hash = "sha256:f111db07af669c83333411c5177131e18e831fe666d6a55a1af263952ada8939", size = 9657, upload-time = "2026-02-10T09:45:28.888Z" }, +] + [[package]] name = "raft-dask-cu12" version = "25.10.0" From 3a23f6274aabc745e6cdd7d083a21e5b490ee875 Mon Sep 17 00:00:00 2001 From: Mohammad Aaftab Date: Mon, 29 Jun 2026 18:06:02 +0530 Subject: [PATCH 2/2] Minimize non-audio PR diffs Signed-off-by: Mohammad Aaftab --- nemo_curator/backends/base.py | 98 +++++++++--- nemo_curator/backends/perf_identity.py | 11 -- nemo_curator/backends/ray_data/adapter.py | 61 ++++---- nemo_curator/backends/ray_data/executor.py | 27 ++-- nemo_curator/backends/ray_data/utils.py | 7 +- nemo_curator/backends/utils.py | 61 +++++--- nemo_curator/backends/xenna/adapter.py | 13 +- nemo_curator/backends/xenna/executor.py | 97 ++++++------ nemo_curator/models/vllm_model.py | 66 ++++++--- nemo_curator/pipeline/payload_refs.py | 29 +--- nemo_curator/pipeline/pipeline.py | 8 +- nemo_curator/utils/performance_utils.py | 35 +---- tests/backends/ray_data/test_utils.py | 163 +------------------- tests/backends/test_task_id_postprocess.py | 10 +- tests/backends/test_utils.py | 9 +- tests/backends/test_xenna_executor.py | 14 ++ tests/backends/xenna/__init__.py | 1 - tests/backends/xenna/test_executor.py | 164 --------------------- tests/pipelines/test_pipelines.py | 5 +- 19 files changed, 315 insertions(+), 564 deletions(-) delete mode 100644 tests/backends/xenna/__init__.py delete mode 100644 tests/backends/xenna/test_executor.py diff --git a/nemo_curator/backends/base.py b/nemo_curator/backends/base.py index c54b10037f..684edd1466 100644 --- a/nemo_curator/backends/base.py +++ b/nemo_curator/backends/base.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import time import uuid from abc import ABC, abstractmethod @@ -24,17 +22,19 @@ from nemo_curator.backends.perf_identity import apply_worker_perf_identity, read_worker_metadata_identity from nemo_curator.core.utils import ignore_ray_head_node +from nemo_curator.tasks import Task from nemo_curator.tasks.task_terminals import preserve_dropped_terminal_tasks from nemo_curator.utils.performance_utils import StagePerfStats, StageTimer if TYPE_CHECKING: from nemo_curator.stages.base import ProcessingStage - from nemo_curator.tasks import Task @dataclass class NodeInfo: - """Generic node information for setup_on_node calls across backends.""" + """Generic node information for setup_on_node calls across backends. + Simplified to match Xenna's structure. + """ node_id: str = "" @@ -42,13 +42,13 @@ class NodeInfo: @dataclass class WorkerMetadata: """Generic worker metadata for setup_on_node calls across backends. - - Backends stamp ``actor_id``/``node_id``/``gpu_id`` at setup; perf records - copy them verbatim (see ``backends/perf_identity.py``). + Simplified to match Xenna's structure. The allocation field can contain + backend-specific allocation information. Backends may also stamp performance + identity fields at worker setup. """ worker_id: str = "" - allocation: Any = None # Backend-specific allocation info (Xenna) + allocation: Any = None # Backend-specific allocation info actor_id: str = "" node_id: str = "" gpu_id: str = "" @@ -67,10 +67,10 @@ def __init__(self, config: dict[str, Any] | None = None, ignore_head_node: bool self.ignore_head_node = ignore_head_node or ignore_ray_head_node() @abstractmethod - def execute(self, stages: list[ProcessingStage], initial_tasks: list[Task] | None = None) -> None: + def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | None = None) -> None: """Execute the pipeline.""" - def _cleanup_stage_run_resources(self, stages: list[ProcessingStage]) -> None: + def _cleanup_stage_run_resources(self, stages: list["ProcessingStage"]) -> None: """Release run-scoped resources created by pipeline helper stages. Some helpers intentionally create named Ray actors so payload handles can @@ -127,7 +127,7 @@ def _attach_pipeline_hardware_perf(tasks: list[Task], perf_stats: StagePerfStats for task in tasks: task.add_stage_perf(perf_stats) - def _publish_external_perf(self, stages: list[ProcessingStage], perf_stats: StagePerfStats | None) -> None: + def _publish_external_perf(self, stages: list["ProcessingStage"], perf_stats: StagePerfStats | None) -> None: """Publish a run-level perf record to the terminal artifact writer when one exists.""" if perf_stats is None: return @@ -145,11 +145,11 @@ def _publish_external_perf(self, stages: list[ProcessingStage], perf_stats: Stag class BaseStageAdapter: """Adapts ProcessingStage to an execution backend, if needed.""" - def __init__(self, stage: ProcessingStage): + def __init__(self, stage: "ProcessingStage"): self.stage = stage @staticmethod - def _stage_resource_expectation_metrics(stage: ProcessingStage) -> dict[str, float]: + def _stage_resource_expectation_metrics(stage: "ProcessingStage") -> dict[str, float]: """Return non-summing resource expectations attached by wrapper stages.""" metrics: dict[str, float] = {} for attr_name, metric_name in ( @@ -174,11 +174,21 @@ def _cache_perf_identity(self) -> None: self._perf_identity = read_worker_metadata_identity(str(self.stage.name), worker_metadata) def process_batch(self, tasks: list[Task]) -> list[Task]: - """Process a batch of tasks, timing and stamping perf stats on outputs.""" + """Process a batch of tasks. + + Args: + tasks (list[Task]): List of tasks to process + + Returns: + list[Task]: List of processed tasks + """ + # Lazy initialize timer if needed if not hasattr(self, "_timer") or self._timer is None: self._timer = StageTimer(self.stage) + # Calculate input data size for timer input_size = sum(task.num_items for task in tasks) + # Initialize performance timer for this batch self._timer.reinit(input_size) tracks_payload_refs = bool(getattr(self.stage, "_curator_tracks_payload_refs", False)) input_payload_refs = self._collect_payload_refs(tasks) if tracks_payload_refs else {} @@ -187,6 +197,7 @@ def process_batch(self, tasks: list[Task]) -> list[Task]: window_start = time.time() if extended_metrics else 0.0 try: with self._timer.time_process(input_size): + # Use the batch processing logic results = self.stage.process_batch(tasks) except Exception: self._release_payload_refs(input_payload_refs.values()) @@ -294,10 +305,47 @@ def _release_payload_refs(payload_refs: object) -> None: release_payload_ref(payload_ref) def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Task | None]) -> list[Task]: - """Assign a deterministic ``task_id`` to every emitted task.""" + """Assign a deterministic ``task_id`` to every emitted task. + + This is the single place task ids are assigned — it runs for every + stage on every backend (all backend adapters subclass this), so it + makes no difference whether a stage defines ``process`` or overrides + ``process_batch``. ``task_id`` is the task's id path (parents + own segment); ids are + re-derived at each stage boundary so the same object passing through + N stages gets N ids. + + The input→output mapping decides each output's PARENT; whether the + stage is a source decides each output's SEGMENT (content id vs index) + — the two are independent. ``None`` outputs (Curator's "return None to + filter") are NOT removed before the length check — keeping them in + place preserves positional alignment for filter stages — and are then + dropped from the returned list. + + - single input → every output is its child (fan-out): ``parent_`` + - ``len(output) == len(input)`` → positional 1:1: each ``parent_i_``; + a ``None`` slot just means input ``i`` was filtered. + - any other (ambiguous) cardinality across a batch → a random ``uuid`` + prefixed with ``"r"`` (e.g. ``"r3f9a…"``), so ``task_id`` is never + empty even when a derived id is not possible. The ``"r"`` prefix flags + the id as non-deterministic / ancestry-not-tracked (see + ``Task.task_id`` docstring). + + ``seg`` is the output's content id (``Task.get_deterministic_id()``) + for a source stage when available, else the positional index — so a + source partition keeps a stable id across reorderings regardless of + whether the source is 1→N or N→N. + + Note: a stage that BOTH filters and fans out within a single batch + (returning a flat list rather than a per-input slot) cannot be mapped + positionally; if its length happens to equal the input length the 1:1 + assumption may misattribute parents. That combination is unsupported + until per-slot sentinels (NoneTask/FailedTask) land in a later PR. + """ is_source = getattr(self.stage, "is_source_stage", False) if len(input_tasks) == 1: + # Fan-out (incl. a source reading from EmptyTask): every non-None + # output is a child of the single input. parent_id = input_tasks[0].task_id out: list[Task] = [t for t in output_tasks if t is not None] for i, task in enumerate(out): @@ -306,6 +354,8 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas return out if len(output_tasks) == len(input_tasks): + # Positional 1:1. None is kept above so a filtered slot still lines + # up with its own parent; drop the None slots from the result. out = [] for parent, task in zip(input_tasks, output_tasks, strict=True): if task is None: @@ -315,17 +365,31 @@ def _post_process_task_ids(self, input_tasks: list[Task], output_tasks: list[Tas out.append(task) return out + # Ambiguous cardinality across a batch: a derived id is not possible. Use a + # random "r"-prefixed uuid so task_id is non-empty but clearly flagged + # non-deterministic. out = [t for t in output_tasks if t is not None] for task in out: task.task_id = "r" + uuid.uuid4().hex return out def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: - """Setup the stage on a node (node/worker info may be absent on some backends).""" + """Setup the stage on a node. + + Args: + node_info (NodeInfo, optional): Information about the node + worker_metadata (WorkerMetadata, optional): Information about the worker + """ + # Call the underlying stage's setup_on_node method + # Some backends may provide node/worker info, others may not self.stage.setup_on_node(node_info, worker_metadata) def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: - """Setup the stage once per actor.""" + """Setup the stage once per actor. + + Args: + worker_metadata (WorkerMetadata, optional): Information about the worker + """ self._worker_metadata = worker_metadata if bool(getattr(self.stage, "extended_performance_metrics", False)): self._cache_perf_identity() diff --git a/nemo_curator/backends/perf_identity.py b/nemo_curator/backends/perf_identity.py index 88d1ceeadb..4251ccbd4f 100644 --- a/nemo_curator/backends/perf_identity.py +++ b/nemo_curator/backends/perf_identity.py @@ -332,17 +332,6 @@ def _ray_gpu_assignment(requires_gpu: bool) -> tuple[tuple[int, ...], tuple[str, return _gpu_assignment_from_tokens(env_tokens) -def _ray_gpu_indices(requires_gpu: bool) -> tuple[int, ...]: - return _ray_gpu_assignment(requires_gpu)[0] - - -def _ray_gpu_label(node_label: str, requires_gpu: bool) -> str: - gpu_indices = _ray_gpu_indices(requires_gpu) - if gpu_indices: - return _format_gpu_label(node_label, gpu_indices[0]) - return "" - - def build_ray_perf_identity( stage_name: str, *, diff --git a/nemo_curator/backends/ray_data/adapter.py b/nemo_curator/backends/ray_data/adapter.py index 6ac52f8e1e..43c6ec63e8 100644 --- a/nemo_curator/backends/ray_data/adapter.py +++ b/nemo_curator/backends/ray_data/adapter.py @@ -108,34 +108,15 @@ def process_dataset(self, dataset: Dataset) -> Dataset: Returns: Dataset: Processed Ray Data dataset """ - is_actor_stage_ = self.stage.ray_stage_spec().get(RayStageSpecKeys.IS_ACTOR_STAGE, is_actor_stage(self.stage)) - - map_batches_fn, concurrency_kwargs = self._map_batches_fn_and_kwargs( - is_actor_stage=is_actor_stage_, - ) - - # Calculate concurrency based on available resources - logger.info(f"{self.stage.__class__.__name__} {is_actor_stage_=} with {concurrency_kwargs=}") - - processed_dataset = dataset.map_batches(map_batches_fn, batch_size=self.batch_size, **concurrency_kwargs) # type: ignore[reportArgumentType] - - if self.stage.ray_stage_spec().get(RayStageSpecKeys.IS_FANOUT_STAGE, False): - processed_dataset = processed_dataset.repartition(target_num_rows_per_block=1) - - return processed_dataset - - def _map_batches_fn_and_kwargs( - self, - *, - is_actor_stage: bool, - ) -> tuple[Any, dict[str, Any]]: ray_stage_spec = self.stage.ray_stage_spec() - if is_actor_stage: + stage_is_actor = ray_stage_spec.get(RayStageSpecKeys.IS_ACTOR_STAGE, is_actor_stage(self.stage)) + + if stage_is_actor: map_batches_fn = create_actor_from_stage(self.stage) - concurrency_kwargs = {"compute": get_actor_compute_strategy_for_stage(self.stage)} + map_batches_kwargs = {"compute": get_actor_compute_strategy_for_stage(self.stage)} else: map_batches_fn = create_task_from_stage(self.stage) - concurrency_kwargs = {} + map_batches_kwargs = {} actor_pool_sizing_keys = get_configured_actor_pool_sizing_keys(ray_stage_spec) if actor_pool_sizing_keys: @@ -146,15 +127,18 @@ def _map_batches_fn_and_kwargs( num_workers = self.stage.num_workers() if num_workers is not None and num_workers > 0: - concurrency_kwargs["compute"] = TaskPoolStrategy(size=num_workers) + map_batches_kwargs["compute"] = TaskPoolStrategy(size=num_workers) - max_calls = ray_stage_spec.get(RayStageSpecKeys.MAX_CALLS_PER_WORKER, None) + max_calls = ray_stage_spec.get(RayStageSpecKeys.MAX_CALLS_PER_WORKER) if max_calls is not None: - concurrency_kwargs["max_calls"] = max_calls + map_batches_kwargs["max_calls"] = max_calls - concurrency_kwargs.update(self._build_resource_kwargs(ray_stage_spec)) + map_batches_kwargs.update(self._build_resource_kwargs(ray_stage_spec)) + # Per-stage ray_remote_args (e.g. runtime_env with different pip versions per stage). ray_remote_args = copy.deepcopy(ray_stage_spec.get(RayStageSpecKeys.RAY_REMOTE_ARGS) or {}) + # If the stage declares runtime_env, forward it directly to Ray so Ray creates and + # caches an isolated virtualenv for this stage's workers. if self.stage.runtime_env: ray_remote_args["runtime_env"] = self.stage.runtime_env @@ -166,13 +150,20 @@ def _map_batches_fn_and_kwargs( ) raise ValueError(msg) - concurrency_kwargs.update(ray_remote_args) - return map_batches_fn, concurrency_kwargs + map_batches_kwargs.update(ray_remote_args) + + # Let Ray Data apply the selected compute strategy and resource requirements. + logger.info(f"{self.stage.__class__.__name__} stage_is_actor={stage_is_actor} with {map_batches_kwargs=}") + + processed_dataset = dataset.map_batches(map_batches_fn, batch_size=self.batch_size, **map_batches_kwargs) # type: ignore[reportArgumentType] + + if ray_stage_spec.get(RayStageSpecKeys.IS_FANOUT_STAGE, False): + processed_dataset = processed_dataset.repartition(target_num_rows_per_block=1) + + return processed_dataset -def create_actor_from_stage( - stage: ProcessingStage, -) -> type[RayDataStageAdapter]: +def create_actor_from_stage(stage: ProcessingStage) -> type[RayDataStageAdapter]: """Create a StageProcessor class with the proper stage name for display.""" class RayDataStageActorAdapter(RayDataStageAdapter): @@ -203,9 +194,7 @@ def __call__(self, batch: dict[str, Any]) -> dict[str, Any]: return RayDataStageActorAdapter -def create_task_from_stage( - stage: ProcessingStage, -) -> Callable[[dict[str, Any]], dict[str, Any]]: +def create_task_from_stage(stage: ProcessingStage) -> Callable[[dict[str, Any]], dict[str, Any]]: """Create a named Ray Data stage adapter function. This creates a standalone function that wraps the stage processing logic diff --git a/nemo_curator/backends/ray_data/executor.py b/nemo_curator/backends/ray_data/executor.py index d7657dd3cf..182a55b874 100644 --- a/nemo_curator/backends/ray_data/executor.py +++ b/nemo_curator/backends/ray_data/executor.py @@ -18,9 +18,7 @@ from loguru import logger from ray.data import DataContext, Dataset -from nemo_curator.backends.base import ( - BaseExecutor, -) +from nemo_curator.backends.base import BaseExecutor from nemo_curator.backends.utils import execute_setup_on_node, register_loguru_serializer from nemo_curator.tasks import EmptyTask, Task @@ -34,13 +32,22 @@ class RayDataExecutor(BaseExecutor): """Ray Data-based executor for pipeline execution. This executor: - 1. Executes setup on all nodes for all stages + 1. Executes setup on Ray nodes for all stages 2. Converts initial tasks to Ray Data dataset 3. Applies each stage as a Ray Data transformation (as a task or actor in map_batches) 4. Returns final results as a list of tasks """ def __init__(self, config: dict[str, Any] | None = None, ignore_head_node: bool = False): + """Initialize the executor. + + Args: + config (dict[str, Any], optional): Configuration dictionary. + ignore_head_node (bool, optional): Whether to skip the Ray head node for + ``setup_on_node``. Ray Data controls ``map_batches`` task/actor placement + through Ray's scheduler; this flag does not cap actor-pool size or force + Ray Data workers away from the head node. + """ super().__init__(config, ignore_head_node) def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | None = None) -> list[Task]: @@ -87,7 +94,12 @@ def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | N # TODO: add pipeline level config for verbosity logger.info(f"Processing stage {i + 1}/{len(stages)}: {stage}") logger.info(f" CPU cores: {stage.resources.cpus}, GPU ratio: {stage.resources.gpus}") - current_dataset = self._process_stage_dataset(stage, current_dataset) + + # Create adapter for this stage + adapter = RayDataStageAdapter(stage) + + # Apply stage transformation + current_dataset = adapter.process_dataset(current_dataset) except Exception as e: logger.error(f"Error during pipeline execution: {e}") raise @@ -110,11 +122,6 @@ def execute(self, stages: list["ProcessingStage"], initial_tasks: list[Task] | N ray.shutdown() return output_tasks - def _process_stage_dataset(self, stage: "ProcessingStage", dataset: Dataset) -> Dataset: - """Process one stage as a Ray Data transform.""" - adapter = RayDataStageAdapter(stage) - return adapter.process_dataset(dataset) - def _tasks_to_dataset(self, tasks: list[Task]) -> Dataset: """Convert list of tasks to Ray Data dataset. diff --git a/nemo_curator/backends/ray_data/utils.py b/nemo_curator/backends/ray_data/utils.py index 2b3df48ff5..5c092736a0 100644 --- a/nemo_curator/backends/ray_data/utils.py +++ b/nemo_curator/backends/ray_data/utils.py @@ -52,7 +52,12 @@ def get_configured_actor_pool_sizing_keys(ray_stage_spec: Mapping[str, object]) def get_actor_compute_strategy_for_stage(stage: ProcessingStage) -> ActorPoolStrategy: - """Get the Ray Data actor-pool compute strategy for a processing stage.""" + """Get the Ray Data actor-pool compute strategy for a processing stage. + + Explicit stage ``num_workers`` requests a fixed-size actor pool. Otherwise, + actor stages use Ray Data's autoscaling pool and can optionally override + min/max/initial workers through ``ray_stage_spec``. + """ num_workers = stage.num_workers() if num_workers is not None and num_workers > 0: actor_pool_sizing_keys = get_configured_actor_pool_sizing_keys(stage.ray_stage_spec()) diff --git a/nemo_curator/backends/utils.py b/nemo_curator/backends/utils.py index 2e91f39041..fde76ac57b 100644 --- a/nemo_curator/backends/utils.py +++ b/nemo_curator/backends/utils.py @@ -39,11 +39,13 @@ def _logger_custom_serializer( def _logger_custom_deserializer( _: None, ) -> "loguru.Logger": + # Initialize a default logger return logger def register_loguru_serializer() -> None: - """Register a no-op (de)serializer for loguru (not serializable in general).""" + """Initialize a new local Ray cluster or connects to an existing one.""" + # Turn off serization for loguru. This is needed as loguru is not serializable in general. ray.util.register_serializer( logger.__class__, serializer=_logger_custom_serializer, @@ -52,15 +54,29 @@ def register_loguru_serializer() -> None: def merge_executor_configs(base_config: dict | None, override_config: dict | None) -> dict: - """Recursively deep-merge two executor configs (override wins, inputs untouched). + """ + Recursively merge two executor configs with deep merging of nested dicts. Args: base_config: Base configuration dictionary - override_config: Configuration merged on top of base_config + override_config: Configuration to merge on top of base_config Returns: - Merged config with nested dicts merged recursively + Merged configuration dictionary with all nested dicts recursively merged + + Notes: + - Recursively merges all nested dictionaries + - Non-dict values in override_config will overwrite base_config + - Handles None values gracefully + - Does not modify original inputs (uses deep copy) + + Examples: + >>> base = {"runtime_env": {"env_vars": {"A": "1", "B": "2"}}} + >>> override = {"runtime_env": {"env_vars": {"B": "3", "C": "4"}}} + >>> merge_executor_configs(base, override) + {"runtime_env": {"env_vars": {"A": "1", "B": "3", "C": "4"}}} """ + # Handle None cases if base_config is None and override_config is None: return {} if base_config is None: @@ -68,15 +84,20 @@ def merge_executor_configs(base_config: dict | None, override_config: dict | Non if override_config is None: return deepcopy(base_config) + # Deep copy to avoid modifying originals merged_config = deepcopy(base_config) + # Recursively merge each key from override_config for key, value in override_config.items(): if isinstance(value, dict): if key not in merged_config or not isinstance(merged_config[key], dict): + # If key doesn't exist or isn't a dict, just use the override value merged_config[key] = deepcopy(value) else: + # Recursively merge nested dicts merged_config[key] = merge_executor_configs(merged_config[key], value) else: + # For non-dict values, overwrite merged_config[key] = value return merged_config @@ -141,9 +162,10 @@ def get_available_cpu_gpu_resources( """Get available CPU and GPU resources from Ray.""" if init_and_shutdown: ray.init(ignore_reinit_error=True) - time.sleep(0.2) # ray.available_resources() can lag - # Curator assumes the whole cluster is free (one pipeline at a time), so - # available resources should match total resources. + time.sleep(0.2) # ray.available_resources() returns might have a lag + # available resources can be different from total resources, however curator assumes + # entire cluster is available for use and only one pipeline is being run at a time. + # therefore available resources should match total resources. available_resources = ray.available_resources() available_cpus = available_resources.get("CPU", 0) available_gpus = available_resources.get("GPU", 0) @@ -166,10 +188,12 @@ def get_available_cpu_gpu_resources( def check_total_gpu_capacity(gpus_needed: int, *, ignore_head_node: bool = False) -> None: - """Raise if the cluster lacks enough GPUs for aggregate demand. + """Raise if the cluster doesn't have enough GPUs to satisfy aggregate demand. - Coarse pre-check: Ray's placement-group scheduler can hang on ``pg.ready()`` - when demand exceeds capacity, so fail fast with the actual numbers. + Intended as a coarse pre-check before submitting placement groups: Ray's + PG scheduler can hang indefinitely on ``pg.ready()`` when demand exceeds + capacity, so a fast, explicit error with the actual numbers is friendlier + than waiting on a timeout. """ _, available_gpus = get_available_cpu_gpu_resources(ignore_head_node=ignore_head_node) available = int(available_gpus) @@ -180,11 +204,13 @@ def check_total_gpu_capacity(gpus_needed: int, *, ignore_head_node: bool = False @ray.remote def _setup_stage_on_node(stage: ProcessingStage) -> None: - """Run ``setup_on_node`` for a stage as a Ray task. + """Ray remote function to execute setup_on_node for a stage. - Force vLLM's spawn method: it auto-sets spawn only inside Ray actors, not - tasks, so without this fork would hit "Cannot re-initialize CUDA in forked - subprocess". + This runs as a Ray remote task (not an actor). + vLLM's auto-detection only forces the spawn multiprocessing method inside Ray actors, + not in Ray tasks. Without this override, vLLM defaults to fork in tasks and hits + RuntimeError: Cannot re-initialize CUDA in forked subprocess. + We explicitly set the environment variable to spawn to prevent this. """ os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") node_id = ray.get_runtime_context().get_node_id() @@ -194,9 +220,10 @@ def _setup_stage_on_node(stage: ProcessingStage) -> None: def execute_setup_on_node(stages: list[ProcessingStage], ignore_head_node: bool = False) -> None: """Execute ``setup_on_node`` for every stage on every alive Ray node. - All ``(stage, node)`` tasks are submitted up front and awaited with one - ``ray.get``, so wall-clock time is bounded by the slowest stage (matters when - setup is heavy: model downloads, weight loads). + All ``(stage, node)`` setup tasks are submitted up front and awaited with a single + ``ray.get``, so total wall-clock time is bounded by the slowest stage rather than + the sum of per-stage times — important when setup is heavy (model downloads, weight + loads) and stages don't contend for the same resources. """ head_node_id = get_head_node_id() if ignore_head_node else None for node in ray.nodes(): diff --git a/nemo_curator/backends/xenna/adapter.py b/nemo_curator/backends/xenna/adapter.py index ab9cce5a77..037620621a 100644 --- a/nemo_curator/backends/xenna/adapter.py +++ b/nemo_curator/backends/xenna/adapter.py @@ -20,11 +20,7 @@ from cosmos_xenna.pipelines.private.resources import Resources as XennaResources from cosmos_xenna.pipelines.private.resources import WorkerMetadata as XennaWorkerMetadata -from nemo_curator.backends.base import ( - BaseStageAdapter, - NodeInfo, - WorkerMetadata, -) +from nemo_curator.backends.base import BaseStageAdapter, NodeInfo, WorkerMetadata from nemo_curator.backends.perf_identity import build_xenna_perf_identity, stamp_worker_metadata from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import Task @@ -81,7 +77,7 @@ def required_resources(self) -> XennaResources: def stage_batch_size(self) -> int: """Get the batch size for this stage.""" batch_size = self.processing_stage.batch_size - return 1 if batch_size is None else int(batch_size) + return batch_size if batch_size is not None else 1 @property def env_info(self) -> pipelines_v1.RuntimeEnv | None: @@ -101,6 +97,7 @@ def process_data(self, tasks: list[Task]) -> list[Task] | None: Returns: List of processed tasks or None """ + # Use the base stage's monitoring capability return self.process_batch(tasks) def setup_on_node(self, node_info: XennaNodeInfo, worker_metadata: XennaWorkerMetadata) -> None: @@ -116,7 +113,7 @@ def setup_on_node(self, node_info: XennaNodeInfo, worker_metadata: XennaWorkerMe requires_gpu = bool(getattr(getattr(self.processing_stage, "resources", None), "requires_gpu", False)) generic_worker_metadata = WorkerMetadata( worker_id=worker_metadata.worker_id, - allocation=worker_metadata.allocation, + allocation=worker_metadata.allocation, # Keep the original allocation object ) if bool(getattr(self.processing_stage, "extended_performance_metrics", False)): identity = build_xenna_perf_identity( @@ -140,7 +137,7 @@ def setup(self, worker_metadata: XennaWorkerMetadata) -> None: requires_gpu = bool(getattr(getattr(self.processing_stage, "resources", None), "requires_gpu", False)) generic_worker_metadata = WorkerMetadata( worker_id=worker_metadata.worker_id, - allocation=worker_metadata.allocation, + allocation=worker_metadata.allocation, # Keep the original allocation object ) if bool(getattr(self.processing_stage, "extended_performance_metrics", False)): identity = build_xenna_perf_identity( diff --git a/nemo_curator/backends/xenna/executor.py b/nemo_curator/backends/xenna/executor.py index b9a0808aca..096eb9e375 100644 --- a/nemo_curator/backends/xenna/executor.py +++ b/nemo_curator/backends/xenna/executor.py @@ -19,13 +19,9 @@ from cosmos_xenna.utils.verbosity import VerbosityLevel from loguru import logger -from nemo_curator.backends.base import ( - BaseExecutor, -) +from nemo_curator.backends.base import BaseExecutor from nemo_curator.backends.utils import register_loguru_serializer -from nemo_curator.backends.xenna.adapter import ( - create_named_xenna_stage_adapter, -) +from nemo_curator.backends.xenna.adapter import create_named_xenna_stage_adapter from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import EmptyTask, Task @@ -78,22 +74,49 @@ def execute(self, stages: list[ProcessingStage], initial_tasks: list[Task] | Non Returns: list[Task]: List of output tasks from the pipeline """ - initial_tasks = initial_tasks if initial_tasks else [EmptyTask()] - return self._run_xenna_pipeline(stages, initial_tasks) - - def _run_xenna_pipeline( - self, - stages: list[ProcessingStage], - initial_tasks: list[Any], - ) -> list[Any]: - if not stages: - return initial_tasks - # Convert stages to Xenna stage specs stage_specs = [] + # Initialize with initial tasks if provided, otherwise start with EmptyTask + initial_tasks = initial_tasks if initial_tasks else [EmptyTask()] + for stage in stages: - stage_specs.append(self._build_stage_spec(stage)) + # Get stage configuration + stage_config = stage.xenna_stage_spec() + if "num_workers" in stage_config: + msg = f"Stage {stage.name} sets num_workers in xenna_stage_spec(). Use num_workers() instead." + raise ValueError(msg) + + num_workers = stage.num_workers() + num_workers_per_node = stage_config.get("num_workers_per_node") + if num_workers is not None and num_workers_per_node is not None: + msg = ( + f"Stage {stage.name} sets both num_workers() and " + "xenna_stage_spec()['num_workers_per_node']. Use only one worker sizing option." + ) + raise ValueError(msg) + + # Create Xenna stage adapter with the original stage's name + xenna_stage = create_named_xenna_stage_adapter( + stage=stage, + ) + + # Create stage spec with configuration from stage + stage_spec = pipelines_v1.StageSpec( + stage=xenna_stage, + num_workers=num_workers, + num_workers_per_node=num_workers_per_node, + num_setup_attempts_python=stage_config.get("num_setup_attempts_python"), + num_run_attempts_python=stage_config.get("num_run_attempts_python"), + ignore_failures=stage_config.get("ignore_failures"), + reset_workers_on_failure=stage_config.get("reset_workers_on_failure"), + slots_per_actor=stage_config.get("slots_per_actor"), + worker_max_lifetime_m=stage_config.get("worker_max_lifetime_m"), + worker_restart_interval_m=stage_config.get("worker_restart_interval_m"), + max_setup_failure_percentage=stage_config.get("max_setup_failure_percentage"), + ) + + stage_specs.append(stage_spec) # Determine execution mode exec_mode = pipelines_v1.ExecutionMode.STREAMING @@ -163,44 +186,6 @@ def _run_xenna_pipeline( ray.shutdown() return results if results else [] - def _build_stage_spec(self, stage: ProcessingStage) -> pipelines_v1.StageSpec: - """Create a Xenna StageSpec from a Curator stage.""" - stage_config = stage.xenna_stage_spec() - num_workers, num_workers_per_node = self._resolve_stage_worker_sizing(stage, stage_config) - xenna_stage = create_named_xenna_stage_adapter(stage=stage) - - return pipelines_v1.StageSpec( - stage=xenna_stage, - num_workers=num_workers, - num_workers_per_node=num_workers_per_node, - num_setup_attempts_python=stage_config.get("num_setup_attempts_python"), - num_run_attempts_python=stage_config.get("num_run_attempts_python"), - ignore_failures=stage_config.get("ignore_failures"), - reset_workers_on_failure=stage_config.get("reset_workers_on_failure"), - slots_per_actor=stage_config.get("slots_per_actor"), - worker_max_lifetime_m=stage_config.get("worker_max_lifetime_m"), - worker_restart_interval_m=stage_config.get("worker_restart_interval_m"), - max_setup_failure_percentage=stage_config.get("max_setup_failure_percentage"), - ) - - @staticmethod - def _resolve_stage_worker_sizing( - stage: ProcessingStage, stage_config: dict[str, Any] - ) -> tuple[int | None, int | None]: - """Resolve Xenna worker sizing with the main-branch contract.""" - if "num_workers" in stage_config: - msg = f"Stage {stage.name} sets num_workers in xenna_stage_spec(). Use num_workers() instead." - raise ValueError(msg) - num_workers = stage.num_workers() - num_workers_per_node = stage_config.get("num_workers_per_node") - if num_workers is not None and num_workers_per_node is not None: - msg = ( - f"Stage {stage.name} sets both num_workers() and " - "xenna_stage_spec()['num_workers_per_node']. Use only one worker sizing option." - ) - raise ValueError(msg) - return num_workers, num_workers_per_node - def _get_pipeline_config(self, key: str) -> Any: # noqa: ANN401 """Get configuration value with fallback to defaults.""" return self.config.get(key, self._default_pipeline_config.get(key)) diff --git a/nemo_curator/models/vllm_model.py b/nemo_curator/models/vllm_model.py index b28029840b..9b4f27a83f 100644 --- a/nemo_curator/models/vllm_model.py +++ b/nemo_curator/models/vllm_model.py @@ -12,18 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""vLLM model wrappers. - -- :class:`VLLMBase` - shared engine management (creation, generation, GPU - cleanup). ``_generate`` accepts text prompts *or* multimodal prompt dicts - so audio/vision adapters reuse the same plumbing. -- :class:`VLLMModel` - generic text-generation :class:`ModelInterface`. -""" - -from __future__ import annotations - import gc -import time from typing import Any import torch @@ -64,9 +53,7 @@ def _init_engine(self, model_kwargs: dict[str, Any], sampling_kwargs: dict[str, Args forward to ``vllm.LLM`` / ``vllm.SamplingParams``. Constructor exceptions propagate unchanged, matching the public main behavior. """ - start_time = time.perf_counter() self._llm = LLM(**model_kwargs) - logger.info("vLLM engine loaded in {:.3f}s", time.perf_counter() - start_time) self._sampling_params = SamplingParams(**sampling_kwargs) def _generate(self, prompts: list, *, use_tqdm: bool = False) -> list: @@ -120,15 +107,23 @@ def __init__( # noqa: PLR0913 max_tokens: int | None = None, cache_dir: str | None = None, ): - """Initialize the vLLM model wrapper. + """ + Initialize the vLLM model wrapper. Args: - model: Model identifier (e.g., "microsoft/phi-4"). - max_model_len: Context length; auto-detected from HF AutoConfig - when ``None``. - tensor_parallel_size: TP GPU count; auto-detected when ``None``. - min_p: Min-p sampling (Qwen3 only). - cache_dir: Model weight cache directory. + model: Model identifier (e.g., "microsoft/phi-4") + max_model_len: Maximum model context length. If not specified, + will be auto-detected from HuggingFace AutoConfig. + tensor_parallel_size: Number of GPUs for tensor parallelism. + If not specified, auto-detects available GPUs. + max_num_batched_tokens: Maximum tokens per batch. Defaults to + 4096. + temperature: Sampling temperature. Defaults to 0.7. + top_p: Top-p sampling parameter. Defaults to 0.8. + top_k: Top-k sampling parameter. Defaults to 20. + min_p: Min-p sampling parameter (for Qwen3). Defaults to 0.0. + max_tokens: Maximum tokens to generate. Defaults to None. + cache_dir: Cache directory for model weights. Defaults to None. """ self.model = model self.max_model_len = max_model_len @@ -140,6 +135,8 @@ def __init__( # noqa: PLR0913 self.min_p = min_p self.max_tokens = max_tokens self.cache_dir = cache_dir + self._llm: LLM | None = None + self._sampling_params: SamplingParams | None = None self._final_max_model_len: int | None = None self._is_qwen3: bool = False @@ -154,13 +151,16 @@ def setup(self) -> None: msg = "vLLM is required for VLLMModel. Please install it: pip install vllm" raise ImportError(msg) + # Fetch max_model_len from user param or auto-detect from HuggingFace AutoConfig if self.max_model_len is not None: final_max_model_len = self.max_model_len else: final_max_model_len = get_max_model_len_from_config(self.model, cache_dir=self.cache_dir) + # Set tensor_parallel_size as user param or auto-detect from GPU count final_tp_size = self.tensor_parallel_size if self.tensor_parallel_size is not None else get_gpu_count() + # Set max_num_batched_tokens as user param or use default final_max_batched = self.max_num_batched_tokens llm_kwargs: dict[str, Any] = { @@ -218,15 +218,33 @@ def generate( self, prompts: list[str], ) -> list[str]: - """Generate text from prompt strings (or chat message dicts). + """ + Generate text from prompts. + + Args: + prompts: List of prompt strings or list of message dicts + (for chat template). - Raises ``RuntimeError`` if the model is not set up or generation fails. + Returns: + List of generated text strings. + + Raises: + RuntimeError: If the model is not set up or generation fails. """ if self._llm is None or self._sampling_params is None: msg = "Model not initialized. Call setup() first." raise RuntimeError(msg) - outputs = self._generate(prompts) - return [out.outputs[0].text if out.outputs else "" for out in outputs] + + try: + outputs = self._llm.generate( + prompts, + sampling_params=self._sampling_params, + use_tqdm=False, + ) + return [out.outputs[0].text if out.outputs else "" for out in outputs] + except (RuntimeError, ValueError, TypeError) as e: + msg = f"Error generating text: {e}" + raise RuntimeError(msg) from e def get_tokenizer(self) -> Any: # noqa: ANN401 """Get the tokenizer from the LLM instance.""" diff --git a/nemo_curator/pipeline/payload_refs.py b/nemo_curator/pipeline/payload_refs.py index 5ca762e269..00299a82ce 100644 --- a/nemo_curator/pipeline/payload_refs.py +++ b/nemo_curator/pipeline/payload_refs.py @@ -46,36 +46,11 @@ def _get_named_actor(name: str, namespace: str | None = None) -> Any: return ray.get_actor(name) -def resolve_payload_ref(payload_ref: PayloadRef) -> Any: - heartbeat_payload_ref(payload_ref) - store = _get_named_actor(payload_ref.store_actor_name, payload_ref.actor_namespace) - return _ray_get(store.get.remote(payload_ref.payload_id, payload_ref.lease_ttl_s)) - - -def heartbeat_payload_ref(payload_ref: PayloadRef) -> None: - admission = _get_named_actor(payload_ref.admission_actor_name, payload_ref.actor_namespace) - if not _ray_get( - admission.heartbeat.remote( - payload_ref.owner_node_id, - payload_ref.payload_id, - payload_ref.lease_ttl_s, - ) - ): - raise KeyError( - f"Payload admission lease {payload_ref.payload_id} is no longer present in " - f"{payload_ref.admission_actor_name}" - ) - store = _get_named_actor(payload_ref.store_actor_name, payload_ref.actor_namespace) - if not _ray_get(store.pin.remote(payload_ref.payload_id, payload_ref.lease_ttl_s)): - raise KeyError(f"Payload {payload_ref.payload_id} is no longer present in {payload_ref.store_actor_name}") - - def heartbeat_payload_refs_batched(payload_refs: Sequence[PayloadRef]) -> None: """Refresh payload leases with one RPC per admission/store actor. - The singular :func:`heartbeat_payload_ref` contract remains unchanged for - existing callers. This opt-in batched path is used by payload-aware stages - that know their actors provide ``heartbeat_many`` and ``pin_many``. + Payload-aware stages use this path when their actors provide + ``heartbeat_many`` and ``pin_many``. """ refs = _unique_payload_refs(payload_refs) if not refs: diff --git a/nemo_curator/pipeline/pipeline.py b/nemo_curator/pipeline/pipeline.py index 90e9059699..dc116e5fec 100644 --- a/nemo_curator/pipeline/pipeline.py +++ b/nemo_curator/pipeline/pipeline.py @@ -25,7 +25,11 @@ def assign_root_task_ids(initial_tasks: list[Task]) -> list[Task]: """Assign root ``task_id``s to user-provided initial tasks. - Every non-sentinel task is rooted under ``"0"`` exactly as on main. + Every task in a run descends from the implicit root ``"0"`` (the id of + :class:`EmptyTask`). User-provided initial tasks are its direct + children, so they get ``"0_0"``, ``"0_1"``, … ``EmptyTask`` instances + are skipped (already ``"0"``). All downstream ``task_id`` assignment + happens in ``BaseStageAdapter``. NOTE: we deliberately use the positional index here, NOT ``get_deterministic_id()``, even for content-bearing tasks like @@ -237,8 +241,10 @@ def _decompose_stages( sub_stages = stage.decompose_and_apply_with() if isinstance(stage, CompositeStage) else [stage] if len(sub_stages) > 1: + # This was a composite stage logger.info(f"Decomposing composite stage: {stage.name}") + # Validate that decomposed stages are not composite for sub_stage in sub_stages: if isinstance(sub_stage, CompositeStage) and len(sub_stage.decompose()) > 1: msg = ( diff --git a/nemo_curator/utils/performance_utils.py b/nemo_curator/utils/performance_utils.py index 7e815c487f..ac55611361 100644 --- a/nemo_curator/utils/performance_utils.py +++ b/nemo_curator/utils/performance_utils.py @@ -17,7 +17,7 @@ import contextlib import statistics import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import attrs from loguru import logger @@ -28,23 +28,6 @@ from nemo_curator.stages.base import ProcessingStage -#: Identity fields on ``StagePerfStats``: best-effort string labels for the -#: actor/node/GPU that produced the record. They are metadata, not numeric -#: metrics, so they MUST be excluded from ``items()`` -- downstream collection -#: calls ``float()`` on every yielded value and would crash on a string. -_IDENTITY_FIELDS = ( - "invocation_id", - "actor_id", - "node_id", - "gpu_id", - "physical_address", - "pod_ip", - "hostname", - "gpu_indices", - "gpu_uuids", -) - - @attrs.define class StagePerfStats: """Statistics for tracking stage performance metrics. @@ -72,7 +55,7 @@ class StagePerfStats: input_data_size_mb: float = 0.0 num_items_processed: int = 0 custom_metrics: dict[str, float] = attrs.field(factory=dict) - # identity (metadata, never a numeric metric -- see _IDENTITY_FIELDS) + # identity metadata invocation_id: str = "" actor_id: str = "" node_id: str = "" @@ -153,20 +136,13 @@ def to_dict(self) -> dict[str, float | int]: "custom_metrics": dict(self.custom_metrics), } - def to_extended_dict(self) -> dict[str, Any]: - """Convert to the complete observability schema, including identity.""" - return attrs.asdict(self) - def items(self) -> list[tuple[str, float | int]]: """Returns (metric_name, metric_value) pairs custom_metrics are flattened into the format (custom., metric_value) """ res = self.to_dict() res.pop("stage_name", None) - # Identity fields are string metadata; downstream collectors call float() - # on every yielded value, so they MUST be dropped here. - for identity_field in _IDENTITY_FIELDS: - res.pop(identity_field, None) + # Extract and drop the raw custom_metrics dict from the flattened output custom_metrics = res.pop("custom_metrics", {}) # Flatten custom_metrics with a stable prefix for key, value in custom_metrics.items(): @@ -175,7 +151,9 @@ def items(self) -> list[tuple[str, float | int]]: class StageTimer: - """Tracks processing time and other metrics per process_data call.""" + """Tracker for stage performance stats. + Tracks processing time and other metrics at a per process_data call level. + """ def __init__(self, stage: ProcessingStage) -> None: """Initialize the stage timer. @@ -199,6 +177,7 @@ def _reset(self) -> None: def reinit(self, stage_input_size: int = 1) -> None: """Reinitialize the stage timer. Args: + stage: The stage to reinitialize the timer for. stage_input_size: The size of the stage input. """ self._reset() diff --git a/tests/backends/ray_data/test_utils.py b/tests/backends/ray_data/test_utils.py index 1fc184705a..66903d7abd 100644 --- a/tests/backends/ray_data/test_utils.py +++ b/tests/backends/ray_data/test_utils.py @@ -12,106 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable from unittest.mock import MagicMock, Mock, patch import numpy as np import pytest from ray.data import ActorPoolStrategy -from nemo_curator.backends.ray_data.adapter import RayDataStageAdapter -from nemo_curator.backends.ray_data.executor import RayDataExecutor from nemo_curator.backends.ray_data.utils import ( coerce_batch_tasks, get_actor_compute_strategy_for_stage, ) from nemo_curator.backends.utils import RayStageSpecKeys, get_available_cpu_gpu_resources -from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy -from nemo_curator.stages.base import ProcessingStage -from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import AudioTask from tests.backends.test_utils import reset_head_node_cache # noqa: F401 -class _PreplannedEchoStage(ProcessingStage[AudioTask, AudioTask]): - name = "preplanned_echo" - resources = Resources(cpus=1.0) - batch_size = 99 - - def process(self, task: AudioTask) -> AudioTask: - return task - - def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: - for task in tasks: - task.data["seen_batch_size"] = len(tasks) - return tasks - - -class _CentralizedPlanningStage(ProcessingStage[AudioTask, AudioTask]): - name = "centralized_planning" - resources = Resources(cpus=1.0) - batch_size = 2 - - def __init__(self) -> None: - self.batch_policy = BatchPolicy( - buckets_sec=[0, 30, 1200], - max_items_per_batch_by_bucket=[2, 1, 1], - max_audio_sec_per_batch=None, - ) - - def process(self, task: AudioTask) -> AudioTask: - return task - - def build_prebucketed_tasks(self, tasks: list[AudioTask]) -> list[AudioTask]: - return list(tasks) - - def scheduler_task_cost(self, task: AudioTask) -> float: - return float(task.data["duration"]) - - def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: - for task in tasks: - task.data["processed_batch_size"] = len(tasks) - return tasks - - def assemble_prebucketed_task_results( - self, - tasks: list[AudioTask], - _processed_tasks: list[AudioTask], - ) -> list[AudioTask]: - return list(tasks) - - -class _FakeDataset: - def __init__(self, sample_items: list[object] | None = None) -> None: - self.repartition_calls: list[tuple[tuple, dict]] = [] - self.map_batches_calls: list[tuple[object, dict]] = [] - self.sample_output: dict | None = None - self.sample_items = sample_items - - def repartition(self, *args, **kwargs) -> "_FakeDataset": - self.repartition_calls.append((args, kwargs)) - return self - - def map_batches(self, fn: Callable[[dict[str, object]], dict[str, object]], **kwargs) -> "_FakeDataset": - self.map_batches_calls.append((fn, kwargs)) - sample_items = self.sample_items - if sample_items is None: - first = AudioTask(data={"duration": 5.0}) - second = AudioTask(data={"duration": 600.0}) - sample_items = [first, second] - self.sample_output = fn({"item": sample_items}) - return self - - class TestGetAvailableCpuGpuResources: # TODO: Move this to tests/backends/test_utils.py """Test class for utility functions in ray_data backend.""" def test_get_available_cpu_gpu_resources_conftest(self, shared_ray_client: None): """Test get_available_cpu_gpu_resources function.""" + # Test with Ray resources from conftest.py cpus, gpus = get_available_cpu_gpu_resources() assert cpus == 11 - # GPU count depends on local hardware and whether GPU tests are selected. + # GPU count depends on whether GPU tests are running in this session + # and on how many GPUs the test host exposes to the fixture. assert gpus in [0.0, 1.0, 2.0] @pytest.mark.usefixtures("reset_head_node_cache") @@ -119,7 +44,8 @@ def test_get_resources_with_ignore_head_node( self, shared_ray_client: None, ): - """ignore_head_node=True skips the head node; running on the head node, resources are 0.""" + """Test get_available_cpu_gpu_resources with ignore_head_node=True to skip head node. + Since this test is run with the head node, the resources should be 0.""" cpus_without_head, gpus_without_head = get_available_cpu_gpu_resources(ignore_head_node=True) assert cpus_without_head == 0 assert gpus_without_head == 0 @@ -174,7 +100,7 @@ def test_actor_compute_strategy( ray_stage_spec: dict[str, object], expected: ActorPoolStrategy, expected_warning: str | None, - ) -> None: + ): mock_stage = Mock(num_workers=lambda: num_workers, ray_stage_spec=lambda: ray_stage_spec) mock_stage.name = "stage" @@ -187,7 +113,7 @@ def test_actor_compute_strategy( mock_warning.assert_called_once() assert expected_warning in mock_warning.call_args.args[0] - def test_actor_compute_strategy_rejects_invalid_sizing(self) -> None: + def test_actor_compute_strategy_rejects_invalid_sizing(self): mock_stage = Mock( num_workers=lambda: None, ray_stage_spec=lambda: { @@ -212,78 +138,3 @@ def test_coerce_batch_tasks_empty(self) -> None: assert coerce_batch_tasks([]) == [] assert coerce_batch_tasks(np.array([], dtype=object)) == [] assert coerce_batch_tasks(None) == [] - - -def test_ray_data_adapter_passes_backend_batch_to_stage_process_batch() -> None: - dataset = _FakeDataset() - adapter = RayDataStageAdapter(_CentralizedPlanningStage()) - - out = adapter.process_dataset(dataset) - - assert out is dataset - assert dataset.repartition_calls == [] - assert len(dataset.map_batches_calls) == 1 - assert dataset.map_batches_calls[0][1]["batch_size"] == 2 - assert dataset.sample_output is not None - processed_durations = [task.data["duration"] for task in dataset.sample_output["item"]] - processed_batch_sizes = [task.data["processed_batch_size"] for task in dataset.sample_output["item"]] - assert processed_durations == [5.0, 600.0] - assert processed_batch_sizes == [2, 2] - - -def test_ray_data_executor_keeps_centralized_stage_in_ray_data( - monkeypatch: pytest.MonkeyPatch, -) -> None: - executor = RayDataExecutor(ignore_head_node=True) - stage = _CentralizedPlanningStage() - input_dataset = object() - output_dataset = object() - calls: dict[str, object] = {} - - class FakeRayDataStageAdapter: - def __init__(self, stage_arg: ProcessingStage) -> None: - calls["stage"] = stage_arg - - def process_dataset(self, dataset_arg: object) -> object: - calls["dataset"] = dataset_arg - return output_dataset - - monkeypatch.setattr( - executor, - "_dataset_to_tasks", - Mock(side_effect=AssertionError("centralized stages should not materialize Ray Data datasets")), - ) - monkeypatch.setattr( - executor, - "_tasks_to_dataset", - Mock(side_effect=AssertionError("centralized stages should not rebuild Ray Data datasets")), - ) - monkeypatch.setattr("nemo_curator.backends.ray_data.executor.RayDataStageAdapter", FakeRayDataStageAdapter) - - out = executor._process_stage_dataset(stage, input_dataset) - - assert out is output_dataset - assert calls == {"stage": stage, "dataset": input_dataset} - - -def test_ray_data_executor_keeps_noncentral_stage_in_ray_data(monkeypatch: pytest.MonkeyPatch) -> None: - executor = RayDataExecutor(ignore_head_node=True) - stage = _PreplannedEchoStage() - input_dataset = object() - output_dataset = object() - calls: dict[str, object] = {} - - class FakeRayDataStageAdapter: - def __init__(self, stage_arg: ProcessingStage) -> None: - calls["stage"] = stage_arg - - def process_dataset(self, dataset_arg: object) -> object: - calls["dataset"] = dataset_arg - return output_dataset - - monkeypatch.setattr("nemo_curator.backends.ray_data.executor.RayDataStageAdapter", FakeRayDataStageAdapter) - - out = executor._process_stage_dataset(stage, input_dataset) - - assert out is output_dataset - assert calls == {"stage": stage, "dataset": input_dataset} diff --git a/tests/backends/test_task_id_postprocess.py b/tests/backends/test_task_id_postprocess.py index d685982a5b..6c9a37bbf8 100644 --- a/tests/backends/test_task_id_postprocess.py +++ b/tests/backends/test_task_id_postprocess.py @@ -18,8 +18,8 @@ end-to-end against real backends in tests/backends/test_integration.py (``test_task_ids``). This file keeps only the cases that are awkward or impossible to trigger through a real pipeline: filter-``None`` positional -alignment, the ambiguous-cardinality ``"r"``-uuid fallback, preservation of - framework overwrite semantics, and source content-id selection.""" +alignment, the ambiguous-cardinality ``"r"``-uuid fallback, in-place +re-derivation, and source content-id vs. positional-index selection.""" from dataclasses import dataclass @@ -105,6 +105,8 @@ def test_filter_stage_keeps_positional_alignment(self) -> None: assert c2.task_id == "0_2_0" # child of p2, not p1 def test_in_place_return_is_reassigned(self) -> None: + # A 1:1 stage that returns its input unchanged still gets a fresh + # segment appended (ids are re-derived at each stage boundary). t = _task("0_5") out = _assign([t], [t]) assert out == [t] @@ -201,8 +203,8 @@ def test_dropped_generic_terminal_row_is_preserved_as_tombstone(self) -> None: class TestSourceStage: def test_uses_content_id_rooted_at_input(self) -> None: - # FileGroupTask.get_deterministic_id() hashes its files; output with no - # content ids are rooted at the framework EmptyTask id. + # FileGroupTask.get_deterministic_id() hashes its files; the source + # output is rooted at the EmptyTask input id "0" → "0_". empty = EmptyTask(dataset_name="empty", data=None) a = FileGroupTask(dataset_name="d", data=["a.parquet"]) b = FileGroupTask(dataset_name="d", data=["b.parquet"]) diff --git a/tests/backends/test_utils.py b/tests/backends/test_utils.py index 8bde59fc32..71aa8ce4b7 100644 --- a/tests/backends/test_utils.py +++ b/tests/backends/test_utils.py @@ -65,9 +65,11 @@ def test_merge_nested_dicts(self): result = merge_executor_configs(base, override) + # Check that nested dicts are merged assert result["runtime_env"]["env_vars"]["A"] == "1" assert result["runtime_env"]["env_vars"]["B"] == "3" assert result["runtime_env"]["env_vars"]["C"] == "4" + # Check that other keys are preserved assert result["runtime_env"]["pip"] == ["package1"] assert result["runtime_env"]["working_dir"] == "." assert result["other_config"] == "value1" @@ -128,9 +130,11 @@ def setup_on_node( stage1 = MockStage1() stage2 = MockStage1().with_(name="mock_stage_2", resources=Resources(cpus=0.5, gpus=0.0)) + # Test execute_setup_on_node([stage1, stage2]) - # Verify NodeInfo / WorkerMetadata were passed via the per-call files. + # Check the files written to the temp directory + # Verify that NodeInfo and WorkerMetadata were passed correctly for stage_name in ["mock_stage_1", "mock_stage_2"]: stage_files = list(tmp_path.glob(f"{stage_name}_*.txt")) assert len(stage_files) == len(ray.nodes()), ( @@ -149,7 +153,7 @@ def setup_on_node( f"Expected node IDs to be the same as the Ray nodes, got {node_ids}" ) - # Log records starting with "Executing setup on node" and ending with "for 2 stages". + # Check that there are exactly two log records that start with "Executing setup on node" and end with "for 2 stages" matching_logs = [ record.message for record in caplog.records @@ -224,6 +228,7 @@ def setup_on_node( stage = MockStage1() + # Test with ignore_head_node=True execute_setup_on_node([stage], ignore_head_node=True) # Verify the cache variable is set directly (not using the lazy function) diff --git a/tests/backends/test_xenna_executor.py b/tests/backends/test_xenna_executor.py index 89885e52d3..4d06ba656a 100644 --- a/tests/backends/test_xenna_executor.py +++ b/tests/backends/test_xenna_executor.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +from cosmos_xenna.utils.verbosity import VerbosityLevel from nemo_curator.backends.xenna import executor as xenna_executor from nemo_curator.backends.xenna.executor import XennaExecutor @@ -43,6 +44,19 @@ def process(self, task: EmptyTask) -> EmptyTask: return task +def test_xenna_verbosity_none_uses_default() -> None: + executor = XennaExecutor(config={"actor_pool_verbosity_level": None}) + + assert executor._get_verbosity_config("actor_pool_verbosity_level") is VerbosityLevel.INFO + + +def test_xenna_verbosity_bad_string_has_helpful_error() -> None: + executor = XennaExecutor(config={"actor_pool_verbosity_level": "loud"}) + + with pytest.raises(ValueError, match="Invalid Xenna verbosity config actor_pool_verbosity_level='loud'"): + executor._get_verbosity_config("actor_pool_verbosity_level") + + def test_xenna_executor_uses_stage_num_workers_when_xenna_spec_has_no_worker_sizing( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tests/backends/xenna/__init__.py b/tests/backends/xenna/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/tests/backends/xenna/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/backends/xenna/test_executor.py b/tests/backends/xenna/test_executor.py deleted file mode 100644 index 96c9ee6450..0000000000 --- a/tests/backends/xenna/test_executor.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import pytest -from cosmos_xenna.utils.verbosity import VerbosityLevel - -from nemo_curator.backends.xenna.executor import XennaExecutor -from nemo_curator.stages.audio.common import ManifestWriterStage -from nemo_curator.stages.audio.inference.asr.stage import ASRStage -from nemo_curator.stages.audio.inference.batch_policy import BatchPolicy -from nemo_curator.stages.audio.io.audio_file_reader import AudioFileReaderStage -from nemo_curator.stages.base import ProcessingStage -from nemo_curator.stages.resources import Resources -from nemo_curator.tasks import AudioTask, Task - - -class _PassthroughStage(ProcessingStage[AudioTask, AudioTask]): - name = "passthrough" - resources = Resources(cpus=1.0) - batch_size = 2 - - def process(self, task: AudioTask) -> AudioTask: - return task - - -class _CentralizedStage(_PassthroughStage): - name = "centralized" - - def __init__(self) -> None: - self.batch_policy = BatchPolicy( - buckets_sec=[0, 30, 1200], - max_items_per_batch_by_bucket=[2, 1, 1], - max_audio_sec_per_batch=None, - ) - - def build_prebucketed_tasks(self, tasks: list[AudioTask]) -> list[AudioTask]: - return list(tasks) - - def scheduler_task_cost(self, task: AudioTask) -> float: - return float(task.data.get("duration", 0.0)) - - def assemble_prebucketed_task_results( - self, - _tasks: list[AudioTask], - processed_tasks: list[AudioTask], - ) -> list[AudioTask]: - return processed_tasks - - -class _WorkerSizedStage(_PassthroughStage): - name = "worker_sized" - - def __init__(self, workers: int | None = None, stage_spec: dict[str, Any] | None = None) -> None: - self._workers = workers - self._stage_spec = stage_spec or {} - - def num_workers(self) -> int | None: - return self._workers - - def xenna_stage_spec(self) -> dict[str, Any]: - return dict(self._stage_spec) - - -def test_xenna_executor_keeps_centralized_stage_inside_one_pipeline(monkeypatch) -> None: # noqa: ANN001 - executor = XennaExecutor() - stages: list[ProcessingStage[Any, Any]] = [ - _PassthroughStage(), - _CentralizedStage(), - _PassthroughStage(), - ] - initial_tasks = [AudioTask(data={"duration": 5.0})] - calls: list[tuple[list[ProcessingStage[Any, Any]], list[Task]]] = [] - - def fake_run_xenna_pipeline( - stages_arg: list[ProcessingStage[Any, Any]], - initial_tasks_arg: list[Task], - ) -> list[Task]: - calls.append((stages_arg, initial_tasks_arg)) - return initial_tasks_arg - - monkeypatch.setattr(executor, "_run_xenna_pipeline", fake_run_xenna_pipeline) - - out = executor.execute(stages, initial_tasks) - - assert out == initial_tasks - assert calls == [(stages, initial_tasks)] - - -def test_xenna_verbosity_none_uses_default() -> None: - executor = XennaExecutor(config={"actor_pool_verbosity_level": None}) - - assert executor._get_verbosity_config("actor_pool_verbosity_level") is VerbosityLevel.INFO - - -def test_xenna_verbosity_bad_string_has_helpful_error() -> None: - executor = XennaExecutor(config={"actor_pool_verbosity_level": "loud"}) - - with pytest.raises(ValueError, match="Invalid Xenna verbosity config actor_pool_verbosity_level='loud'"): - executor._get_verbosity_config("actor_pool_verbosity_level") - - -def test_xenna_stage_spec_falls_back_to_stage_num_workers() -> None: - stage_spec = XennaExecutor()._build_stage_spec(_WorkerSizedStage(workers=3)) - - assert stage_spec.num_workers == 3 - assert stage_spec.num_workers_per_node is None - - -def test_real_audio_stages_use_main_worker_sizing_contract(tmp_path) -> None: # noqa: ANN001 - executor = XennaExecutor() - - asr_spec = executor._build_stage_spec( - ASRStage( - adapter_target="tests.fake.Adapter", - model_id="fake-model", - xenna_num_workers=2, - ) - ) - reader_spec = executor._build_stage_spec(AudioFileReaderStage(xenna_num_workers=3)) - writer_spec = executor._build_stage_spec(ManifestWriterStage(output_path=str(tmp_path / "out.jsonl"))) - - assert asr_spec.num_workers == 2 - assert asr_spec.num_workers_per_node is None - assert reader_spec.num_workers == 3 - assert reader_spec.num_workers_per_node is None - assert writer_spec.num_workers == 1 - assert writer_spec.num_workers_per_node is None - - -def test_xenna_stage_spec_num_workers_is_rejected() -> None: - with pytest.raises(ValueError, match=r"Use num_workers\(\) instead"): - XennaExecutor()._build_stage_spec(_WorkerSizedStage(stage_spec={"num_workers": 4})) - - -def test_xenna_num_workers_per_node_is_rejected_with_stage_num_workers() -> None: - with pytest.raises(ValueError, match=r"num_workers\(\).*num_workers_per_node"): - XennaExecutor()._build_stage_spec(_WorkerSizedStage(workers=3, stage_spec={"num_workers_per_node": 2})) - - -def test_xenna_num_workers_per_node_is_rejected_with_legacy_num_workers() -> None: - stage = _WorkerSizedStage(stage_spec={"num_workers": 4, "num_workers_per_node": 2}) - - with pytest.raises(ValueError, match=r"Use num_workers\(\) instead"): - XennaExecutor()._build_stage_spec(stage) - - -def test_xenna_rejects_conflicting_cluster_worker_counts() -> None: - stage = _WorkerSizedStage(workers=3, stage_spec={"num_workers": 4}) - - with pytest.raises(ValueError, match=r"Use num_workers\(\) instead"): - XennaExecutor()._build_stage_spec(stage) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 66194ed148..33a3e9ded4 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -183,7 +183,8 @@ def test_single_stage_composite_preserves_main_behavior(self) -> None: class TestRootTaskIds: - """``assign_root_task_ids`` follows the framework-owned main contract.""" + """``assign_root_task_ids`` roots user-provided initial tasks under the + implicit ``EmptyTask`` root id ``"0"``.""" def test_empty_task_id_is_zero(self) -> None: assert EmptyTask().task_id == "0" @@ -199,11 +200,13 @@ def test_rewrites_existing_internal_task_ids(self) -> None: def test_roots_user_tasks_at_zero(self) -> None: tasks = [_SimpleTask(dataset_name="d", data=[1]) for _ in range(3)] assign_root_task_ids(tasks) + # User-provided initial tasks are children of root "0", by position. assert [t.task_id for t in tasks] == ["0_0", "0_1", "0_2"] def test_skips_empty_tasks(self) -> None: et = EmptyTask(dataset_name="d", data=None) real = _SimpleTask(dataset_name="d", data=[1]) assign_root_task_ids([et, real]) + # EmptyTask stays "0"; the real task is rooted by its position. assert et.task_id == "0" assert real.task_id == "0_1"