Skip to content

Commit b602030

Browse files
authored
Parallel tests v2 (#276)
1 parent 65d7e03 commit b602030

14 files changed

+296
-185
lines changed

fast_llm/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ def __init__(self, **kwargs):
274274

275275
if dynamic_type is not None:
276276
for cls_, name in dynamic_type.items():
277-
print(cls_, name, wrapped)
278277
cls_.register_subclass(name, wrapped)
279278

280279
return wrapped

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ DEV =
5151
# Required for testing
5252
pytest>=8.3.2
5353
pytest-depends>=1.0.1
54+
pytest-xdist>=3.6.1
5455
# Somehow needed for Megatron to work with base image 24.11
5556
setuptools>=75.6.0
5657

tests/common.py

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
# TODO: Use `pytest_addoption` instead?
3434
# Keep all results in one place to allow recovering them for debugging in case of failure.
35-
TEST_RESULTS_PATH = pathlib.Path(os.environ.get("TEST_RESULTS_PATH", "/tmp/fast_llm_tests"))
35+
TEST_RESULTS_PATH = pathlib.Path(os.environ.get("TEST_RESULTS_PATH", "/tmp/fast_llm_tests")).resolve()
3636
FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0
3737
REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0
3838
_LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13))
@@ -350,78 +350,84 @@ def get_test_concatenated_memmap_dataset(
350350
index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)])
351351

352352

353-
def run_test_script(
354-
name: str,
355-
script: list[str],
356-
num_gpus: int = 1,
357-
*,
358-
model_type: str = TEST_MODEL_TYPE,
359-
is_megatron: bool = False,
360-
compare: str | None = None,
361-
config: CompareConfig | None = None,
362-
prepare_fn=None,
363-
compare_fn=None,
364-
do_compare: bool = True,
365-
):
366-
if torch.cuda.device_count() < num_gpus:
367-
pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})")
368-
env = os.environ.copy()
369-
if is_megatron:
370-
# Prevent Megatron from complaining.
371-
env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
372-
env["NVTE_FLASH_ATTN"] = "0"
373-
path = TEST_RESULTS_PATH.resolve() / name
374-
skip = False
375-
artifact_path = path / ARTIFACT_PATH
376-
if path.exists():
377-
assert path.is_dir()
378-
# TODO: Better way to check if the previous attempt succeeded.
379-
if (
380-
REUSE_RESULTS
381-
and artifact_path.is_dir()
382-
and len(list((artifact_path / "0").iterdir())) >= (1 if is_megatron else 3)
383-
):
384-
skip = True
353+
@pytest.fixture(scope="session")
354+
def run_test_script(worker_resources):
355+
def do_run_test_script(
356+
name: str,
357+
script: list[str],
358+
num_gpus: int = 1,
359+
*,
360+
model_type: str = TEST_MODEL_TYPE,
361+
is_megatron: bool = False,
362+
compare: str | None = None,
363+
config: CompareConfig | None = None,
364+
prepare_fn=None,
365+
compare_fn=None,
366+
do_compare: bool = True,
367+
):
368+
if torch.cuda.device_count() < num_gpus:
369+
pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})")
370+
env = os.environ.copy()
371+
if is_megatron:
372+
# Prevent Megatron from complaining.
373+
env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
374+
env["NVTE_FLASH_ATTN"] = "0"
375+
path = TEST_RESULTS_PATH / name
376+
skip = False
377+
artifact_path = path / ARTIFACT_PATH
378+
if path.exists():
379+
assert path.is_dir()
380+
# TODO: Better way to check if the previous attempt succeeded.
381+
if (
382+
REUSE_RESULTS
383+
and artifact_path.is_dir()
384+
and len(list((artifact_path / "0").iterdir())) >= (1 if is_megatron else 3)
385+
):
386+
skip = True
387+
elif FORCE_REUSE_RESULTS:
388+
raise RuntimeError(artifact_path)
389+
else:
390+
shutil.rmtree(path)
385391
elif FORCE_REUSE_RESULTS:
386-
raise RuntimeError(artifact_path)
392+
raise RuntimeError(path)
393+
if prepare_fn is not None:
394+
skip = prepare_fn(TEST_RESULTS_PATH / name, None if compare is None else TEST_RESULTS_PATH / compare, skip)
395+
if is_megatron:
396+
script = [*script, f"--structured-logs-dir={path}", f"--data-cache-path={path}"]
387397
else:
388-
shutil.rmtree(path)
389-
elif FORCE_REUSE_RESULTS:
390-
raise RuntimeError(path)
391-
if prepare_fn is not None:
392-
skip = prepare_fn(TEST_RESULTS_PATH / name, None if compare is None else TEST_RESULTS_PATH / compare, skip)
393-
if is_megatron:
394-
script = [*script, f"--structured-logs-dir={path}", f"--data-cache-path={path}"]
395-
else:
396-
script = [model_type, *script, f"run.experiment_dir={path}"]
397-
header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"]
398-
command = [
399-
"python",
400-
"-m",
401-
"torch.distributed.run",
402-
f"--nproc-per-node={num_gpus}",
403-
*header,
404-
*script,
405-
]
406-
print(" ".join(command))
407-
if skip:
408-
print("Reusing existing run.")
409-
else:
410-
get_test_dataset()
411-
if num_gpus == 1 and not is_megatron:
412-
CliTrainingConfig.parse_and_run(script)
398+
script = [model_type, *script, f"run.experiment_dir={path}"]
399+
header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"]
400+
command = [
401+
"python",
402+
"-m",
403+
"torch.distributed.run",
404+
f"--nproc-per-node={num_gpus}",
405+
f"--rdzv-endpoint=localhost:{worker_resources.rendezvous_port}",
406+
f"--master-port={worker_resources.torchrun_port}",
407+
*header,
408+
*script,
409+
]
410+
print(" ".join(command))
411+
if skip:
412+
print("Reusing existing run.")
413413
else:
414-
completed_proc = subprocess.run(command, env=env, timeout=60)
415-
if completed_proc.returncode:
416-
raise RuntimeError(f"Process failed with return code {completed_proc.returncode}")
417-
if compare and do_compare:
418-
if compare_fn is not None:
419-
compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare)
420-
compare_tensor_logs(
421-
TEST_RESULTS_PATH / compare / ARTIFACT_PATH,
422-
TEST_RESULTS_PATH / name / ARTIFACT_PATH,
423-
config,
424-
)
414+
get_test_dataset()
415+
if num_gpus == 1 and not is_megatron:
416+
CliTrainingConfig.parse_and_run(script)
417+
else:
418+
completed_proc = subprocess.run(command, env=env, timeout=60)
419+
if completed_proc.returncode:
420+
raise RuntimeError(f"Process failed with return code {completed_proc.returncode}")
421+
if compare and do_compare:
422+
if compare_fn is not None:
423+
compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare)
424+
compare_tensor_logs(
425+
TEST_RESULTS_PATH / compare / ARTIFACT_PATH,
426+
TEST_RESULTS_PATH / name / ARTIFACT_PATH,
427+
config,
428+
)
429+
430+
return do_run_test_script
425431

426432

427433
def materialize_meta_tensors(model, tensor_space):

tests/conftest.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1+
import dataclasses
2+
import math
3+
import os
4+
5+
import networkx
16
import pytest
7+
import pytest_depends
8+
import pytest_depends.main
9+
import torch
10+
from xdist.scheduler import LoadGroupScheduling
11+
12+
# Make fixtures available globally without import
13+
from tests.common import run_test_script # isort: skip
214

315

416
def pytest_addoption(parser):
@@ -11,13 +23,73 @@ def pytest_addoption(parser):
1123
)
1224

1325

26+
@dataclasses.dataclass
27+
class WorkerResources:
28+
worker_id: int
29+
gpu_id: int | None
30+
num_gpus: int
31+
torchrun_port: int
32+
rendezvous_port: int
33+
34+
35+
MAX_TEST_MEMORY = 5e9
36+
CUDA_CONTEXT_SIZE = 7e8
37+
TORCHRUN_DEFAULT_PORT = 25900
38+
39+
1440
def pytest_configure(config):
1541
config.addinivalue_line("markers", "slow: Test is slow.")
1642
config.addinivalue_line(
1743
"markers", "extra_slow: Mark test as extra slow and skip unless --run-extra-slow is given."
1844
)
45+
# TODO: Spawned processes (multi-gpu, Megatron) ignore resource allocation.
46+
is_parallel = hasattr(config, "workerinput")
47+
if is_parallel:
48+
worker_name = config.workerinput["workerid"]
49+
assert worker_name.startswith("gw")
50+
worker_id = int(worker_name[2:])
51+
else:
52+
worker_id = 0
53+
54+
num_gpus = torch.cuda.device_count()
55+
if num_gpus > 0 and is_parallel:
56+
# We spread workers across GPUs.
57+
gpu_id = worker_id % num_gpus
58+
# We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before cuda initialization.
59+
# The `device_count` call above doesn't initialize, but `mem_get_info` below does.
60+
assert not torch.cuda.is_initialized()
61+
# TODO: Support this?
62+
assert "CUDA_VISIBLE_DEVICES" not in os.environ
63+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str((gpu_id + i) % num_gpus) for i in range(num_gpus))
64+
elif num_gpus > 0:
65+
gpu_id = 0
66+
else:
67+
gpu_id = None
1968

69+
gpu_memory = torch.cuda.mem_get_info(0)[1] if num_gpus > 0 else 0
70+
if num_gpus > 0:
71+
torch.cuda.set_per_process_memory_fraction(MAX_TEST_MEMORY / gpu_memory, 0)
2072

73+
num_workers = config.workerinput["workercount"] if is_parallel else 1
74+
if num_gpus > 0:
75+
memory_needed = (MAX_TEST_MEMORY + CUDA_CONTEXT_SIZE) * math.ceil(num_workers / num_gpus)
76+
if memory_needed > gpu_memory:
77+
raise ValueError(
78+
f"Not enough GPU memory to support this many parallel workers {num_workers}."
79+
f"Please reduce the number of workers to {int(gpu_memory/(MAX_TEST_MEMORY + CUDA_CONTEXT_SIZE))*num_gpus} or less."
80+
)
81+
82+
config.worker_resources = WorkerResources(
83+
worker_id=worker_id,
84+
gpu_id=gpu_id,
85+
num_gpus=num_gpus,
86+
# Each worker needs its own set of ports for safe distributed run. Hopefully these are free.
87+
torchrun_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id,
88+
rendezvous_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id + 1,
89+
)
90+
91+
92+
@pytest.hookimpl(trylast=True)
2193
def pytest_collection_modifyitems(config, items):
2294
if config.getoption("--skip-slow"):
2395
skip_slow = pytest.mark.skip(reason="Skipping slow tests")
@@ -29,3 +101,36 @@ def pytest_collection_modifyitems(config, items):
29101
for item in items:
30102
if "extra_slow" in item.keywords:
31103
item.add_marker(skip_extra_slow)
104+
105+
manager: pytest_depends.DependencyManager = pytest_depends.managers[-1]
106+
# Build the undirected graph as in `DependencyManager.sorted_items`.
107+
dag = networkx.DiGraph()
108+
for item in manager.items:
109+
node_id = pytest_depends.clean_nodeid(item.nodeid)
110+
dag.add_node(node_id)
111+
for dependency in manager.dependencies[node_id].dependencies:
112+
dag.add_edge(dependency, node_id)
113+
# Mark dependency groups for xdist.
114+
manager.groups = {}
115+
for i, node_ids in enumerate(sorted(networkx.weakly_connected_components(dag), key=len, reverse=True)):
116+
if len(node_ids) > 1:
117+
for node_id in node_ids:
118+
manager.nodeid_to_item[node_id]._nodeid = (
119+
f"{manager.nodeid_to_item[node_id]._nodeid}@dependency_group_{i}"
120+
)
121+
122+
old_clean_nodeid = pytest_depends.main.clean_nodeid
123+
# Hack into `clean_nodeid` so pytest_depends recognizes the renamed nodes.
124+
pytest_depends.main.clean_nodeid = lambda nodeid: old_clean_nodeid(nodeid.split("@dependency_group_")[0])
125+
126+
127+
@pytest.fixture(scope="session")
128+
def worker_resources(request) -> WorkerResources:
129+
return request.config.worker_resources
130+
131+
132+
@pytest.mark.trylast
133+
def pytest_xdist_make_scheduler(config, log):
134+
# Always use grouped load balancing to handle dependencies, and make it work with `-n`.
135+
assert config.getvalue("dist") == "load"
136+
return LoadGroupScheduling(config, log)

0 commit comments

Comments
 (0)