Skip to content

Commit f35befa

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
refactor train_pipeline utils (#2926)
Summary: Pull Request resolved: #2926 # context add comments explaining the `_rewrite_model` util function, which is a critical part of the sparse-dist train pipeline. Reviewed By: aporialiao Differential Revision: D73868739 fbshipit-source-id: 0318dd6f5d328a84c8bd0278f4bf3e032d9f242c
1 parent 97f8dea commit f35befa

File tree

2 files changed

+87
-50
lines changed

2 files changed

+87
-50
lines changed

torchrec/distributed/test_utils/test_input.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,9 @@ def _assemble_kjt(
442442
lengths = None
443443
if pin_memory:
444444
indices = indices.pin_memory()
445-
lengths = lengths.pin_memory() if lengths else None
446-
weights = weights.pin_memory() if weights else None
447-
offsets = offsets.pin_memory() if offsets else None
445+
lengths = lengths.pin_memory() if lengths is not None else None
446+
weights = weights.pin_memory() if weights is not None else None
447+
offsets = offsets.pin_memory() if offsets is not None else None
448448
return KeyedJaggedTensor(features, indices, weights, lengths, offsets)
449449

450450
@staticmethod

torchrec/distributed/train_pipeline/utils.py

Lines changed: 84 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -758,12 +758,15 @@ def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta:
758758

759759
class Tracer(torch.fx.Tracer):
760760
"""
761-
Disables proxying buffers during tracing. Ideally, proxying buffers would be
762-
disabled, but some models are currently mutating buffer values, which causes errors
763-
during tracing. If those models can be rewritten to not do that, we can likely
764-
remove this line.
761+
The Trace class used in `_rewrite_model`, treating all ShardedModules and ShardedModule-free
762+
modules as leaf modules. A module who is not a ShardedModule but contains ShardedModule would
763+
NOT be considered as a leaf module.
765764
"""
766765

766+
# Disables proxying buffers during tracing. Ideally, proxying buffers would be
767+
# disabled, but some models are currently mutating buffer values, which causes errors
768+
# during tracing. If those models can be rewritten to not do that, we can likely
769+
# remove this line.
767770
proxy_buffer_attributes = False
768771

769772
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
@@ -1344,6 +1347,12 @@ def _get_leaf_module_names_helper(
13441347
path: str,
13451348
leaf_module_names: Set[str],
13461349
) -> bool:
1350+
"""
1351+
recursive function returns True if any of the sub-modules is ShardedModule.
1352+
it also added the fqns of the sub-modules who do not contain any ShardedModule
1353+
into the `leaf_module_names` unless it's marked as `_is_pytorch_fx_traceable = True`,
1354+
which suggests this ShardedModule-free module should NOT be treated as a leaf module
1355+
"""
13471356
sharded_children = set()
13481357
for name, child in model.named_children():
13491358
curr_path = path + name
@@ -1358,6 +1367,7 @@ def _get_leaf_module_names_helper(
13581367
if child_sharded:
13591368
sharded_children.add(name)
13601369

1370+
# only do this for hybrid module (has sharded child)
13611371
if len(sharded_children) > 0:
13621372
for name, child in model.named_children():
13631373
if name in sharded_children:
@@ -1371,8 +1381,9 @@ def _get_leaf_module_names_helper(
13711381
def _get_leaf_module_names(model: torch.nn.Module) -> List[str]:
13721382
"""
13731383
Returns a list of top level modules to be used as leaf modules for FX tracing.
1374-
This is a shallow FX trace that only goes the minimum depth required to pipeline
1375-
the model unless child modules are explicitly tagged as `_is_pytorch_fx_traceable`.
1384+
This is a shallow FX trace that only goes the minimum depth required to pipeline.
1385+
Any sub-module who does not contain a ShardedModule would be considered as a leaf
1386+
module unless explicitly tagged as `_is_pytorch_fx_traceable = True`.
13761387
"""
13771388

13781389
leaf_module_names: Set[str] = set()
@@ -1454,7 +1465,7 @@ def _pipeline_detach_model(
14541465
setattr(model, postproc_mod.fqn, postproc_mod.postproc_module)
14551466

14561467

1457-
# pyre-ignore[3]
1468+
# pyre-ignore[3] Return type must be specified as type that does not contain
14581469
def _rewrite_model( # noqa C901
14591470
model: torch.nn.Module,
14601471
context: TForwardContext,
@@ -1471,32 +1482,50 @@ def _rewrite_model( # noqa C901
14711482
List[PipelinedPostproc],
14721483
List[str],
14731484
]:
1485+
"""
1486+
This is a very important util function used by TorchRec's sparse-dist (and others) train pipeline.
1487+
1488+
The high-level idea of the sparse-dist train pipeline is to extract the forward calls of the sharded
1489+
modules (e.g., ShardedEBC, ShardedEC, etc.) from the model's forward call, so that the sparse-dist
1490+
pipeline can apply some optimization technique like overlapping the comms (i.e., input_dist) with
1491+
compute (e.g., dense-forward, emb-lookup, etc.). And this "extraction of sharded forward" is done by
1492+
this `_rewrite_model` util function.
1493+
1494+
currently the `_rewrite_model` function uses fx tracer to capture the graph of the sharded model,
1495+
and find the "call_module" nodes for sharded modules.
1496+
1497+
theoretically the ShardedModule takes a KJT as the only input (EBC, EC, etc.), it calls `_get_node_args`
1498+
to
1499+
"""
14741500
input_model = model
1475-
# Get underlying nn.Module
1501+
# Get underlying sharded model (nn.Module) from DistributedModelParallel
1502+
# which will not be wrapped in DDP, FSDP, DMP, or any other parallelism wrappers.
14761503
if isinstance(model, DistributedModelParallel):
14771504
model = model.module
14781505

14791506
# Collect a list of sharded modules.
1480-
sharded_modules = {}
1507+
sharded_modules: Dict[str, ShardedModule] = {} # fqn -> ShardedModule
14811508
for name, m in model.named_modules():
14821509
if isinstance(m, ShardedModule):
14831510
sharded_modules[name] = m
14841511

1485-
# Trace a model.
1512+
## Trace a model. for more: https://pytorch.org/docs/stable/fx.html
14861513
concrete_args = {}
1514+
"""
1515+
concrete_args allows you to partially specialize your function, whether it’s to remove
1516+
control flow or data structures.
1517+
"""
1518+
1519+
# special handling of placeholder, adding meta/label to the PH node
14871520
if batch:
14881521
if hasattr(batch, "to_proxy"):
1489-
# for some special models, it requires using "input"
1490-
# as the key for input
1522+
# for some special models, it requires using "input" as the key for input
14911523
# pyre-ignore[16]: Variable[In (bound to Pipelineable)] has no attribute to_proxy.
14921524
concrete_args["inputs"] = copy.copy(batch).to_proxy()
14931525
elif hasattr(batch, "to_proxy_tuple"):
1494-
# when the model is pre-fx traced or dynamo exported, the
1495-
# inputs are already flattened, and therefore we use
1496-
# tuple as concrete args that fx.trace will automatically
1497-
# match with the argument names.
1498-
# We pass in the model for the caller side to customize
1499-
# the batch
1526+
# when the model is pre-fx traced or dynamo exported, the inputs are already flattened,
1527+
# and therefore we use tuple as concrete args that fx.trace will automatically match
1528+
# with the argument names. We pass in the model for the caller side to customize the batch
15001529
# pyre-ignore[16]: Variable[In (bound to Pipelineable)] has no attribute to_proxy_tuple.
15011530
concrete_args = batch.to_proxy_tuple(model)
15021531

@@ -1512,37 +1541,45 @@ def _rewrite_model( # noqa C901
15121541
non_pipelined_sharded_modules = []
15131542

15141543
for node in graph.nodes:
1515-
if node.op == "call_module" and node.target in sharded_modules:
1516-
total_num_args = len(node.args) + len(node.kwargs)
1517-
if total_num_args == 0:
1518-
continue
1519-
arg_info_list, num_found = _get_node_args(
1520-
model,
1521-
node,
1522-
pipelined_postprocs,
1544+
# only work on the call_module node which is also a sharded module
1545+
if node.op != "call_module" or node.target not in sharded_modules:
1546+
continue
1547+
1548+
total_num_args = len(node.args) + len(node.kwargs)
1549+
# only work on node with input(s), we don't expect zero input count for sharded module
1550+
if total_num_args == 0:
1551+
logger.warning(f"Module '{node.target}' is a ShardedModule with zero input")
1552+
continue
1553+
1554+
# List[ArgInfo]: for rebuilding the input arguments, while the num verifies if missing any
1555+
arg_info_list, num_found = _get_node_args(
1556+
model,
1557+
node,
1558+
pipelined_postprocs,
1559+
context,
1560+
pipeline_postproc,
1561+
default_stream=default_stream,
1562+
dist_stream=dist_stream,
1563+
)
1564+
1565+
if num_found == total_num_args:
1566+
logger.info(f"Module '{node.target}' will be pipelined")
1567+
child = sharded_modules[node.target]
1568+
original_forwards.append(child.forward)
1569+
# pyre-ignore[8] Incompatible attribute type
1570+
child.forward = pipelined_forward(
1571+
node.target,
1572+
arg_info_list,
1573+
child,
15231574
context,
1524-
pipeline_postproc,
1525-
default_stream=default_stream,
1526-
dist_stream=dist_stream,
1575+
dist_stream,
15271576
)
1528-
1529-
if num_found == total_num_args:
1530-
logger.info(f"Module '{node.target}' will be pipelined")
1531-
child = sharded_modules[node.target]
1532-
original_forwards.append(child.forward)
1533-
child.forward = pipelined_forward(
1534-
node.target,
1535-
arg_info_list,
1536-
child,
1537-
context,
1538-
dist_stream,
1539-
)
1540-
pipelined_forwards.append(child)
1541-
else:
1542-
logger.warning(
1543-
f"Module '{node.target}'' will not be pipelined, due to input modifications"
1544-
)
1545-
non_pipelined_sharded_modules.append(node.target)
1577+
pipelined_forwards.append(child)
1578+
else:
1579+
logger.warning(
1580+
f"Module '{node.target}' will NOT be pipelined, due to input modifications"
1581+
)
1582+
non_pipelined_sharded_modules.append(node.target)
15461583

15471584
# JIT script unsharded modules if applicable.
15481585
if apply_jit:

0 commit comments

Comments
 (0)