@@ -357,6 +357,49 @@ def __init__(self, config):
357
357
assert set (data_out ) == set (range (num_seqs ))
358
358
359
359
360
+ def test_horovod_partition_combined_dataset_sampling ():
361
+ num_seqs = 10
362
+ sampling_size = 12
363
+ dummy_data = [{"data" : numpy .array ([i ])} for i in range (num_seqs )]
364
+ from returnn .datasets .meta import CombinedDataset
365
+ dataset = MapDatasetWrapper (FromListDataset (data_list = dummy_data ))
366
+ combined_dataset = CombinedDataset (
367
+ datasets = {"dataset" : dataset }, data_map = {("dataset" , "data" ): "data" }, sampling_sizes = {"dataset" : sampling_size },
368
+ data_dims = {"data" : (1 , 1 )}, seq_ordering = "random" )
369
+ from returnn .config import get_global_config
370
+ global_config = get_global_config (auto_create = True )
371
+ global_config .set ("use_horovod" , True )
372
+ global_config .set ("horovod_dataset_distribution" , "partition" )
373
+ from returnn .tf import horovod
374
+
375
+ horovod_size = 3
376
+ data_out = []
377
+ for rank in range (horovod_size ):
378
+ # Simulating a multi-gpu setup.
379
+ def get_dummy_ctx (config = None ):
380
+ class DummyHorovodContext (horovod .HorovodContext ):
381
+ def __init__ (self , config ):
382
+ self ._rank = rank
383
+ self ._size = horovod_size
384
+ self ._config = config
385
+ return DummyHorovodContext (config or global_config )
386
+ horovod .get_ctx = get_dummy_ctx
387
+ combined_dataset .init_seq_order (epoch = None )
388
+ seq_idx = 0
389
+ while combined_dataset .is_less_than_num_seqs (seq_idx ):
390
+ combined_dataset .load_seqs (seq_idx , seq_idx + 1 )
391
+ data = combined_dataset .get_data (seq_idx , "data" )
392
+ data_out .extend (data .tolist ())
393
+ seq_idx += 1
394
+ # We sample 12 values from range(10) "in order", so 0 and 1 should appear twice, all other values once. This e.g.
395
+ # would not be the case if the sub-dataset is partitioned before sampling,
396
+ # see Dataset.disable_horovod_partition.
397
+ assert len (data_out ) == sampling_size
398
+ assert set (data_out ) == set (range (num_seqs ))
399
+ assert data_out .count (0 ) == 2
400
+ assert data_out .count (1 ) == 2
401
+
402
+
360
403
if __name__ == "__main__" :
361
404
better_exchook .install ()
362
405
if len (sys .argv ) <= 1 :
0 commit comments