Skip to content

MSC Checkpointing Changes #789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
c2e5db8
Initial commit
chris-hawes Feb 11, 2025
61742fc
Working changes to be cleaned up.
chris-hawes Feb 13, 2025
9d74552
Rename msc_config.yaml
chris-hawes Feb 13, 2025
12c8c88
Fixed pytorch test issue by removing MSC Cache
chris-hawes Feb 14, 2025
4e90e42
Clean up
chris-hawes Feb 14, 2025
08b1557
Clean up
chris-hawes Feb 14, 2025
546285c
Clean up
chris-hawes Feb 14, 2025
cbbdf11
Merge branch 'main' into chawes/initial-msc-checkpointing
chris-hawes Feb 14, 2025
5226264
Updated project dependencies
chris-hawes Feb 14, 2025
6be3b18
Find MSC config using absolute path.
chris-hawes Feb 18, 2025
d681bec
Re-added cuda test parameter.
chris-hawes Feb 18, 2025
ef07910
Moved MSC Config file
chris-hawes Feb 18, 2025
35c75f5
Rename MSC config file
chris-hawes Feb 18, 2025
1c68712
Add test to read from public S3 bucket using MSC.
chris-hawes Feb 18, 2025
3906bb8
Added MSC comment
chris-hawes Feb 18, 2025
8b1f17f
Clean up
chris-hawes Feb 18, 2025
a9f0f29
Revert save_checkpoint_freq value.
chris-hawes Feb 18, 2025
82b1885
Remove temporary printing
chris-hawes Feb 18, 2025
1a5dd4f
Remove unnecessary dependency
chris-hawes Feb 24, 2025
0f70aac
Switched to use consistent mechanism for detecting msc URIs
chris-hawes Feb 24, 2025
e70fc58
Changes from code review.
chris-hawes Feb 25, 2025
68f88b3
Fix missing variable.
chris-hawes Feb 25, 2025
3624a88
Fix missing variable.
chris-hawes Feb 25, 2025
9b48137
Moved fsspec.filesystem logic into filesystem.py
chris-hawes Feb 25, 2025
ba294c9
Change to cache for non-file protocols when reading non-modulus models.
chris-hawes Feb 27, 2025
e0c0881
Moved code to generate checkpoint directory.directory
chris-hawes Mar 11, 2025
afedf58
Added get_checkpoint_dir import
chris-hawes Mar 11, 2025
8a57a9f
Address review feedback.
chris-hawes Mar 11, 2025
d238968
Changes from code review.
chris-hawes Mar 11, 2025
4643d83
Add comment per code review.:w
chris-hawes Mar 11, 2025
38b73d7
Addressed file test issue from review.
chris-hawes Mar 19, 2025
b0d1db8
Fix to file existence check.
chris-hawes Mar 19, 2025
77025fb
Merge branch 'main' into chawes/initial-msc-checkpointing
chris-hawes Mar 19, 2025
450ea6e
Fix merge conflicts due to project name change.
chris-hawes Mar 20, 2025
19b6f78
Merge branch 'main' into chawes/initial-msc-checkpointing
chris-hawes Mar 20, 2025
f61124c
Merge branch 'main' into chawes/initial-msc-checkpointing
chris-hawes Mar 21, 2025
e4a9a90
Updated CHANGELOG.
chris-hawes Mar 25, 2025
48b76be
Added Multi-Storage Client to allow checkpointing to/from Object Storage
chris-hawes Mar 25, 2025
089b6f6
Addressed issues identified by pre-commit.
chris-hawes Mar 31, 2025
9f3f59b
Merge branch 'main' into chawes/initial-msc-checkpointing
chris-hawes Mar 31, 2025
d36cc64
Merge branch 'main' into chawes/initial-msc-checkpointing
NickGeneva Mar 31, 2025
e746142
Update filesystem.py
NickGeneva Mar 31, 2025
13f7fb0
Update __init__.py
NickGeneva Mar 31, 2025
782cb46
Update Dockerfile
NickGeneva Mar 31, 2025
c3eed5e
Merge branch 'main' into chawes/initial-msc-checkpointing
NickGeneva Apr 1, 2025
8c74f19
Update Dockerfile
NickGeneva Apr 1, 2025
82e503b
Merge branch 'main' into chawes/initial-msc-checkpointing
NickGeneva Apr 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/generative/corrdiff/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper
from modulus.metrics.diffusion import RegressionLoss, ResLoss, RegressionLossCE
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper
from modulus.launch.utils import load_checkpoint, save_checkpoint
from modulus.launch.utils import load_checkpoint, save_checkpoint, get_checkpoint_dir
from datasets.dataset import init_train_valid_datasets_from_config
from helpers.train_helpers import (
set_patch_shape,
Expand Down Expand Up @@ -66,9 +66,7 @@ def main(cfg: DictConfig) -> None:
enable_amp = fp_optimizations.startswith("amp")
amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16
logger.info(f"Saving the outputs in {os.getcwd()}")
checkpoint_dir = os.path.join(
cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}"
)
checkpoint_dir = get_checkpoint_dir(str(cfg.training.io.get("checkpoint_dir", ".")), cfg.model.name)
if cfg.training.hp.batch_size_per_gpu == "auto":
cfg.training.hp.batch_size_per_gpu = (
cfg.training.hp.total_batch_size // dist.world_size
Expand Down
2 changes: 1 addition & 1 deletion modulus/launch/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .checkpoint import load_checkpoint, save_checkpoint
from .checkpoint import load_checkpoint, save_checkpoint, get_checkpoint_dir
90 changes: 65 additions & 25 deletions modulus/launch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import re
import os, re
from pathlib import Path
from typing import Any, Dict, List, NewType, Optional, Union

import fsspec
import fsspec.utils
import torch
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -27,6 +28,7 @@
from modulus.distributed import DistributedManager
from modulus.launch.logging import PythonLogger
from modulus.utils.capture import _StaticCapture
from modulus.utils.filesystem import _download_cached, LOCAL_CACHE

optimizer = NewType("optimizer", torch.optim)
scheduler = NewType("scheduler", _LRScheduler)
Expand Down Expand Up @@ -86,10 +88,13 @@ def _get_checkpoint_filename(
else 0
)

# Input file name
checkpoint_filename = str(
Path(path).resolve() / f"{base_name}.{model_parallel_rank}"
)
# Determine input file name. Get absolute file path if Posix path.
# pathlib does not support custom schemes (eg: msc://...) so only perform resolve() for Posix.
protocol = fsspec.utils.get_protocol(path)
fs = fsspec.filesystem(protocol)
if protocol == "file":
path = str(Path(path).resolve())
checkpoint_filename = f"{path}/{base_name}.{model_parallel_rank}"

# File extension for Modulus models or PyTorch models
file_extension = ".mdlus" if model_type == "mdlus" else ".pt"
Expand All @@ -101,20 +106,21 @@ def _get_checkpoint_filename(
# Otherwise try loading the latest epoch or rolling checkpoint
else:
file_names = [
Path(fname).name
for fname in glob.glob(
checkpoint_filename + "*" + file_extension, recursive=False
)
fname for fname in fs.glob(checkpoint_filename + "*" + file_extension)
]

if len(file_names) > 0:
# If checkpoint from a null index save exists load that
# This is the most likely line to error since it will fail with
# invalid checkpoint names

# Remove protocol prefix if present to allow generic matching
_, path_without_protocol = fsspec.core.split_protocol(path)
file_idx = [
int(
re.sub(
f"^{base_name}.{model_parallel_rank}.|" + file_extension,
f"^{path_without_protocol}/{base_name}.{model_parallel_rank}.|"
+ file_extension,
"",
fname,
)
Expand Down Expand Up @@ -212,8 +218,11 @@ def save_checkpoint(
metadata : Optional[Dict[str, Any]], optional
Additional metadata to save, by default None
"""
# Create checkpoint directory if it does not exist
if not Path(path).is_dir():
protocol = fsspec.utils.get_protocol(path)
fs = fsspec.filesystem(protocol)
# Create checkpoint directory if it does not exist.
# Only applicable to Posix filesystems ("file" protocol), not object stores.
if protocol == "file" and not Path(path).is_dir():
checkpoint_logging.warning(
f"Output directory {path} does not exist, will " "attempt to create"
)
Expand All @@ -237,7 +246,8 @@ def save_checkpoint(
if isinstance(model, modulus.models.Module):
model.save(file_name)
else:
torch.save(model.state_dict(), file_name)
with fs.open(file_name, "wb") as fp:
torch.save(model.state_dict(), fp)
checkpoint_logging.success(f"Saved model state dictionary: {file_name}")

# == Saving training checkpoint ==
Expand Down Expand Up @@ -268,10 +278,11 @@ def save_checkpoint(

# Save checkpoint to memory
if bool(checkpoint_dict):
torch.save(
checkpoint_dict,
output_filename,
)
with fs.open(output_filename, "wb") as fp:
torch.save(
checkpoint_dict,
fp,
)
checkpoint_logging.success(f"Saved training checkpoint: {output_filename}")


Expand Down Expand Up @@ -316,8 +327,15 @@ def load_checkpoint(
int
Loaded epoch
"""
fs = fsspec.filesystem(fsspec.utils.get_protocol(path))
# Check if checkpoint directory exists
if not Path(path).is_dir():
try:
info = fs.info(path)
if info["type"] == "file":
raise FileNotFoundError(
f"Provided checkpoint directory {path} is a file, not directory"
)
except FileNotFoundError:
checkpoint_logging.warning(
f"Provided checkpoint directory {path} does not exist, skipping load"
)
Expand All @@ -336,7 +354,7 @@ def load_checkpoint(
file_name = _get_checkpoint_filename(
path, name, index=epoch, model_type=model_type
)
if not Path(file_name).exists():
if not fs.exists(file_name):
checkpoint_logging.error(
f"Could not find valid model file {file_name}, skipping load"
)
Expand All @@ -345,21 +363,22 @@ def load_checkpoint(
if isinstance(model, modulus.models.Module):
model.load(file_name)
else:
model.load_state_dict(torch.load(file_name, map_location=device))

file_to_load = _cache_if_needed(file_name)
model.load_state_dict(torch.load(file_to_load, map_location=device))
checkpoint_logging.success(
f"Loaded model state dictionary {file_name} to device {device}"
)

# == Loading training checkpoint ==
checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt")
if not Path(checkpoint_filename).is_file():
if not fs.exists(checkpoint_filename):
checkpoint_logging.warning(
"Could not find valid checkpoint file, skipping load"
)
return 0

checkpoint_dict = torch.load(checkpoint_filename, map_location=device)

file_to_load = _cache_if_needed(checkpoint_filename)
checkpoint_dict = torch.load(file_to_load, map_location=device)
checkpoint_logging.success(
f"Loaded checkpoint file {checkpoint_filename} to device {device}"
)
Expand Down Expand Up @@ -393,3 +412,24 @@ def load_checkpoint(
metadata_dict[key] = value

return epoch

# Get a checkpoint directory based on a given base directory and model name
def get_checkpoint_dir(base_dir: str, model_name: str) -> str:
top_level_dir = f"checkpoints_{model_name}"
protocol = fsspec.utils.get_protocol(base_dir)
if protocol == "msc":
if not base_dir.endswith("/"):
base_dir += "/"
return base_dir + top_level_dir
else:
return os.path.join(
base_dir, top_level_dir
)

# Read via cache and return the cached path for non-file protocols, otherwise just return the path
def _cache_if_needed(path: str) -> str:
protocol = fsspec.utils.get_protocol(path)
if protocol == "file":
return path
else:
return _download_cached(path, recursive=False, local_cache_path=os.path.join(LOCAL_CACHE, f"checkpoint_pid_{os.getpid()}"))
7 changes: 4 additions & 3 deletions modulus/utils/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import zipfile

import fsspec
import fsspec.utils
import fsspec.implementations.cached
import requests
import s3fs
Expand All @@ -46,7 +47,7 @@ def _get_fs(path):
if path.startswith("s3://"):
return s3fs.S3FileSystem(client_kwargs=dict(endpoint_url="https://pbss.s8k.io"))
else:
return fsspec.filesystem("file")
return fsspec.filesystem(fsspec.utils.get_protocol(path))


def _download_ngc_model_file(path: str, out_path: str, timeout: int = 300) -> str:
Expand Down Expand Up @@ -175,8 +176,8 @@ def _download_cached(
# TODO watch for race condition here
if not os.path.exists(cache_path):
logger.debug("Downloading %s to cache: %s", path, cache_path)
if path.startswith("s3://"):
fs = _get_fs(path)
if url.scheme in ("s3", "msc"):
fs = fsspec.filesystem(fsspec.utils.get_protocol(path))
fs.get(path, cache_path, recursive=recursive)
elif path.startswith("ngc://models/"):
path = _download_ngc_model_file(path, cache_path)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dev = [
"interrogate==1.5.0",
"coverage==6.5.0",
"ruff==0.0.290",
"moto[s3]>=5.0.28",
]

makani = [
Expand All @@ -76,7 +77,7 @@ fignet = [
]

storage = [
"multi-storage-client>=0.14.0",
"multi-storage-client[boto3]>=0.14.0",
]

all = [
Expand Down
30 changes: 30 additions & 0 deletions test/utils/msc_config_checkpoint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# This is an example MSC configuration file for testing checkpoint logic.
profiles:
checkpoint-test:
storage_provider:
type: s3
options:
region_name: us-east-1
base_path: checkpoint-test-bucket
credentials_provider:
type: S3Credentials
options:
access_key: "access-key-id"
secret_key: "secret-access-key"
31 changes: 31 additions & 0 deletions test/utils/msc_config_public_read.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# This is an example MSC configuration file for accessing the CMIP6 archive on AWS:
# https://registry.opendata.aws/cmip6/
profiles:
cmip6-pds:
storage_provider:
type: s3
options:
region_name: us-west-2
base_path: cmip6-pds
signature_version: UNSIGNED
cache:
location: /tmp/.cache
size_mb: 5000

Loading