@@ -365,7 +365,9 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
365
365
dist_print (f"Initialized default NCCL process group with { WORLD_SIZE } GPUs" )
366
366
367
367
# Initialize the Transformer Engine layer with overlap
368
- args , kwargs , input_shape = _get_layer_args (opts , nccl_world , opts .tp , num_layers = opts .num_layers )
368
+ args , kwargs , input_shape = _get_layer_args (
369
+ opts , nccl_world , opts .tp , num_layers = opts .num_layers
370
+ )
369
371
# Intialize userbuffers
370
372
ub_cfgs = None
371
373
if opts .overlap_rs_dgrad :
@@ -391,7 +393,9 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
391
393
dist .barrier ()
392
394
393
395
# Initialize the reference model and copy all parameters
394
- ref_args , ref_kwargs , _ = _get_layer_args (opts , nccl_world , opts .tp , num_layers = opts .num_layers , reference = True )
396
+ ref_args , ref_kwargs , _ = _get_layer_args (
397
+ opts , nccl_world , opts .tp , num_layers = opts .num_layers , reference = True
398
+ )
395
399
with te .fp8_model_init (enabled = opts .fp8_init ):
396
400
ref_model = multi_module_model (opts .layer_type , opts .num_layers , * ref_args , ** ref_kwargs )
397
401
dist_print ("Initialized reference model..." , debug = True )
0 commit comments