Disable shape_padding in TorchRec pipelines to prevent cross-PG deadlock on AMD#4128
Open
kaanbaloglu wants to merge 1 commit intometa-pytorch:mainfrom
Open
Disable shape_padding in TorchRec pipelines to prevent cross-PG deadlock on AMD#4128kaanbaloglu wants to merge 1 commit intometa-pytorch:mainfrom
kaanbaloglu wants to merge 1 commit intometa-pytorch:mainfrom
Conversation
Contributor
|
@kaanbaloglu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D101241634. |
…ock on AMD Summary: torch.compile's should_pad_mm heuristic benchmarks GPU kernels via benchmark_gpu(), which calls torch.cuda.synchronize(). This device-wide sync blocks on pending NCCL collectives from other process groups, causing a circular deadlock in distributed training with multiple PGs (e.g. mesh_shard + mesh_replicate). The deadlock has only been observed on AMD MI350X (maz5 datacenter). To minimize blast radius, this diff scopes the workaround to AMD/ROCm builds via `torch.version.hip is not None`. NVIDIA jobs continue to benefit from the shape_padding optimization. This diff disables torch._inductor.config.shape_padding (AMD-only) in TrainPipelinePT2 and TrainPipelineSparseDistCompAutograd, matching the precedent set by Simple FSDP (set_configs_for_simple_fsdp). The MM padding optimization is a nice-to-have that is not critical for model performance — Simple FSDP already disables it for all its users without reported issues. Differential Revision: D101241634
cad94cb to
3015e2b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
torch.compile's should_pad_mm heuristic benchmarks GPU kernels via benchmark_gpu(), which calls torch.cuda.synchronize(). This device-wide sync blocks on pending NCCL collectives from other process groups, causing a circular deadlock in distributed training with multiple PGs (e.g. mesh_shard + mesh_replicate).
The deadlock has only been observed on AMD MI350X (maz5 datacenter). To minimize blast radius, this diff scopes the workaround to AMD/ROCm builds via
torch.version.hip is not None. NVIDIA jobs continue to benefit from the shape_padding optimization.This diff disables torch._inductor.config.shape_padding (AMD-only) in TrainPipelinePT2 and TrainPipelineSparseDistCompAutograd, matching the precedent set by Simple FSDP (set_configs_for_simple_fsdp). The MM padding optimization is a nice-to-have that is not critical for model performance — Simple FSDP already disables it for all its users without reported issues.
Differential Revision: D101241634