Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Tuple, Union

import lhotse
import numpy as np
import torch
from lhotse import CutSet, RecordingSet
from lhotse.cut import Cut
from lhotse.dataset import (
ClippingTransform,
Compress,
CutConcatenate,
DynamicBucketingSampler,
DynamicCutSampler,
IterableDatasetWrapper,
LowpassUsingResampling,
ReverbWithImpulseResponse,
RoundRobinSampler,
ZipSampler,
Expand Down Expand Up @@ -176,6 +180,27 @@ class LhotseDataLoadingConfig:
# f. Padding to a minimum duration. Examples shorter than this will be padded, others are unaffected.
pad_min_duration: Optional[float] = None
pad_direction: str = "right" # "right" | "left" | "both" | "random"
# g. Bandwidth limitation via back-and-forth resampling
lowpass_enabled: bool = False
lowpass_frequencies_interval: Tuple[float, float] = (3500.0, 8000.0)
lowpass_prob: float = 0.5
# h. Lossy compression augmentation (opus, mp3, vorbis, gsm)
# implemented via soundfile, so compression level is specified via number in [0.0, 1.0]
# 0.0 denotes the highest bitrate and denotes the lowest bitrate for a given codec
# overall, parameters mirror lhotse interface
compression_enabled: bool = False
compression_prob: float = 0.5
compression_level_interval: Tuple[float, float] = (0.8, 0.99)
compression_codecs: Tuple[str] = ("opus",)
compression_codec_weights: Optional[List[float]] = None
compression_enable_for_custom_fields: bool = False
# i. Clipping/saturation augmentation
clipping_enabled: bool = False
clipping_gain_db: Tuple[float, float] = (0.0, 24.0)
clipping_normalize: bool = True
clipping_oversampling: Optional[int] = 2
clipping_prob_hard: float = 0.5
clipping_prob: float = 0.5

# 5. Other Lhotse options.
text_field: str = "text" # key to read the transcript from
Expand Down Expand Up @@ -622,6 +647,31 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No
if config.concatenate_merge_supervisions:
sampler = sampler.map(_merge_supervisions)

if config.lowpass_enabled:
if lhotse.get_current_resampling_backend() != "libsox":
logging.warning(
"Lowpass augmentation works best with libsox backend. Consider setting resamping backend in Lhotse to libsox."
)
sampler = sampler.map(
LowpassUsingResampling(
frequencies_interval=OmegaConf.to_container(config.lowpass_frequencies_interval),
p=config.lowpass_prob,
seed=config.shard_seed,
)
)

if config.clipping_enabled:
sampler = sampler.map(
ClippingTransform(
gain_db=OmegaConf.to_container(config.clipping_gain_db),
normalize=config.clipping_normalize,
p=config.clipping_prob,
p_hard=config.clipping_prob_hard,
oversampling=config.clipping_oversampling,
seed=config.shard_seed,
)
)

if config.rir_enabled:
sampler = sampler.map(
ReverbWithImpulseResponse(
Expand All @@ -631,6 +681,22 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No
)
)

if config.compression_enabled:
sampler = sampler.map(
Compress(
codecs=OmegaConf.to_container(config.compression_codecs),
p=config.compression_prob,
compression_level=OmegaConf.to_container(config.compression_level_interval),
codec_weights=(
OmegaConf.to_container(config.compression_codec_weights)
if config.compression_codec_weights
else config.compression_codec_weights
),
compress_custom_fields=config.compression_enable_for_custom_fields,
seed=config.shard_seed,
)
)

return sampler, use_iterable_dataset


Expand Down
192 changes: 191 additions & 1 deletion tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from io import BytesIO
from itertools import islice
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import lhotse
import numpy as np
Expand Down Expand Up @@ -2135,6 +2135,196 @@
assert len(b) == 2


@pytest.mark.parametrize("clipping_prob_hard", [0.0, 1.0])
@pytest.mark.parametrize("clipping_oversampling", [None, 1, 2])
def test_dataloader_with_clipping_lhotse_jsonl(
cutset_path: Path, clipping_prob_hard: float, clipping_oversampling: Optional[int]
):
from lhotse.augmentation import Clipping, Resample

config = OmegaConf.create(
{
"cuts_path": str(cutset_path),
"clipping_enabled": True,
"clipping_gain_db": (0.0, 24.0),
"clipping_prob": 1.0,
"clipping_prob_hard": clipping_prob_hard,
"clipping_oversampling": clipping_oversampling,
"batch_size": 2,
"seed": 0,
"shard_seed": 0,
}
)
dl = get_lhotse_dataloader_from_config(
config=config,
global_rank=0,
world_size=1,
dataset=Identity(),
)
batch = next(iter(dl))
assert isinstance(batch, CutSet)
assert len(batch) == 2
cut = batch[0]
assert isinstance(cut, MonoCut)
if clipping_oversampling is not None:
assert isinstance(cut.recording.transforms[-3], Resample)
assert isinstance(cut.recording.transforms[-2], Clipping)
assert isinstance(cut.recording.transforms[-1], Resample)
else:
assert isinstance(cut.recording.transforms[-1], Clipping)
cut = batch[1]
assert isinstance(cut, MonoCut)
if clipping_oversampling is not None:
assert isinstance(cut.recording.transforms[-3], Resample)
assert isinstance(cut.recording.transforms[-2], Clipping)
assert isinstance(cut.recording.transforms[-1], Resample)
else:
assert isinstance(cut.recording.transforms[-1], Clipping)
for cut in batch:
cut.load_audio()


def test_dataloader_with_compression_lossy_lhotse_jsonl(cutset_path: Path):
from lhotse.augmentation import Compress

config = OmegaConf.create(
{
"cuts_path": str(cutset_path),
"compression_enabled": True,
"compression_codecs": ["opus", "mp3", "vorbis"],
"compression_prob": 1.0,
"batch_size": 4,
"seed": 0,
"shard_seed": 0,
}
)
dl = get_lhotse_dataloader_from_config(
config=config,
global_rank=0,
world_size=1,
dataset=Identity(),
)
batch = next(iter(dl))
assert isinstance(batch, CutSet)
assert len(batch) == 4
cut = batch[0]
assert isinstance(cut, MonoCut)
assert isinstance(cut.recording.transforms[-1], Compress)
cut = batch[1]
assert isinstance(cut, MonoCut)
assert isinstance(cut.recording.transforms[-1], Compress)
for cut in batch:
cut.load_audio()


def test_dataloader_with_compression_gsm_lhotse_jsonl(cutset_path: Path):
from lhotse.augmentation import Compress, Resample

config = OmegaConf.create(
{
"cuts_path": str(cutset_path),
"compression_enabled": True,
"compression_codecs": ["gsm"],
"compression_prob": 1.0,
"batch_size": 4,
"seed": 0,
"shard_seed": 0,
}
)
dl = get_lhotse_dataloader_from_config(
config=config,
global_rank=0,
world_size=1,
dataset=Identity(),
)
batch = next(iter(dl))
assert isinstance(batch, CutSet)
assert len(batch) == 4
cut = batch[0]
assert isinstance(cut, MonoCut)
assert isinstance(cut.recording.transforms[-3], Resample)
assert isinstance(cut.recording.transforms[-2], Compress)
assert isinstance(cut.recording.transforms[-1], Resample)
for cut in batch:
cut.load_audio()


def test_dataloader_with_lowpass_using_resampling_lhotse_jsonl(cutset_path: Path):
from lhotse.augmentation import Resample

config = OmegaConf.create(
{
"cuts_path": str(cutset_path),
"lowpass_enabled": True,
"lowpass_frequencies_interval": [3500.0, 4000.0],
"lowpass_prob": 1.0,
"batch_size": 4,
"seed": 0,
"shard_seed": 0,
}
)
dl = get_lhotse_dataloader_from_config(
config=config,
global_rank=0,
world_size=1,
dataset=Identity(),
)
batch = next(iter(dl))
assert isinstance(batch, CutSet)
assert len(batch) == 4
cut = batch[0]
assert isinstance(cut, MonoCut)
assert isinstance(cut.recording.transforms[-2], Resample)
assert isinstance(cut.recording.transforms[-1], Resample)
cut = batch[1]
assert isinstance(cut, MonoCut)
assert isinstance(cut.recording.transforms[-2], Resample)
assert isinstance(cut.recording.transforms[-1], Resample)
for cut in batch:
cut.load_audio()


def test_dataloader_with_multiple_augmentations_lhotse_jsonl(cutset_path: Path):
from lhotse.augmentation import Compress, Resample, ReverbWithImpulseResponse

config = OmegaConf.create(
{
"cuts_path": str(cutset_path),
"noise_path": str(cutset_path),
"noise_mix_prob": 1.0,
"noise_snr": [-5.0, 5.0],
"rir_enabled": True,
"rir_prob": 1.0,
"lowpass_enabled": True,
"lowpass_frequencies_interval": [3500.0, 4000.0],
"lowpass_prob": 1.0,
"compression_enabled": True,
"compression_codecs": ["gsm"],
"compression_prob": 1.0,
"batch_size": 4,
"seed": 0,
"shard_seed": 0,
}
)
dl = get_lhotse_dataloader_from_config(
config=config,
global_rank=0,
world_size=1,
dataset=Identity(),
)
batch = next(iter(dl))
assert isinstance(batch, CutSet)
assert len(batch) == 4
cut = batch[0]
assert isinstance(cut, MixedCut)
for track in cut.tracks:

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable audio is not used.
assert isinstance(track.cut.recording.transforms[-3], Resample)
assert isinstance(track.cut.recording.transforms[-2], Compress)
assert isinstance(track.cut.recording.transforms[-1], Resample)
for cut in batch:
audio = cut.load_audio()


def test_dataloader_2d_bucketing(nemo_tarred_manifest_path_multi: tuple[str, str], en_es_tokenizer):
json_mft, tar_mft = nemo_tarred_manifest_path_multi
config = OmegaConf.create(
Expand Down
Loading