Skip to content

Commit 78a239d

Browse files
authored
Resync branch (#121)
1 parent c54c7e5 commit 78a239d

File tree

3 files changed

+29
-18
lines changed

3 files changed

+29
-18
lines changed

fms_fsdp/utils/checkpointing_utils.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,45 @@
2020
from torch.distributed.fsdp import StateDictType
2121

2222

23-
def get_latest(targdir, qualifier=lambda x: True):
24-
"""Fetch the latest file or folder written to target directory, subject to name passing the qualifier fn.
25-
If directory is empty or nonexistent or no items qualify, return None."""
23+
def get_latest(targdir, qualifier=lambda x: True, key=os.path.getctime):
24+
"""
25+
Fetch the full path of the latest file or folder written to target directory,
26+
subject to name passing the qualifier fn.
27+
Optional key fn can be used for custom sorting.
28+
Both functions take full path arguments.
29+
If directory is empty or nonexistent or no items qualify, return None.
30+
"""
2631
if os.path.exists(targdir) and len(os.listdir(targdir)) > 0:
2732
latest = max(
2833
[
2934
os.path.join(targdir, x)
3035
for x in os.listdir(targdir)
3136
if qualifier(os.path.join(targdir, x))
3237
],
33-
key=lambda path: int(path.split("/")[-1].split("_")[1]),
38+
key=key,
3439
)
35-
return os.path.join(targdir, latest)
40+
return latest
3641
return None
3742

3843

39-
def get_oldest(targdir, qualifier=lambda x: True):
40-
"""Fetch the oldest file or folder written to target directory, subject to name passing the qualifier fn.
41-
If directory is empty or nonexistent or no items qualify, return None."""
44+
def get_oldest(targdir, qualifier=lambda x: True, key=os.path.getctime):
45+
"""
46+
Fetch the full path of the oldest file or folder written to target directory,
47+
subject to name passing the qualifier fn.
48+
Optional key fn can be used for custom sorting.
49+
Both functions take full path arguments.
50+
If directory is empty or nonexistent or no items qualify, return None.
51+
"""
4252
if os.path.exists(targdir) and len(os.listdir(targdir)) > 0:
4353
oldest = min(
4454
[
4555
os.path.join(targdir, x)
4656
for x in os.listdir(targdir)
4757
if qualifier(os.path.join(targdir, x))
4858
],
49-
key=os.path.getctime,
59+
key=key,
5060
)
51-
return os.path.join(targdir, oldest)
61+
return oldest
5262
return None
5363

5464

@@ -118,7 +128,7 @@ def _cleanup(self):
118128
ckp_to_remove = Path(
119129
get_oldest(self.ckp_path, qualifier=lambda x: "tmp" in x)
120130
)
121-
if os.path.is_file(ckp_to_remove):
131+
if os.path.isfile(ckp_to_remove):
122132
ckp_to_remove.unlink()
123133
else:
124134
shutil.rmtree(ckp_to_remove)

fms_fsdp/utils/dataset_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be
3333
re-distributed over workers (i.e. buffers).
3434
35-
Our loaders obey the following type heirarchy:
35+
Our loaders obey the following type hierarchy:
3636
torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset.
3737
`_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a
3838
single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times,
@@ -510,8 +510,8 @@ def _validate_ckp_path(self, path: str, verbose: bool = False):
510510
f" Dataset: No valid checkpoint detected at {path}, dataset starting from scratch."
511511
)
512512
return ""
513-
# Check latest path
514-
latest = os.path.join(path, get_latest(path))
513+
# Check latest path, using ckp naming syntax
514+
latest = get_latest(path, key=lambda path: int(path.split("_")[-2]))
515515
if verbose:
516516
self.report(f"Checkpoint detected at {latest}")
517517
# If item is not a folder, exit early

speculator/train_speculator_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import re
33
import time
4-
from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union
4+
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
55

66
import torch
77
import torch.distributed as dist
@@ -437,11 +437,12 @@ class EmbedGPTBigCode(GPTBigCode):
437437
# Overrides the forward function of GPTBigCode to allow returning embedding vectors
438438
def forward(
439439
self,
440-
x: torch.LongTensor,
440+
x: torch.Tensor,
441441
mask: Optional[torch.Tensor] = None,
442-
position_ids: Optional[torch.LongTensor] = None,
443-
past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None,
442+
position_ids: Optional[torch.Tensor] = None,
443+
past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
444444
use_cache: bool = False,
445+
only_last_token: bool = False,
445446
attn_algorithm: Optional[str] = None,
446447
include_embeds: bool = False,
447448
):

0 commit comments

Comments
 (0)