Skip to content

Commit ace0aa9

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f686a49 commit ace0aa9

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/pytorch/distributed/run_layer_with_overlap.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,9 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
365365
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
366366

367367
# Initialize the Transformer Engine layer with overlap
368-
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, opts.tp, num_layers=opts.num_layers)
368+
args, kwargs, input_shape = _get_layer_args(
369+
opts, nccl_world, opts.tp, num_layers=opts.num_layers
370+
)
369371
# Intialize userbuffers
370372
ub_cfgs = None
371373
if opts.overlap_rs_dgrad:
@@ -391,7 +393,9 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
391393
dist.barrier()
392394

393395
# Initialize the reference model and copy all parameters
394-
ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True)
396+
ref_args, ref_kwargs, _ = _get_layer_args(
397+
opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True
398+
)
395399
with te.fp8_model_init(enabled=opts.fp8_init):
396400
ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs)
397401
dist_print("Initialized reference model...", debug=True)

0 commit comments

Comments
 (0)