Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions benchmarks/ddp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ def rank_0_print(msg: str) -> None:
if dist.get_rank() == 0:
print(msg)


import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--work-dir", default="/tmp")
parser.add_argument("--work-dir", default="tmp")
parser.add_argument("--param-size", type=int, default=int(100_000_000))
parser.add_argument("--num-params", type=int, default=200)
args = parser.parse_args()
Expand All @@ -50,21 +53,37 @@ def rank_0_print(msg: str) -> None:
sz = sum(t.nelement() * t.element_size() for t in model.parameters())
rank_0_print(f"Model size: {sz / 1_000_000_000.0} GB")

if dist.get_rank() == 0:
print("Saving the model with torch.save...")
t_begin = time.time()
with open(f"{args.work_dir}/{uuid.uuid4()}.pt", "wb+") as f:
torch.save(model.state_dict(), f)
print(f"Took {time.time() - t_begin} seconds with torch.save")
dist.barrier()
# rank_0_print("Saving the model with torchsnapshot...")
# t_begin = time.monotonic()
# app_state = {"model": model}
# snapshot = torchsnapshot.Snapshot.take(
# path=f"{args.work_dir}/{uuid.uuid4()}",
# app_state=app_state,
# replicated=["**"],
# )
# os.sync()
# rank_0_print(f"Snapshot path: {snapshot.path}")
# rank_0_print(f"Took {time.monotonic() - t_begin} seconds with torchsnapshot")
# dist.barrier()

_patch_model_state_dict(model)
rank_0_print("Saving the model with DCP...")
for num_threads in range(32, 33):
checkpointer = DCP.FileSystemCheckpointer(
f"{args.work_dir}/{uuid.uuid4()}",
thread_count=num_threads
)

begin_ts = time.monotonic()
checkpointer.save(state_dict={"model": model})
end_ts = time.monotonic()
rank_0_print(f"{num_threads}, {time.monotonic() - begin_ts}")

dist.barrier()
if dist.get_rank() == 0:
import shutil
# Delete a directory and all its contents
shutil.rmtree(args.work_dir)
shutil.os.makedirs(args.work_dir)

rank_0_print("Saving the model with torchsnapshot...")
t_begin = time.time()
app_state = {"model": model}
snapshot = torchsnapshot.Snapshot.take(
path=f"{args.work_dir}/{uuid.uuid4()}",
app_state=app_state,
replicated=["**"],
)
rank_0_print(f"Snapshot path: {snapshot.path}")
rank_0_print(f"Took {time.time() - t_begin} seconds with torchsnapshot")
dist.barrier()
54 changes: 50 additions & 4 deletions benchmarks/fsdp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
class BenchmarkType(Enum):
TORCHSNAPSHOT = "torchsnapshot"
TORCH_SAVE = "torch_save"
DCP = "dcp"


def __str__(self):
return self.value
Expand Down Expand Up @@ -78,7 +80,7 @@ def benchmark_torchsnapshot(
rank_0_print("Saving a checkpoint with torchsnapshot...")
app_state = {"model": model}
begin_ts = time.monotonic()
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
Snapshot.take(
path=save_dir,
app_state=app_state,
Expand Down Expand Up @@ -112,7 +114,7 @@ def benchmark_torchsave(model: nn.Module, save_dir: str, benchmark_load: bool) -
begin_ts = time.monotonic()
with FSDP.state_dict_type(
model,
StateDictType.LOCAL_STATE_DICT,
StateDictType.SHARDED_STATE_DICT,
):
state_dict = model.state_dict()
torch.save(state_dict, save_file)
Expand All @@ -134,6 +136,48 @@ def benchmark_torchsave(model: nn.Module, save_dir: str, benchmark_load: bool) -
f"Took {end_ts - begin_ts:.2f} seconds."
)

import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict,
)
import shutil

def benchmark_dcp(model: nn.Module, save_dir: str, benchmark_load: bool) -> None:
rank_0_print("Saving a checkpoint with DCP.save...")

os.makedirs(save_dir, exist_ok=True)
save_file = f"{save_dir}/state_dict-{dist.get_rank()}.pt"
_patch_model_state_dict(model)

for num_threads in range(1, 17):
checkpointer = DCP.FileSystemCheckpointer(save_file, thread_count=num_threads)

begin_ts = time.monotonic()
checkpointer.save(state_dict={"model": model})
end_ts = time.monotonic()
rank_0_print(end_ts - begin_ts)
dist.barrier()

if dist.get_rank() == 0:
import shutil
# Delete a directory and all its contents
shutil.rmtree(args.work_dir)
os.makedirs(args.work_dir)

dist.barrier()
print()

if benchmark_load:
begin_ts = time.monotonic()
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
model.load_state_dict(torch.load(save_file))
dist.barrier()
end_ts = time.monotonic()
rank_0_print(
f"Completed loading with torch.save.\n"
f"Took {end_ts - begin_ts:.2f} seconds."
)


@record
def main(benchmark_type: BenchmarkType, work_dir: str, benchmark_load: bool) -> None:
Expand All @@ -155,6 +199,8 @@ def main(benchmark_type: BenchmarkType, work_dir: str, benchmark_load: bool) ->
benchmark_torchsnapshot(model, save_dir, benchmark_load)
elif benchmark_type == BenchmarkType.TORCH_SAVE:
benchmark_torchsave(model, save_dir, benchmark_load)
elif benchmark_type == BenchmarkType.DCP:
benchmark_dcp(model, save_dir, benchmark_load)
else:
raise ValueError(f"Unrecognized benchmark type: {benchmark_type}")

Expand All @@ -165,9 +211,9 @@ def main(benchmark_type: BenchmarkType, work_dir: str, benchmark_load: bool) ->
"--benchmark-type",
type=BenchmarkType,
choices=list(BenchmarkType),
default=BenchmarkType.TORCHSNAPSHOT,
default=BenchmarkType.DCP,
)
parser.add_argument("--work-dir", default="/tmp")
parser.add_argument("--work-dir", default="~/tmp")
parser.add_argument("--benchmark-load", action="store_true", default=False)

args: argparse.Namespace = parser.parse_args()
Expand Down
1 change: 1 addition & 0 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def async_take(
``.done()`` method for querying the progress and a ``.wait()``
method for waiting for the snapshot's completion.
"""
print("USING ASYNC")
torch._C._log_api_usage_once("torchsnapshot.Snapshot.async_take")
cls._validate_app_state(app_state)

Expand Down