diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py b/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py index fb944b3ed76..7923cbdc4bd 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging +import math import shutil from contextlib import nullcontext from copy import deepcopy @@ -211,6 +212,14 @@ def build_distributed_environment(mesh_dim_config: tuple): """ from torch.distributed.device_mesh import init_device_mesh + required_world_size = math.prod(mesh_dim_config) + world_size = torch.distributed.get_world_size() + if world_size < required_world_size: + pytest.skip( + f"This test requires {required_world_size} GPUs for mesh " + f"{mesh_dim_config}, but only {world_size} are available" + ) + # Construct device mesh. device_mesh = init_device_mesh( "cuda", mesh_shape=mesh_dim_config, mesh_dim_names=(DP_OUTER, DP_SHARD, CP, TP)