Skip to content

Commit

Permalink
device agnostic testing (#2123)
Browse files Browse the repository at this point in the history
* device agnostic testing

* initilaize accelerate state before using the logging utility

* apply review suggestion

* apply review suggestion

Co-authored-by: Zach Mueller <[email protected]>

* use `hardware accelerator` to disambiguate

* remove redundant guard code

* rename variable name for consistency

* remove the overkilled codes

* fix ci-error

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
ji-huazhong and muellerzr authored Dec 8, 2023
1 parent 54d670b commit 0a37e20
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 15 deletions.
5 changes: 5 additions & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from .testing import (
are_the_same_tensors,
assert_exception,
device_count,
execute_subprocess_async,
require_bnb,
require_cpu,
require_cuda,
require_huggingface_suite,
require_mps,
require_multi_device,
require_multi_gpu,
require_multi_xpu,
require_non_cpu,
require_single_device,
require_single_gpu,
require_single_xpu,
require_torch_min_version,
require_tpu,
require_xpu,
skip,
slow,
torch_device,
)
from .training import RegressionDataset, RegressionModel, RegressionModel4XPU

Expand Down
11 changes: 7 additions & 4 deletions src/accelerate/test_utils/scripts/external_deps/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from accelerate import Accelerator
from accelerate.data_loader import DataLoaderDispatcher
from accelerate.test_utils import RegressionDataset, RegressionModel
from accelerate.utils import is_tpu_available, set_seed
from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device
from accelerate.utils import set_seed


os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
Expand Down Expand Up @@ -87,7 +87,10 @@ def get_mrpc_setup(dispatch_batches, split_batches):
"hf-internal-testing/mrpc-bert-base-cased", return_dict=True
)
ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader)
return {"ddp": [ddp_model, ddp_dataloader, "cuda:0"], "no": [model, dataloader, accelerator.device]}, accelerator
return {
"ddp": [ddp_model, ddp_dataloader, torch_device],
"no": [model, dataloader, accelerator.device],
}, accelerator


def generate_predictions(model, dataloader, accelerator):
Expand Down Expand Up @@ -247,7 +250,7 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# These are a bit slower so they should only be ran on the GPU or TPU
if torch.cuda.is_available() or is_tpu_available():
if accelerator.device.type != "cpu":
if accelerator.is_local_main_process:
print("**Testing gather_for_metrics**")
for split_batches in [True, False]:
Expand Down
45 changes: 44 additions & 1 deletion src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_deepspeed_available,
is_dvclive_available,
is_mps_available,
is_npu_available,
is_pandas_available,
is_tensorboard_available,
is_timm_available,
Expand All @@ -49,6 +50,22 @@
)


def get_backend():
if torch.cuda.is_available():
return "cuda", torch.cuda.device_count()
elif is_mps_available():
return "mps", 1
elif is_npu_available():
return "npu", torch.npu.device_count()
elif is_xpu_available():
return "xpu", torch.xpu.device_count()
else:
return "cpu", 1


torch_device, device_count = get_backend()


def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
Expand Down Expand Up @@ -85,7 +102,15 @@ def require_cpu(test_case):
"""
Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available.
"""
return unittest.skipUnless(not torch.cuda.is_available(), "test requires only a CPU")(test_case)
return unittest.skipUnless(torch_device == "cpu", "test requires only a CPU")(test_case)


def require_non_cpu(test_case):
"""
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
hardware accelerator available.
"""
return unittest.skipUnless(torch_device != "cpu", "test requires a GPU")(test_case)


def require_cuda(test_case):
Expand Down Expand Up @@ -147,6 +172,16 @@ def require_tpu(test_case):
return unittest.skipUnless(is_tpu_available(), "test requires TPU")(test_case)


def require_single_device(test_case):
"""
Decorator marking a test that requires a single device. These tests are skipped when there is no hardware
accelerator available or number of devices is more than one.
"""
return unittest.skipUnless(torch_device != "cpu" and device_count == 1, "test requires a hardware accelerator")(
test_case
)


def require_single_gpu(test_case):
"""
Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU
Expand All @@ -163,6 +198,14 @@ def require_single_xpu(test_case):
return unittest.skipUnless(torch.xpu.device_count() == 1, "test requires a XPU")(test_case)


def require_multi_device(test_case):
"""
Decorator marking a test that requires a multi-device setup. These tests are skipped on a machine without multiple
devices.
"""
return unittest.skipUnless(device_count > 1, "test requires multiple hardware accelerators")(test_case)


def require_multi_gpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
Expand Down
19 changes: 9 additions & 10 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
import os
import unittest

import torch

import accelerate
from accelerate import debug_launcher
from accelerate.test_utils import (
device_count,
execute_subprocess_async,
require_cpu,
require_huggingface_suite,
require_multi_gpu,
require_single_gpu,
require_multi_device,
require_single_device,
)
from accelerate.utils import patch_environment

Expand All @@ -50,13 +49,13 @@ def test_metric_cpu_noop(self):
def test_metric_cpu_multi(self):
debug_launcher(self.test_metrics.main)

@require_single_gpu
def test_metric_gpu(self):
@require_single_device
def test_metric_accelerator(self):
self.test_metrics.main()

@require_multi_gpu
def test_metric_gpu_multi(self):
print(f"Found {torch.cuda.device_count()} devices.")
cmd = ["torchrun", f"--nproc_per_node={torch.cuda.device_count()}", self.test_file_path]
@require_multi_device
def test_metric_accelerator_multi(self):
print(f"Found {device_count} devices.")
cmd = ["torchrun", f"--nproc_per_node={device_count}", self.test_file_path]
with patch_environment(omp_num_threads=1, ACCELERATE_LOG_LEVEL="INFO"):
execute_subprocess_async(cmd, env=os.environ.copy())

0 comments on commit 0a37e20

Please sign in to comment.