diff --git a/benchmarks/ddp/main.py b/benchmarks/ddp/main.py index 257cc99..435b6fa 100644 --- a/benchmarks/ddp/main.py +++ b/benchmarks/ddp/main.py @@ -31,10 +31,13 @@ def rank_0_print(msg: str) -> None: if dist.get_rank() == 0: print(msg) - +import torch.distributed.checkpoint as DCP +from torch.distributed.checkpoint.state_dict import ( + _patch_model_state_dict, +) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--work-dir", default="/tmp") + parser.add_argument("--work-dir", default="tmp") parser.add_argument("--param-size", type=int, default=int(100_000_000)) parser.add_argument("--num-params", type=int, default=200) args = parser.parse_args() @@ -50,21 +53,37 @@ def rank_0_print(msg: str) -> None: sz = sum(t.nelement() * t.element_size() for t in model.parameters()) rank_0_print(f"Model size: {sz / 1_000_000_000.0} GB") - if dist.get_rank() == 0: - print("Saving the model with torch.save...") - t_begin = time.time() - with open(f"{args.work_dir}/{uuid.uuid4()}.pt", "wb+") as f: - torch.save(model.state_dict(), f) - print(f"Took {time.time() - t_begin} seconds with torch.save") - dist.barrier() + # rank_0_print("Saving the model with torchsnapshot...") + # t_begin = time.monotonic() + # app_state = {"model": model} + # snapshot = torchsnapshot.Snapshot.take( + # path=f"{args.work_dir}/{uuid.uuid4()}", + # app_state=app_state, + # replicated=["**"], + # ) + # os.sync() + # rank_0_print(f"Snapshot path: {snapshot.path}") + # rank_0_print(f"Took {time.monotonic() - t_begin} seconds with torchsnapshot") + # dist.barrier() + + _patch_model_state_dict(model) + rank_0_print("Saving the model with DCP...") + for num_threads in range(32, 33): + checkpointer = DCP.FileSystemCheckpointer( + f"{args.work_dir}/{uuid.uuid4()}", + thread_count=num_threads + ) + + begin_ts = time.monotonic() + checkpointer.save(state_dict={"model": model}) + end_ts = time.monotonic() + rank_0_print(f"{num_threads}, {time.monotonic() - begin_ts}") + + dist.barrier() + if dist.get_rank() == 0: + import shutil + # Delete a directory and all its contents + shutil.rmtree(args.work_dir) + shutil.os.makedirs(args.work_dir) - rank_0_print("Saving the model with torchsnapshot...") - t_begin = time.time() - app_state = {"model": model} - snapshot = torchsnapshot.Snapshot.take( - path=f"{args.work_dir}/{uuid.uuid4()}", - app_state=app_state, - replicated=["**"], - ) - rank_0_print(f"Snapshot path: {snapshot.path}") - rank_0_print(f"Took {time.time() - t_begin} seconds with torchsnapshot") + dist.barrier() diff --git a/benchmarks/fsdp/main.py b/benchmarks/fsdp/main.py index 5c011db..20f9cba 100644 --- a/benchmarks/fsdp/main.py +++ b/benchmarks/fsdp/main.py @@ -22,6 +22,8 @@ class BenchmarkType(Enum): TORCHSNAPSHOT = "torchsnapshot" TORCH_SAVE = "torch_save" + DCP = "dcp" + def __str__(self): return self.value @@ -78,7 +80,7 @@ def benchmark_torchsnapshot( rank_0_print("Saving a checkpoint with torchsnapshot...") app_state = {"model": model} begin_ts = time.monotonic() - with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): Snapshot.take( path=save_dir, app_state=app_state, @@ -112,7 +114,7 @@ def benchmark_torchsave(model: nn.Module, save_dir: str, benchmark_load: bool) - begin_ts = time.monotonic() with FSDP.state_dict_type( model, - StateDictType.LOCAL_STATE_DICT, + StateDictType.SHARDED_STATE_DICT, ): state_dict = model.state_dict() torch.save(state_dict, save_file) @@ -134,6 +136,48 @@ def benchmark_torchsave(model: nn.Module, save_dir: str, benchmark_load: bool) - f"Took {end_ts - begin_ts:.2f} seconds." ) +import torch.distributed.checkpoint as DCP +from torch.distributed.checkpoint.state_dict import ( + _patch_model_state_dict, +) +import shutil + +def benchmark_dcp(model: nn.Module, save_dir: str, benchmark_load: bool) -> None: + rank_0_print("Saving a checkpoint with DCP.save...") + + os.makedirs(save_dir, exist_ok=True) + save_file = f"{save_dir}/state_dict-{dist.get_rank()}.pt" + _patch_model_state_dict(model) + + for num_threads in range(1, 17): + checkpointer = DCP.FileSystemCheckpointer(save_file, thread_count=num_threads) + + begin_ts = time.monotonic() + checkpointer.save(state_dict={"model": model}) + end_ts = time.monotonic() + rank_0_print(end_ts - begin_ts) + dist.barrier() + + if dist.get_rank() == 0: + import shutil + # Delete a directory and all its contents + shutil.rmtree(args.work_dir) + os.makedirs(args.work_dir) + + dist.barrier() + print() + + if benchmark_load: + begin_ts = time.monotonic() + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + model.load_state_dict(torch.load(save_file)) + dist.barrier() + end_ts = time.monotonic() + rank_0_print( + f"Completed loading with torch.save.\n" + f"Took {end_ts - begin_ts:.2f} seconds." + ) + @record def main(benchmark_type: BenchmarkType, work_dir: str, benchmark_load: bool) -> None: @@ -155,6 +199,8 @@ def main(benchmark_type: BenchmarkType, work_dir: str, benchmark_load: bool) -> benchmark_torchsnapshot(model, save_dir, benchmark_load) elif benchmark_type == BenchmarkType.TORCH_SAVE: benchmark_torchsave(model, save_dir, benchmark_load) + elif benchmark_type == BenchmarkType.DCP: + benchmark_dcp(model, save_dir, benchmark_load) else: raise ValueError(f"Unrecognized benchmark type: {benchmark_type}") @@ -165,9 +211,9 @@ def main(benchmark_type: BenchmarkType, work_dir: str, benchmark_load: bool) -> "--benchmark-type", type=BenchmarkType, choices=list(BenchmarkType), - default=BenchmarkType.TORCHSNAPSHOT, + default=BenchmarkType.DCP, ) - parser.add_argument("--work-dir", default="/tmp") + parser.add_argument("--work-dir", default="~/tmp") parser.add_argument("--benchmark-load", action="store_true", default=False) args: argparse.Namespace = parser.parse_args() diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index 0bf73e7..b6ef859 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -248,6 +248,7 @@ def async_take( ``.done()`` method for querying the progress and a ``.wait()`` method for waiting for the snapshot's completion. """ + print("USING ASYNC") torch._C._log_api_usage_once("torchsnapshot.Snapshot.async_take") cls._validate_app_state(app_state)