Skip to content

[DO NOT MERGE] fix main #2832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/gpu_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ jobs:
python-version: ['3.9', '3.10', '3.11']
torch-version: ["stable", "nightly"]
# Do not run against nightlies on PR
exclude:
- torch-version: ${{ github.event_name == 'pull_request' && 'nightly' }}
# exclude:
# - torch-version: ${{ github.event_name == 'pull_request' && 'nightly' }}
steps:
- name: Check out repo
uses: actions/checkout@v4
Expand Down
70 changes: 35 additions & 35 deletions tests/torchtune/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,41 +78,41 @@ def test_batch_to_device(self):
with pytest.raises(ValueError):
batch_to_device(batch, device)

@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
def test_get_gpu_device(self) -> None:
device_idx = torch.cuda.device_count() - 1
assert device_idx >= 0
with mock.patch.dict(os.environ, {"LOCAL_RANK": str(device_idx)}, clear=True):
device = get_device()
assert device.type == "cuda"
assert device.index == device_idx
assert device.index == torch.cuda.current_device()

# Test that we raise an error if the device index is specified on distributed runs
if device_idx > 0:
with pytest.raises(
RuntimeError,
match=(
f"You can't specify a device index when using distributed training. "
f"Device specified is cuda:0 but local rank is:{device_idx}"
),
):
device = get_device("cuda:0")

invalid_device_idx = device_idx + 10
with mock.patch.dict(os.environ, {"LOCAL_RANK": str(invalid_device_idx)}):
with pytest.raises(
RuntimeError,
match="The local rank is larger than the number of available GPUs",
):
device = get_device("cuda")

# Test that we fall back to 0 if LOCAL_RANK is not specified
device = torch.device(_get_device_type_from_env())
device = _setup_device(device)
assert device.type == "cuda"
assert device.index == 0
assert device.index == torch.cuda.current_device()
# @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
# def test_get_gpu_device(self) -> None:
# device_idx = torch.cuda.device_count() - 1
# assert device_idx >= 0
# with mock.patch.dict(os.environ, {"LOCAL_RANK": str(device_idx)}, clear=True):
# device = get_device()
# assert device.type == "cuda"
# assert device.index == device_idx
# assert device.index == torch.cuda.current_device()

# # Test that we raise an error if the device index is specified on distributed runs
# if device_idx > 0:
# with pytest.raises(
# RuntimeError,
# match=(
# f"You can't specify a device index when using distributed training. "
# f"Device specified is cuda:0 but local rank is:{device_idx}"
# ),
# ):
# device = get_device("cuda:0")

# invalid_device_idx = device_idx + 10
# with mock.patch.dict(os.environ, {"LOCAL_RANK": str(invalid_device_idx)}):
# with pytest.raises(
# RuntimeError,
# match="The local rank is larger than the number of available GPUs",
# ):
# device = get_device("cuda")

# # Test that we fall back to 0 if LOCAL_RANK is not specified
# device = torch.device(_get_device_type_from_env())
# device = _setup_device(device)
# assert device.type == "cuda"
# assert device.index == 0
# assert device.index == torch.cuda.current_device()

@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
@patch("torch.cuda.is_available", return_value=True)
Expand Down
32 changes: 16 additions & 16 deletions torchtune/utils/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,22 @@ def _validate_device_from_env(device: torch.device) -> None:
"""
local_rank = _get_local_rank()

# Check if the device index is correct
if device.type != "cpu" and local_rank is not None:
# Ensure device index matches assigned index when distributed training
if device.index != local_rank:
raise RuntimeError(
f"You can't specify a device index when using distributed training. "
f"Device specified is {device} but local rank is:{local_rank}"
)

# Check if the device is available on this machine
try:
torch.empty(0, device=device)
except RuntimeError as e:
raise RuntimeError(
f"The device {device} is not available on this machine."
) from e
# # Check if the device index is correct
# if device.type != "cpu" and local_rank is not None:
# # Ensure device index matches assigned index when distributed training
# if device.index != local_rank:
# raise RuntimeError(
# f"You can't specify a device index when using distributed training. "
# f"Device specified is {device} but local rank is:{local_rank}"
# )

# # Check if the device is available on this machine
# try:
# torch.empty(0, device=device)
# except RuntimeError as e:
# raise RuntimeError(
# f"The device {device} is not available on this machine."
# ) from e


def get_device(device: Optional[str] = None) -> torch.device:
Expand Down
Loading