Skip to content

Commit 69dacf4

Browse files
Added horovod partition test for CombinedDataset with sampling_sizes
1 parent 90da898 commit 69dacf4

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/test_Dataset.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,49 @@ def __init__(self, config):
357357
assert set(data_out) == set(range(num_seqs))
358358

359359

360+
def test_horovod_partition_combined_dataset_sampling():
361+
num_seqs = 10
362+
sampling_size = 12
363+
dummy_data = [{"data": numpy.array([i])} for i in range(num_seqs)]
364+
from returnn.datasets.meta import CombinedDataset
365+
dataset = MapDatasetWrapper(FromListDataset(data_list=dummy_data))
366+
combined_dataset = CombinedDataset(
367+
datasets={"dataset": dataset}, data_map={("dataset", "data"): "data"}, sampling_sizes={"dataset": sampling_size},
368+
data_dims={"data": (1, 1)}, seq_ordering="random")
369+
from returnn.config import get_global_config
370+
global_config = get_global_config(auto_create=True)
371+
global_config.set("use_horovod", True)
372+
global_config.set("horovod_dataset_distribution", "partition")
373+
from returnn.tf import horovod
374+
375+
horovod_size = 3
376+
data_out = []
377+
for rank in range(horovod_size):
378+
# Simulating a multi-gpu setup.
379+
def get_dummy_ctx(config=None):
380+
class DummyHorovodContext(horovod.HorovodContext):
381+
def __init__(self, config):
382+
self._rank = rank
383+
self._size = horovod_size
384+
self._config = config
385+
return DummyHorovodContext(config or global_config)
386+
horovod.get_ctx = get_dummy_ctx
387+
combined_dataset.init_seq_order(epoch=None)
388+
seq_idx = 0
389+
while combined_dataset.is_less_than_num_seqs(seq_idx):
390+
combined_dataset.load_seqs(seq_idx, seq_idx + 1)
391+
data = combined_dataset.get_data(seq_idx, "data")
392+
data_out.extend(data.tolist())
393+
seq_idx += 1
394+
# We sample 12 values from range(10) "in order", so 0 and 1 should appear twice, all other values once. This e.g.
395+
# would not be the case if the sub-dataset is partitioned before sampling,
396+
# see Dataset.disable_horovod_partition.
397+
assert len(data_out) == sampling_size
398+
assert set(data_out) == set(range(num_seqs))
399+
assert data_out.count(0) == 2
400+
assert data_out.count(1) == 2
401+
402+
360403
if __name__ == "__main__":
361404
better_exchook.install()
362405
if len(sys.argv) <= 1:

0 commit comments

Comments
 (0)