14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
- import glob
17
+ import os
18
18
import re
19
19
from pathlib import Path
20
20
from typing import Any , Dict , List , NewType , Optional , Union
21
21
22
+ import fsspec
23
+ import fsspec .utils
22
24
import torch
23
25
from torch .cuda .amp import GradScaler
24
26
from torch .optim .lr_scheduler import _LRScheduler
27
29
from physicsnemo .distributed import DistributedManager
28
30
from physicsnemo .launch .logging import PythonLogger
29
31
from physicsnemo .utils .capture import _StaticCapture
32
+ from physicsnemo .utils .filesystem import LOCAL_CACHE , _download_cached
30
33
31
34
optimizer = NewType ("optimizer" , torch .optim )
32
35
scheduler = NewType ("scheduler" , _LRScheduler )
@@ -86,10 +89,13 @@ def _get_checkpoint_filename(
86
89
else 0
87
90
)
88
91
89
- # Input file name
90
- checkpoint_filename = str (
91
- Path (path ).resolve () / f"{ base_name } .{ model_parallel_rank } "
92
- )
92
+ # Determine input file name. Get absolute file path if Posix path.
93
+ # pathlib does not support custom schemes (eg: msc://...) so only perform resolve() for Posix.
94
+ protocol = fsspec .utils .get_protocol (path )
95
+ fs = fsspec .filesystem (protocol )
96
+ if protocol == "file" :
97
+ path = str (Path (path ).resolve ())
98
+ checkpoint_filename = f"{ path } /{ base_name } .{ model_parallel_rank } "
93
99
94
100
# File extension for PhysicsNeMo models or PyTorch models
95
101
file_extension = ".mdlus" if model_type == "mdlus" else ".pt"
@@ -101,20 +107,21 @@ def _get_checkpoint_filename(
101
107
# Otherwise try loading the latest epoch or rolling checkpoint
102
108
else :
103
109
file_names = [
104
- Path (fname ).name
105
- for fname in glob .glob (
106
- checkpoint_filename + "*" + file_extension , recursive = False
107
- )
110
+ fname for fname in fs .glob (checkpoint_filename + "*" + file_extension )
108
111
]
109
112
110
113
if len (file_names ) > 0 :
111
114
# If checkpoint from a null index save exists load that
112
115
# This is the most likely line to error since it will fail with
113
116
# invalid checkpoint names
117
+
118
+ # Remove protocol prefix if present to allow generic matching
119
+ _ , path_without_protocol = fsspec .core .split_protocol (path )
114
120
file_idx = [
115
121
int (
116
122
re .sub (
117
- f"^{ base_name } .{ model_parallel_rank } .|" + file_extension ,
123
+ f"^{ path_without_protocol } /{ base_name } .{ model_parallel_rank } .|"
124
+ + file_extension ,
118
125
"" ,
119
126
fname ,
120
127
)
@@ -212,8 +219,11 @@ def save_checkpoint(
212
219
metadata : Optional[Dict[str, Any]], optional
213
220
Additional metadata to save, by default None
214
221
"""
215
- # Create checkpoint directory if it does not exist
216
- if not Path (path ).is_dir ():
222
+ protocol = fsspec .utils .get_protocol (path )
223
+ fs = fsspec .filesystem (protocol )
224
+ # Create checkpoint directory if it does not exist.
225
+ # Only applicable to Posix filesystems ("file" protocol), not object stores.
226
+ if protocol == "file" and not Path (path ).is_dir ():
217
227
checkpoint_logging .warning (
218
228
f"Output directory { path } does not exist, will " "attempt to create"
219
229
)
@@ -239,7 +249,8 @@ def save_checkpoint(
239
249
if isinstance (model , physicsnemo .models .Module ):
240
250
model .save (file_name )
241
251
else :
242
- torch .save (model .state_dict (), file_name )
252
+ with fs .open (file_name , "wb" ) as fp :
253
+ torch .save (model .state_dict (), fp )
243
254
checkpoint_logging .success (f"Saved model state dictionary: { file_name } " )
244
255
245
256
# == Saving training checkpoint ==
@@ -270,10 +281,11 @@ def save_checkpoint(
270
281
271
282
# Save checkpoint to memory
272
283
if bool (checkpoint_dict ):
273
- torch .save (
274
- checkpoint_dict ,
275
- output_filename ,
276
- )
284
+ with fs .open (output_filename , "wb" ) as fp :
285
+ torch .save (
286
+ checkpoint_dict ,
287
+ fp ,
288
+ )
277
289
checkpoint_logging .success (f"Saved training checkpoint: { output_filename } " )
278
290
279
291
@@ -318,8 +330,14 @@ def load_checkpoint(
318
330
int
319
331
Loaded epoch
320
332
"""
333
+ fs = fsspec .filesystem (fsspec .utils .get_protocol (path ))
321
334
# Check if checkpoint directory exists
322
- if not Path (path ).is_dir ():
335
+ if fs .exists (path ):
336
+ if fs .isfile (path ):
337
+ raise FileNotFoundError (
338
+ f"Provided checkpoint directory { path } is a file, not directory"
339
+ )
340
+ else :
323
341
checkpoint_logging .warning (
324
342
f"Provided checkpoint directory { path } does not exist, skipping load"
325
343
)
@@ -340,7 +358,7 @@ def load_checkpoint(
340
358
file_name = _get_checkpoint_filename (
341
359
path , name , index = epoch , model_type = model_type
342
360
)
343
- if not Path ( file_name ) .exists ():
361
+ if not fs .exists (file_name ):
344
362
checkpoint_logging .error (
345
363
f"Could not find valid model file { file_name } , skipping load"
346
364
)
@@ -349,21 +367,22 @@ def load_checkpoint(
349
367
if isinstance (model , physicsnemo .models .Module ):
350
368
model .load (file_name )
351
369
else :
352
- model . load_state_dict ( torch . load ( file_name , map_location = device ) )
353
-
370
+ file_to_load = _cache_if_needed ( file_name )
371
+ model . load_state_dict ( torch . load ( file_to_load , map_location = device ))
354
372
checkpoint_logging .success (
355
373
f"Loaded model state dictionary { file_name } to device { device } "
356
374
)
357
375
358
376
# == Loading training checkpoint ==
359
377
checkpoint_filename = _get_checkpoint_filename (path , index = epoch , model_type = "pt" )
360
- if not Path ( checkpoint_filename ). is_file ( ):
378
+ if not fs . exists ( checkpoint_filename ):
361
379
checkpoint_logging .warning (
362
380
"Could not find valid checkpoint file, skipping load"
363
381
)
364
382
return 0
365
383
366
- checkpoint_dict = torch .load (checkpoint_filename , map_location = device )
384
+ file_to_load = _cache_if_needed (checkpoint_filename )
385
+ checkpoint_dict = torch .load (file_to_load , map_location = device )
367
386
checkpoint_logging .success (
368
387
f"Loaded checkpoint file { checkpoint_filename } to device { device } "
369
388
)
@@ -397,3 +416,41 @@ def load_checkpoint(
397
416
metadata_dict [key ] = value
398
417
399
418
return epoch
419
+
420
+
421
+ def get_checkpoint_dir (base_dir : str , model_name : str ) -> str :
422
+ """Get a checkpoint directory based on a given base directory and model name
423
+
424
+ Parameters
425
+ ----------
426
+ base_dir : str
427
+ Path to the base directory where checkpoints are stored
428
+ model_name: str, optional
429
+ Name of the model which is generating the checkpoint
430
+
431
+ Returns
432
+ -------
433
+ str
434
+ Checkpoint directory
435
+ """
436
+ top_level_dir = f"checkpoints_{ model_name } "
437
+ protocol = fsspec .utils .get_protocol (base_dir )
438
+ if protocol == "msc" :
439
+ if not base_dir .endswith ("/" ):
440
+ base_dir += "/"
441
+ return base_dir + top_level_dir
442
+ else :
443
+ return os .path .join (base_dir , top_level_dir )
444
+
445
+
446
+ # Read via cache and return the cached path for non-file protocols, otherwise just return the path
447
+ def _cache_if_needed (path : str ) -> str :
448
+ protocol = fsspec .utils .get_protocol (path )
449
+ if protocol == "file" :
450
+ return path
451
+ else :
452
+ return _download_cached (
453
+ path ,
454
+ recursive = False ,
455
+ local_cache_path = os .path .join (LOCAL_CACHE , f"checkpoint_pid_{ os .getpid ()} " ),
456
+ )
0 commit comments