Skip to content

Commit

Permalink
Speed up tests by shaving off subprocess when not needed (#3042)
Browse files Browse the repository at this point in the history
* bookmark

* Continue making improvements

* Bookmark

* More

* Format
  • Loading branch information
muellerzr authored Sep 2, 2024
1 parent 758d624 commit a848592
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 53 deletions.
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DEFAULT_LAUNCH_COMMAND,
are_the_same_tensors,
assert_exception,
capture_call_output,
device_count,
execute_subprocess_async,
get_launch_command,
Expand Down
17 changes: 17 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import inspect
import io
import os
import shutil
import subprocess
Expand Down Expand Up @@ -670,3 +671,19 @@ def assert_exception(exception_class: Exception, msg: str = None) -> bool:
assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'"
if was_ran:
raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.")


def capture_call_output(func, *args, **kwargs):
"""
Takes in a `func` with `args` and `kwargs` and returns the captured stdout as a string
"""
captured_output = io.StringIO()
original_stdout = sys.stdout
try:
sys.stdout = captured_output
func(*args, **kwargs)
except Exception as e:
raise e
finally:
sys.stdout = original_stdout
return captured_output.getvalue()
102 changes: 49 additions & 53 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
from pathlib import Path
from unittest.mock import patch

import torch
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError

import accelerate.commands.test as accelerate_test_cmd
from accelerate.commands.config.config_args import BaseConfig, ClusterConfig, SageMakerConfig, load_config_from_file
from accelerate.commands.estimate import estimate_command, estimate_command_parser, gather_data
from accelerate.commands.launch import _validate_launch_command, launch_command_parser
from accelerate.test_utils import execute_subprocess_async
from accelerate.commands.launch import _validate_launch_command, launch_command, launch_command_parser
from accelerate.commands.tpu import tpu_command_launcher, tpu_command_parser
from accelerate.test_utils.testing import (
DEFAULT_LAUNCH_COMMAND,
get_launch_command,
capture_call_output,
path_in_accelerate_package,
require_multi_device,
require_timm,
Expand All @@ -53,6 +52,7 @@ class AccelerateLauncherTester(unittest.TestCase):
changed_path = config_folder / "_default_config.yaml"

test_config_path = Path("tests/test_configs")
parser = launch_command_parser()

@classmethod
def setUpClass(cls):
Expand All @@ -65,33 +65,33 @@ def tearDownClass(cls):
cls.changed_path.rename(cls.config_path)

def test_no_config(self):
args = ["--monitor_interval", "0.1", str(self.test_file_path)]
if torch.cuda.is_available() and (torch.cuda.device_count() > 1):
cmd = get_launch_command(multi_gpu=True)
else:
cmd = DEFAULT_LAUNCH_COMMAND
cmd.append(self.test_file_path)
execute_subprocess_async(cmd, env=os.environ.copy())
args = ["--multi_gpu"] + args
args = self.parser.parse_args(["--monitor_interval", "0.1", str(self.test_file_path)])
launch_command(args)

def test_config_compatibility(self):
invalid_configs = ["fp8", "invalid", "mpi", "sagemaker"]
for config in sorted(self.test_config_path.glob("**/*.yaml")):
if any(invalid_config in str(config) for invalid_config in invalid_configs):
continue
with self.subTest(config_file=config):
cmd = get_launch_command(config_file=config) + [self.test_file_path]
execute_subprocess_async(cmd)
args = self.parser.parse_args(["--config_file", str(config), str(self.test_file_path)])
launch_command(args)

def test_invalid_keys(self):
config_path = self.test_config_path / "invalid_keys.yaml"
with self.assertRaises(
RuntimeError,
ValueError,
msg="The config file at 'invalid_keys.yaml' had unknown keys ('another_invalid_key', 'invalid_key')",
):
cmd = get_launch_command(config_file=config_path) + [self.test_file_path]
execute_subprocess_async(cmd)
args = self.parser.parse_args(["--config_file", str(config_path), str(self.test_file_path)])
launch_command(args)

def test_accelerate_test(self):
execute_subprocess_async(["accelerate", "test"])
args = accelerate_test_cmd.test_command_parser().parse_args([])
accelerate_test_cmd.test_command(args)

@require_multi_device
def test_notebook_launcher(self):
Expand Down Expand Up @@ -276,18 +276,19 @@ class TpuConfigTester(unittest.TestCase):
command_file = "tests/test_samples/test_command_file.sh"
gcloud = "Running gcloud compute tpus tpu-vm ssh"

def setUp(self):
self.parser = tpu_command_parser()

def test_base(self):
output = run_command(
self.cmd
+ ["--command", self.command, "--tpu_zone", self.tpu_zone, "--tpu_name", self.tpu_name, "--debug"],
return_stdout=True,
args = self.parser.parse_args(
["--command", self.command, "--tpu_zone", self.tpu_zone, "--tpu_name", self.tpu_name, "--debug"]
)
output = capture_call_output(tpu_command_launcher, args)
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output

def test_base_backward_compatibility(self):
output = run_command(
self.cmd
+ [
args = self.parser.parse_args(
[
"--config_file",
"tests/test_configs/0_12_0.yaml",
"--command",
Expand All @@ -297,61 +298,57 @@ def test_base_backward_compatibility(self):
"--tpu_name",
self.tpu_name,
"--debug",
],
return_stdout=True,
]
)
output = capture_call_output(tpu_command_launcher, args)
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output

def test_with_config_file(self):
output = run_command(
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--debug"], return_stdout=True
)
args = self.parser.parse_args(["--config_file", "tests/test_configs/latest.yaml", "--debug"])
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
in output
)

def test_with_config_file_and_command(self):
output = run_command(
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--command", self.command, "--debug"],
return_stdout=True,
args = self.parser.parse_args(
["--config_file", "tests/test_configs/latest.yaml", "--command", self.command, "--debug"]
)
output = capture_call_output(tpu_command_launcher, args)
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output

def test_with_config_file_and_multiple_command(self):
output = run_command(
self.cmd
+ [
args = self.parser.parse_args(
[
"--config_file",
"tests/test_configs/latest.yaml",
"--command",
self.command,
"--command",
'echo "Hello World"',
"--debug",
],
return_stdout=True,
]
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls; echo "Hello World" --worker all'
in output
)

def test_with_config_file_and_command_file(self):
output = run_command(
self.cmd
+ ["--config_file", "tests/test_configs/latest.yaml", "--command_file", self.command_file, "--debug"],
return_stdout=True,
args = self.parser.parse_args(
["--config_file", "tests/test_configs/latest.yaml", "--command_file", self.command_file, "--debug"]
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
in output
)

def test_with_config_file_and_command_file_backward_compatibility(self):
output = run_command(
self.cmd
+ [
args = self.parser.parse_args(
[
"--config_file",
"tests/test_configs/0_12_0.yaml",
"--command_file",
Expand All @@ -361,37 +358,36 @@ def test_with_config_file_and_command_file_backward_compatibility(self):
"--tpu_name",
self.tpu_name,
"--debug",
],
return_stdout=True,
]
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
in output
)

def test_accelerate_install(self):
output = run_command(
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--install_accelerate", "--debug"],
return_stdout=True,
args = self.parser.parse_args(
["--config_file", "tests/test_configs/latest.yaml", "--install_accelerate", "--debug"]
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; pip install accelerate -U; echo "hello world"; echo "this is a second command" --worker all'
in output
)

def test_accelerate_install_version(self):
output = run_command(
self.cmd
+ [
args = self.parser.parse_args(
[
"--config_file",
"tests/test_configs/latest.yaml",
"--install_accelerate",
"--accelerate_version",
"12.0.0",
"--debug",
],
return_stdout=True,
]
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; pip install accelerate==12.0.0; echo "hello world"; echo "this is a second command" --worker all'
in output
Expand Down

0 comments on commit a848592

Please sign in to comment.