Skip to content

Commit

Permalink
Add option to throttle checkpoint uploads to one rank from each node …
Browse files Browse the repository at this point in the history
…at a time (#142)
  • Loading branch information
epwalsh authored Jan 21, 2025
1 parent 7633461 commit 212108f
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a callback for sending Slack notifications.
- Added `SkipStepAdamW` optimizer.
- The trainer can load model-only checkpoints now.
- Added the option to throttle checkpoint uploads to one rank from each node at a time.

### Changed

Expand Down
31 changes: 28 additions & 3 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@
def save_state_dict(
dir: PathOrStr,
state_dict: Dict[str, Any],
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
):
"""
Save an arbitrary state dictionary to a distributed format that can loaded again with
Expand All @@ -77,11 +79,19 @@ def save_state_dict(
:param state_dict: The state dict to save.
:param process_group: The process group to use for distributed collectives.
:param save_overwrite: Overwrite existing files.
:param thread_count: Set this to override the number of threads used while writing data.
:param throttle_uploads: If this is set to ``True`` and ``dir`` is a URL then only one
rank from each node will upload data at a time.
"""
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
)

Expand All @@ -95,6 +105,7 @@ def save_model_and_optim_state(
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
) -> None:
"""
Save model and optimizer state dictionaries. The model state can be a sharded model, in which
Expand All @@ -117,6 +128,9 @@ def save_model_and_optim_state(
:param optim: The optimizer to save state from.
:param process_group: The process group to use for distributed collectives.
:param save_overwrite: Overwrite existing files.
:param thread_count: Set this to override the number of threads used while writing data.
:param throttle_uploads: If this is set to ``True`` and ``dir`` is a URL then only one
rank from each node will upload data at a time.
:raises FileExistsError: If the checkpoint dir exists and is non-empty unless ``save_overwrite=True``.
"""
Expand All @@ -125,7 +139,12 @@ def save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
planner=planner,
)
Expand All @@ -140,6 +159,7 @@ def async_save_model_and_optim_state(
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
) -> Future[None]:
"""
An async version of :func:`save_model_and_optim_state()`.
Expand All @@ -151,7 +171,12 @@ def async_save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
return dist_cp.state_dict_saver.async_save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
planner=planner,
)
Expand Down
36 changes: 28 additions & 8 deletions src/olmo_core/distributed/checkpoint/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, cast

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.filesystem import WriteResult
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
Expand All @@ -25,6 +27,7 @@
from torch.futures import Future

from olmo_core.aliases import PathOrStr
from olmo_core.distributed.utils import do_n_at_a_time
from olmo_core.exceptions import OLMoCheckpointError
from olmo_core.io import (
get_bytes_range,
Expand Down Expand Up @@ -154,12 +157,16 @@ def __init__(
self,
path: PathOrStr,
thread_count: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
throttle_uploads: bool = False,
) -> None:
super().__init__()
if thread_count is not None and thread_count <= 0:
raise ValueError("thread count must be at least 1")
self.path = normalize_path(path)
self.thread_count = thread_count or get_default_thread_count()
self.process_group = process_group
self.throttle_uploads = throttle_uploads
self.save_id = generate_uuid()

def reset(self, checkpoint_id: Optional[PathOrStr] = None) -> None:
Expand Down Expand Up @@ -201,22 +208,35 @@ def gen_file_name() -> str:
file_count += 1
return file_name

with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
futures = []
for bucket in _split_by_size_and_type(self.thread_count, plan.items):
def write_items(buckets: List[List[WriteItem]]) -> List[WriteResult]:
results: List[WriteResult] = []
for bucket in buckets:
file_name = gen_file_name()
path = f"{self.path}/{file_name}"
futures.append(executor.submit(_write_items, path, file_name, bucket, planner))

results = []
for f in as_completed(futures):
try:
results += f.result()
results.extend(_write_items(path, file_name, bucket, planner))
except BaseException:
# NOTE: we might get an error here that can't be pickled, which causes a different failure
# later when PyTorch tries to reduce that error across ranks. So here we just make
# sure we're raising a simple error type that can be pickled.
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
return results

results: List[WriteResult]
if self.throttle_uploads and is_url(self.path):
buckets = _split_by_size_and_type(1, plan.items)
results = do_n_at_a_time(
partial(write_items, buckets), process_group=self.process_group
)
else:
buckets = _split_by_size_and_type(self.thread_count, plan.items)
results = []
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
futures = []
for bucket in buckets:
futures.append(executor.submit(write_items, [bucket]))
for f in as_completed(futures):
results.extend(f.result())

fut: Future[List[WriteResult]] = Future()
fut.set_result(results)
Expand Down
36 changes: 35 additions & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""

import logging
import math
import os
from datetime import timedelta
from typing import List, Optional, TypeVar
from typing import Callable, List, Optional, TypeVar, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -421,3 +422,36 @@ def get_local_tensor(x: torch.Tensor) -> torch.Tensor:
return x.to_local()
else:
return x


def do_n_at_a_time(
f: Callable[[], T],
*,
n: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
world_size: Optional[int] = None,
local_rank: Optional[int] = None,
) -> T:
"""
Call a function ``f`` in a distributed context from at most ``n`` ranks at a time.
All ranks will eventually call the given function exactly once, at which point this function
will return.
:param f: The function to call from each rank.
:param n: The level of concurrency, i.e. how many ranks are allowed to call ``f`` at once.
This defaults to the number of nodes, in which case one rank from each node will
call ``f`` at a time.
:param process_group: The process group to use.
"""
world_size = world_size if world_size is not None else get_world_size(process_group)
local_rank = local_rank if local_rank is not None else get_rank(process_group)
n = n if n is not None else get_num_nodes()
group_count = math.ceil(world_size / n)
group_rank = local_rank % group_count
result: Optional[T] = None
for active_group in range(group_count):
if group_rank == active_group:
result = f()
barrier(process_group)
return cast(T, result)
4 changes: 4 additions & 0 deletions src/olmo_core/train/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CheckpointerConfig(Config):
pre_download: bool = False
save_thread_count: Optional[int] = None
load_thread_count: Optional[int] = None
throttle_uploads: bool = False

def build(self, process_group: Optional[dist.ProcessGroup] = None, **kwargs) -> "Checkpointer":
kwargs = {**self.as_dict(exclude_none=True, recurse=False), **kwargs}
Expand Down Expand Up @@ -86,6 +87,7 @@ class Checkpointer:
process_group: Optional[dist.ProcessGroup] = None
save_thread_count: Optional[int] = None
load_thread_count: Optional[int] = None
throttle_uploads: bool = False

def __post_init__(self):
self.work_dir = Path(self.work_dir)
Expand All @@ -112,6 +114,7 @@ def save(self, dir: PathOrStr, model: nn.Module, optim: Optimizer, train_state:
process_group=self.process_group,
save_overwrite=self.save_overwrite,
thread_count=self.save_thread_count,
throttle_uploads=self.throttle_uploads,
)

self._save_metadata(dir, CheckpointMetadata())
Expand Down Expand Up @@ -142,6 +145,7 @@ def save_async(
process_group=self.process_group,
save_overwrite=self.save_overwrite,
thread_count=self.save_thread_count,
throttle_uploads=self.throttle_uploads,
)

def done_callback(fut: Future):
Expand Down
16 changes: 11 additions & 5 deletions src/test/distributed/checkpoint/filesystem_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..utils import BACKENDS, get_default_device, run_distributed_test


def run_save_and_load_with_dtensors(dir):
def run_save_and_load_with_dtensors(dir, throttle: bool = False):
mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),))

x_full = torch.randn(4, 4, device=get_default_device())
Expand All @@ -30,7 +30,7 @@ def run_save_and_load_with_dtensors(dir):
distcp.state_dict_saver.save(
{"x": x, "y": y},
checkpoint_id=dir,
storage_writer=RemoteFileSystemWriter(dir, thread_count=2),
storage_writer=RemoteFileSystemWriter(dir, thread_count=2, throttle_uploads=throttle),
)

# Now create new sharded copies with a different sharding strategy and load the checkpoint.
Expand All @@ -51,11 +51,17 @@ def run_save_and_load_with_dtensors(dir):

@pytest.mark.parametrize("backend", BACKENDS)
def test_save_and_load_locally_with_dtensors(backend, tmp_path):
run_distributed_test(run_save_and_load_with_dtensors, backend=backend, func_args=(tmp_path,))
run_distributed_test(
run_save_and_load_with_dtensors,
backend=backend,
func_args=(tmp_path,),
start_method="spawn",
)


@pytest.mark.parametrize("backend", BACKENDS)
def test_save_and_load_remotely_with_dtensors(backend, s3_checkpoint_dir):
@pytest.mark.parametrize("throttle", [True, False])
def test_save_and_load_remotely_with_dtensors(backend, s3_checkpoint_dir, throttle):
from botocore.exceptions import NoCredentialsError

try:
Expand All @@ -66,6 +72,6 @@ def test_save_and_load_remotely_with_dtensors(backend, s3_checkpoint_dir):
run_distributed_test(
run_save_and_load_with_dtensors,
backend=backend,
func_args=(s3_checkpoint_dir,),
func_args=(s3_checkpoint_dir, throttle),
start_method="spawn", # NOTE: forking causes a crash with boto3
)
10 changes: 9 additions & 1 deletion src/test/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import logging
import os
import sys
from typing import Any, Callable, Dict, Optional, Tuple

Expand All @@ -8,7 +9,11 @@
import torch.distributed as dist
import torch.multiprocessing as mp

from olmo_core.distributed.utils import is_distributed
from olmo_core.distributed.utils import (
OLMO_LOCAL_WORLD_SIZE_ENV_VAR,
OLMO_NUM_NODES_ENV_VAR,
is_distributed,
)

from ..utils import (
DEVICES,
Expand Down Expand Up @@ -115,6 +120,9 @@ def log_record_factory(*args, **kwargs) -> logging.LogRecord:
timeout=datetime.timedelta(seconds=120),
)

os.environ.setdefault(OLMO_NUM_NODES_ENV_VAR, "1")
os.environ.setdefault(OLMO_LOCAL_WORLD_SIZE_ENV_VAR, str(world_size))

log.info("Starting test...")

if "nccl" in backend:
Expand Down
19 changes: 19 additions & 0 deletions src/test/distributed/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import pytest
import torch.distributed as dist

Expand All @@ -18,3 +20,20 @@ def scatter_object():
@pytest.mark.parametrize("backend", BACKENDS)
def test_scatter_object(backend: str):
run_distributed_test(scatter_object, backend=backend)


@pytest.mark.parametrize("n, world_size", [(2, 1), (8, 64)])
def test_do_n_at_a_time(n: int, world_size: int):
times_called = 0
calling_ranks = set()

def func(rank: int):
nonlocal times_called
times_called += 1
calling_ranks.add(rank)

for rank in range(world_size):
dist_utils.do_n_at_a_time(partial(func, rank), n=n, world_size=world_size, local_rank=rank)

assert times_called == world_size
assert calling_ranks == set(range(world_size))

0 comments on commit 212108f

Please sign in to comment.