Skip to content

Commit 51a90bd

Browse files
MSC Checkpointing Changes (#789)
* Working changes to be cleaned up. * Rename msc_config.yaml * Fixed pytorch test issue by removing MSC Cache * Updated project dependencies * Find MSC config using absolute path. * Re-added cuda test parameter. * Add test to read from public S3 bucket using MSC. * Revert save_checkpoint_freq value. * Remove temporary printing * Remove unnecessary dependency * Switched to use consistent mechanism for detecting msc URIs * Moved fsspec.filesystem logic into filesystem.py * Change to cache for non-file protocols when reading non-modulus models. * Moved code to generate checkpoint directory.directory * Added get_checkpoint_dir import * Address review feedback. * Changes from code review. * Addressed file test issue from review. * Fix to file existence check. * Fix merge conflicts due to project name change. * Updated CHANGELOG. * Added Multi-Storage Client to allow checkpointing to/from Object Storage Signed-off-by: Chris Hawes <[email protected]> * Addressed issues identified by pre-commit. * Update filesystem.py * Update __init__.py * Update Dockerfile --------- Signed-off-by: Chris Hawes <[email protected]> Co-authored-by: Nicholas Geneva <[email protected]>
1 parent e0c7389 commit 51a90bd

File tree

10 files changed

+254
-40
lines changed

10 files changed

+254
-40
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
### Added
1212

1313
- Added ReGen score-based data assimilation example
14+
- Added Multi-Storage Client to allow checkpointing to/from Object Storage
1415

1516
### Changed
1617

examples/generative/corrdiff/train.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
2626
from physicsnemo.metrics.diffusion import RegressionLoss, ResLoss, RegressionLossCE
2727
from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
28-
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
28+
from physicsnemo.launch.utils import (
29+
load_checkpoint,
30+
save_checkpoint,
31+
get_checkpoint_dir,
32+
)
2933
from datasets.dataset import init_train_valid_datasets_from_config
3034
from helpers.train_helpers import (
3135
set_patch_shape,
@@ -66,8 +70,8 @@ def main(cfg: DictConfig) -> None:
6670
enable_amp = fp_optimizations.startswith("amp")
6771
amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16
6872
logger.info(f"Saving the outputs in {os.getcwd()}")
69-
checkpoint_dir = os.path.join(
70-
cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}"
73+
checkpoint_dir = get_checkpoint_dir(
74+
str(cfg.training.io.get("checkpoint_dir", ".")), cfg.model.name
7175
)
7276
if cfg.training.hp.batch_size_per_gpu == "auto":
7377
cfg.training.hp.batch_size_per_gpu = (

physicsnemo/launch/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from .checkpoint import load_checkpoint, save_checkpoint
17+
from .checkpoint import get_checkpoint_dir, load_checkpoint, save_checkpoint

physicsnemo/launch/utils/checkpoint.py

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import glob
17+
import os
1818
import re
1919
from pathlib import Path
2020
from typing import Any, Dict, List, NewType, Optional, Union
2121

22+
import fsspec
23+
import fsspec.utils
2224
import torch
2325
from torch.cuda.amp import GradScaler
2426
from torch.optim.lr_scheduler import _LRScheduler
@@ -27,6 +29,7 @@
2729
from physicsnemo.distributed import DistributedManager
2830
from physicsnemo.launch.logging import PythonLogger
2931
from physicsnemo.utils.capture import _StaticCapture
32+
from physicsnemo.utils.filesystem import LOCAL_CACHE, _download_cached
3033

3134
optimizer = NewType("optimizer", torch.optim)
3235
scheduler = NewType("scheduler", _LRScheduler)
@@ -86,10 +89,13 @@ def _get_checkpoint_filename(
8689
else 0
8790
)
8891

89-
# Input file name
90-
checkpoint_filename = str(
91-
Path(path).resolve() / f"{base_name}.{model_parallel_rank}"
92-
)
92+
# Determine input file name. Get absolute file path if Posix path.
93+
# pathlib does not support custom schemes (eg: msc://...) so only perform resolve() for Posix.
94+
protocol = fsspec.utils.get_protocol(path)
95+
fs = fsspec.filesystem(protocol)
96+
if protocol == "file":
97+
path = str(Path(path).resolve())
98+
checkpoint_filename = f"{path}/{base_name}.{model_parallel_rank}"
9399

94100
# File extension for PhysicsNeMo models or PyTorch models
95101
file_extension = ".mdlus" if model_type == "mdlus" else ".pt"
@@ -101,20 +107,21 @@ def _get_checkpoint_filename(
101107
# Otherwise try loading the latest epoch or rolling checkpoint
102108
else:
103109
file_names = [
104-
Path(fname).name
105-
for fname in glob.glob(
106-
checkpoint_filename + "*" + file_extension, recursive=False
107-
)
110+
fname for fname in fs.glob(checkpoint_filename + "*" + file_extension)
108111
]
109112

110113
if len(file_names) > 0:
111114
# If checkpoint from a null index save exists load that
112115
# This is the most likely line to error since it will fail with
113116
# invalid checkpoint names
117+
118+
# Remove protocol prefix if present to allow generic matching
119+
_, path_without_protocol = fsspec.core.split_protocol(path)
114120
file_idx = [
115121
int(
116122
re.sub(
117-
f"^{base_name}.{model_parallel_rank}.|" + file_extension,
123+
f"^{path_without_protocol}/{base_name}.{model_parallel_rank}.|"
124+
+ file_extension,
118125
"",
119126
fname,
120127
)
@@ -212,8 +219,11 @@ def save_checkpoint(
212219
metadata : Optional[Dict[str, Any]], optional
213220
Additional metadata to save, by default None
214221
"""
215-
# Create checkpoint directory if it does not exist
216-
if not Path(path).is_dir():
222+
protocol = fsspec.utils.get_protocol(path)
223+
fs = fsspec.filesystem(protocol)
224+
# Create checkpoint directory if it does not exist.
225+
# Only applicable to Posix filesystems ("file" protocol), not object stores.
226+
if protocol == "file" and not Path(path).is_dir():
217227
checkpoint_logging.warning(
218228
f"Output directory {path} does not exist, will " "attempt to create"
219229
)
@@ -239,7 +249,8 @@ def save_checkpoint(
239249
if isinstance(model, physicsnemo.models.Module):
240250
model.save(file_name)
241251
else:
242-
torch.save(model.state_dict(), file_name)
252+
with fs.open(file_name, "wb") as fp:
253+
torch.save(model.state_dict(), fp)
243254
checkpoint_logging.success(f"Saved model state dictionary: {file_name}")
244255

245256
# == Saving training checkpoint ==
@@ -270,10 +281,11 @@ def save_checkpoint(
270281

271282
# Save checkpoint to memory
272283
if bool(checkpoint_dict):
273-
torch.save(
274-
checkpoint_dict,
275-
output_filename,
276-
)
284+
with fs.open(output_filename, "wb") as fp:
285+
torch.save(
286+
checkpoint_dict,
287+
fp,
288+
)
277289
checkpoint_logging.success(f"Saved training checkpoint: {output_filename}")
278290

279291

@@ -318,8 +330,14 @@ def load_checkpoint(
318330
int
319331
Loaded epoch
320332
"""
333+
fs = fsspec.filesystem(fsspec.utils.get_protocol(path))
321334
# Check if checkpoint directory exists
322-
if not Path(path).is_dir():
335+
if fs.exists(path):
336+
if fs.isfile(path):
337+
raise FileNotFoundError(
338+
f"Provided checkpoint directory {path} is a file, not directory"
339+
)
340+
else:
323341
checkpoint_logging.warning(
324342
f"Provided checkpoint directory {path} does not exist, skipping load"
325343
)
@@ -340,7 +358,7 @@ def load_checkpoint(
340358
file_name = _get_checkpoint_filename(
341359
path, name, index=epoch, model_type=model_type
342360
)
343-
if not Path(file_name).exists():
361+
if not fs.exists(file_name):
344362
checkpoint_logging.error(
345363
f"Could not find valid model file {file_name}, skipping load"
346364
)
@@ -349,21 +367,22 @@ def load_checkpoint(
349367
if isinstance(model, physicsnemo.models.Module):
350368
model.load(file_name)
351369
else:
352-
model.load_state_dict(torch.load(file_name, map_location=device))
353-
370+
file_to_load = _cache_if_needed(file_name)
371+
model.load_state_dict(torch.load(file_to_load, map_location=device))
354372
checkpoint_logging.success(
355373
f"Loaded model state dictionary {file_name} to device {device}"
356374
)
357375

358376
# == Loading training checkpoint ==
359377
checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt")
360-
if not Path(checkpoint_filename).is_file():
378+
if not fs.exists(checkpoint_filename):
361379
checkpoint_logging.warning(
362380
"Could not find valid checkpoint file, skipping load"
363381
)
364382
return 0
365383

366-
checkpoint_dict = torch.load(checkpoint_filename, map_location=device)
384+
file_to_load = _cache_if_needed(checkpoint_filename)
385+
checkpoint_dict = torch.load(file_to_load, map_location=device)
367386
checkpoint_logging.success(
368387
f"Loaded checkpoint file {checkpoint_filename} to device {device}"
369388
)
@@ -397,3 +416,41 @@ def load_checkpoint(
397416
metadata_dict[key] = value
398417

399418
return epoch
419+
420+
421+
def get_checkpoint_dir(base_dir: str, model_name: str) -> str:
422+
"""Get a checkpoint directory based on a given base directory and model name
423+
424+
Parameters
425+
----------
426+
base_dir : str
427+
Path to the base directory where checkpoints are stored
428+
model_name: str, optional
429+
Name of the model which is generating the checkpoint
430+
431+
Returns
432+
-------
433+
str
434+
Checkpoint directory
435+
"""
436+
top_level_dir = f"checkpoints_{model_name}"
437+
protocol = fsspec.utils.get_protocol(base_dir)
438+
if protocol == "msc":
439+
if not base_dir.endswith("/"):
440+
base_dir += "/"
441+
return base_dir + top_level_dir
442+
else:
443+
return os.path.join(base_dir, top_level_dir)
444+
445+
446+
# Read via cache and return the cached path for non-file protocols, otherwise just return the path
447+
def _cache_if_needed(path: str) -> str:
448+
protocol = fsspec.utils.get_protocol(path)
449+
if protocol == "file":
450+
return path
451+
else:
452+
return _download_cached(
453+
path,
454+
recursive=False,
455+
local_cache_path=os.path.join(LOCAL_CACHE, f"checkpoint_pid_{os.getpid()}"),
456+
)

physicsnemo/utils/filesystem.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import fsspec
2626
import fsspec.implementations.cached
27+
import fsspec.utils
2728
import requests
2829
import s3fs
2930
from tqdm import tqdm
@@ -46,7 +47,7 @@ def _get_fs(path):
4647
if path.startswith("s3://"):
4748
return s3fs.S3FileSystem(client_kwargs=dict(endpoint_url="https://pbss.s8k.io"))
4849
else:
49-
return fsspec.filesystem("file")
50+
return fsspec.filesystem(fsspec.utils.get_protocol(path))
5051

5152

5253
def _download_ngc_model_file(path: str, out_path: str, timeout: int = 300) -> str:
@@ -175,8 +176,8 @@ def _download_cached(
175176
# TODO watch for race condition here
176177
if not os.path.exists(cache_path):
177178
logger.debug("Downloading %s to cache: %s", path, cache_path)
178-
if path.startswith("s3://"):
179-
fs = _get_fs(path)
179+
if url.scheme in ("s3", "msc"):
180+
fs = fsspec.filesystem(fsspec.utils.get_protocol(path))
180181
fs.get(path, cache_path, recursive=recursive)
181182
elif path.startswith("ngc://models/"):
182183
path = _download_ngc_model_file(path, cache_path)

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ dev = [
5757
"interrogate==1.5.0",
5858
"coverage==6.5.0",
5959
"ruff==0.0.290",
60+
"moto[s3]>=5.0.28",
6061
]
6162

6263
makani = [
@@ -76,7 +77,7 @@ fignet = [
7677
]
7778

7879
storage = [
79-
"multi-storage-client>=0.14.0",
80+
"multi-storage-client[boto3]>=0.14.0",
8081
]
8182

8283
shardtensor = [

test/utils/msc_config_checkpoint.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
# This is an example MSC configuration file for testing checkpoint logic.
19+
profiles:
20+
checkpoint-test:
21+
storage_provider:
22+
type: s3
23+
options:
24+
region_name: us-east-1
25+
base_path: checkpoint-test-bucket
26+
credentials_provider:
27+
type: S3Credentials
28+
options:
29+
access_key: "access-key-id"
30+
secret_key: "secret-access-key"
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
# This is an example MSC configuration file for accessing the CMIP6 archive on AWS:
19+
# https://registry.opendata.aws/cmip6/
20+
profiles:
21+
cmip6-pds:
22+
storage_provider:
23+
type: s3
24+
options:
25+
region_name: us-west-2
26+
base_path: cmip6-pds
27+
signature_version: UNSIGNED
28+
cache:
29+
location: /tmp/.cache
30+
size_mb: 5000
31+

0 commit comments

Comments
 (0)