@@ -758,12 +758,15 @@ def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta:
758
758
759
759
class Tracer (torch .fx .Tracer ):
760
760
"""
761
- Disables proxying buffers during tracing. Ideally, proxying buffers would be
762
- disabled, but some models are currently mutating buffer values, which causes errors
763
- during tracing. If those models can be rewritten to not do that, we can likely
764
- remove this line.
761
+ The Trace class used in `_rewrite_model`, treating all ShardedModules and ShardedModule-free
762
+ modules as leaf modules. A module who is not a ShardedModule but contains ShardedModule would
763
+ NOT be considered as a leaf module.
765
764
"""
766
765
766
+ # Disables proxying buffers during tracing. Ideally, proxying buffers would be
767
+ # disabled, but some models are currently mutating buffer values, which causes errors
768
+ # during tracing. If those models can be rewritten to not do that, we can likely
769
+ # remove this line.
767
770
proxy_buffer_attributes = False
768
771
769
772
def __init__ (self , leaf_modules : Optional [List [str ]] = None ) -> None :
@@ -1344,6 +1347,12 @@ def _get_leaf_module_names_helper(
1344
1347
path : str ,
1345
1348
leaf_module_names : Set [str ],
1346
1349
) -> bool :
1350
+ """
1351
+ recursive function returns True if any of the sub-modules is ShardedModule.
1352
+ it also added the fqns of the sub-modules who do not contain any ShardedModule
1353
+ into the `leaf_module_names` unless it's marked as `_is_pytorch_fx_traceable = True`,
1354
+ which suggests this ShardedModule-free module should NOT be treated as a leaf module
1355
+ """
1347
1356
sharded_children = set ()
1348
1357
for name , child in model .named_children ():
1349
1358
curr_path = path + name
@@ -1358,6 +1367,7 @@ def _get_leaf_module_names_helper(
1358
1367
if child_sharded :
1359
1368
sharded_children .add (name )
1360
1369
1370
+ # only do this for hybrid module (has sharded child)
1361
1371
if len (sharded_children ) > 0 :
1362
1372
for name , child in model .named_children ():
1363
1373
if name in sharded_children :
@@ -1371,8 +1381,9 @@ def _get_leaf_module_names_helper(
1371
1381
def _get_leaf_module_names (model : torch .nn .Module ) -> List [str ]:
1372
1382
"""
1373
1383
Returns a list of top level modules to be used as leaf modules for FX tracing.
1374
- This is a shallow FX trace that only goes the minimum depth required to pipeline
1375
- the model unless child modules are explicitly tagged as `_is_pytorch_fx_traceable`.
1384
+ This is a shallow FX trace that only goes the minimum depth required to pipeline.
1385
+ Any sub-module who does not contain a ShardedModule would be considered as a leaf
1386
+ module unless explicitly tagged as `_is_pytorch_fx_traceable = True`.
1376
1387
"""
1377
1388
1378
1389
leaf_module_names : Set [str ] = set ()
@@ -1454,7 +1465,7 @@ def _pipeline_detach_model(
1454
1465
setattr (model , postproc_mod .fqn , postproc_mod .postproc_module )
1455
1466
1456
1467
1457
- # pyre-ignore[3]
1468
+ # pyre-ignore[3] Return type must be specified as type that does not contain
1458
1469
def _rewrite_model ( # noqa C901
1459
1470
model : torch .nn .Module ,
1460
1471
context : TForwardContext ,
@@ -1471,32 +1482,50 @@ def _rewrite_model( # noqa C901
1471
1482
List [PipelinedPostproc ],
1472
1483
List [str ],
1473
1484
]:
1485
+ """
1486
+ This is a very important util function used by TorchRec's sparse-dist (and others) train pipeline.
1487
+
1488
+ The high-level idea of the sparse-dist train pipeline is to extract the forward calls of the sharded
1489
+ modules (e.g., ShardedEBC, ShardedEC, etc.) from the model's forward call, so that the sparse-dist
1490
+ pipeline can apply some optimization technique like overlapping the comms (i.e., input_dist) with
1491
+ compute (e.g., dense-forward, emb-lookup, etc.). And this "extraction of sharded forward" is done by
1492
+ this `_rewrite_model` util function.
1493
+
1494
+ currently the `_rewrite_model` function uses fx tracer to capture the graph of the sharded model,
1495
+ and find the "call_module" nodes for sharded modules.
1496
+
1497
+ theoretically the ShardedModule takes a KJT as the only input (EBC, EC, etc.), it calls `_get_node_args`
1498
+ to
1499
+ """
1474
1500
input_model = model
1475
- # Get underlying nn.Module
1501
+ # Get underlying sharded model (nn.Module) from DistributedModelParallel
1502
+ # which will not be wrapped in DDP, FSDP, DMP, or any other parallelism wrappers.
1476
1503
if isinstance (model , DistributedModelParallel ):
1477
1504
model = model .module
1478
1505
1479
1506
# Collect a list of sharded modules.
1480
- sharded_modules = {}
1507
+ sharded_modules : Dict [ str , ShardedModule ] = {} # fqn -> ShardedModule
1481
1508
for name , m in model .named_modules ():
1482
1509
if isinstance (m , ShardedModule ):
1483
1510
sharded_modules [name ] = m
1484
1511
1485
- # Trace a model.
1512
+ ## Trace a model. for more: https://pytorch.org/docs/stable/fx.html
1486
1513
concrete_args = {}
1514
+ """
1515
+ concrete_args allows you to partially specialize your function, whether it’s to remove
1516
+ control flow or data structures.
1517
+ """
1518
+
1519
+ # special handling of placeholder, adding meta/label to the PH node
1487
1520
if batch :
1488
1521
if hasattr (batch , "to_proxy" ):
1489
- # for some special models, it requires using "input"
1490
- # as the key for input
1522
+ # for some special models, it requires using "input" as the key for input
1491
1523
# pyre-ignore[16]: Variable[In (bound to Pipelineable)] has no attribute to_proxy.
1492
1524
concrete_args ["inputs" ] = copy .copy (batch ).to_proxy ()
1493
1525
elif hasattr (batch , "to_proxy_tuple" ):
1494
- # when the model is pre-fx traced or dynamo exported, the
1495
- # inputs are already flattened, and therefore we use
1496
- # tuple as concrete args that fx.trace will automatically
1497
- # match with the argument names.
1498
- # We pass in the model for the caller side to customize
1499
- # the batch
1526
+ # when the model is pre-fx traced or dynamo exported, the inputs are already flattened,
1527
+ # and therefore we use tuple as concrete args that fx.trace will automatically match
1528
+ # with the argument names. We pass in the model for the caller side to customize the batch
1500
1529
# pyre-ignore[16]: Variable[In (bound to Pipelineable)] has no attribute to_proxy_tuple.
1501
1530
concrete_args = batch .to_proxy_tuple (model )
1502
1531
@@ -1512,37 +1541,45 @@ def _rewrite_model( # noqa C901
1512
1541
non_pipelined_sharded_modules = []
1513
1542
1514
1543
for node in graph .nodes :
1515
- if node .op == "call_module" and node .target in sharded_modules :
1516
- total_num_args = len (node .args ) + len (node .kwargs )
1517
- if total_num_args == 0 :
1518
- continue
1519
- arg_info_list , num_found = _get_node_args (
1520
- model ,
1521
- node ,
1522
- pipelined_postprocs ,
1544
+ # only work on the call_module node which is also a sharded module
1545
+ if node .op != "call_module" or node .target not in sharded_modules :
1546
+ continue
1547
+
1548
+ total_num_args = len (node .args ) + len (node .kwargs )
1549
+ # only work on node with input(s), we don't expect zero input count for sharded module
1550
+ if total_num_args == 0 :
1551
+ logger .warning (f"Module '{ node .target } ' is a ShardedModule with zero input" )
1552
+ continue
1553
+
1554
+ # List[ArgInfo]: for rebuilding the input arguments, while the num verifies if missing any
1555
+ arg_info_list , num_found = _get_node_args (
1556
+ model ,
1557
+ node ,
1558
+ pipelined_postprocs ,
1559
+ context ,
1560
+ pipeline_postproc ,
1561
+ default_stream = default_stream ,
1562
+ dist_stream = dist_stream ,
1563
+ )
1564
+
1565
+ if num_found == total_num_args :
1566
+ logger .info (f"Module '{ node .target } ' will be pipelined" )
1567
+ child = sharded_modules [node .target ]
1568
+ original_forwards .append (child .forward )
1569
+ # pyre-ignore[8] Incompatible attribute type
1570
+ child .forward = pipelined_forward (
1571
+ node .target ,
1572
+ arg_info_list ,
1573
+ child ,
1523
1574
context ,
1524
- pipeline_postproc ,
1525
- default_stream = default_stream ,
1526
- dist_stream = dist_stream ,
1575
+ dist_stream ,
1527
1576
)
1528
-
1529
- if num_found == total_num_args :
1530
- logger .info (f"Module '{ node .target } ' will be pipelined" )
1531
- child = sharded_modules [node .target ]
1532
- original_forwards .append (child .forward )
1533
- child .forward = pipelined_forward (
1534
- node .target ,
1535
- arg_info_list ,
1536
- child ,
1537
- context ,
1538
- dist_stream ,
1539
- )
1540
- pipelined_forwards .append (child )
1541
- else :
1542
- logger .warning (
1543
- f"Module '{ node .target } '' will not be pipelined, due to input modifications"
1544
- )
1545
- non_pipelined_sharded_modules .append (node .target )
1577
+ pipelined_forwards .append (child )
1578
+ else :
1579
+ logger .warning (
1580
+ f"Module '{ node .target } ' will NOT be pipelined, due to input modifications"
1581
+ )
1582
+ non_pipelined_sharded_modules .append (node .target )
1546
1583
1547
1584
# JIT script unsharded modules if applicable.
1548
1585
if apply_jit :
0 commit comments