From 76e4c057a4fc7158d117917b75b40c37601a2d9f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 10 Oct 2024 14:11:52 -0400 Subject: [PATCH 1/7] ckp path handling fixes and get_ fn updates Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/checkpointing_utils.py | 30 ++++++++++++++++++--------- fms_fsdp/utils/dataset_utils.py | 6 +++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index e146ac94..f953aaa2 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -20,9 +20,14 @@ from torch.distributed.fsdp import StateDictType -def get_latest(targdir, qualifier=lambda x: True): - """Fetch the latest file or folder written to target directory, subject to name passing the qualifier fn. - If directory is empty or nonexistent or no items qualify, return None.""" +def get_latest(targdir, qualifier=lambda x: True, key=os.path.getctime): + """ + Fetch the full path of the latest file or folder written to target directory, + subject to name passing the qualifier fn. + Optional key fn can be used for custom sorting. + Both functions take full path arguments. + If directory is empty or nonexistent or no items qualify, return None. + """ if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: latest = max( [ @@ -30,15 +35,20 @@ def get_latest(targdir, qualifier=lambda x: True): for x in os.listdir(targdir) if qualifier(os.path.join(targdir, x)) ], - key=lambda path: int(path.split("/")[-1].split("_")[1]), + key=key, ) - return os.path.join(targdir, latest) + return latest return None -def get_oldest(targdir, qualifier=lambda x: True): - """Fetch the oldest file or folder written to target directory, subject to name passing the qualifier fn. - If directory is empty or nonexistent or no items qualify, return None.""" +def get_oldest(targdir, qualifier=lambda x: True, key=os.path.getctime): + """ + Fetch the full path of the oldest file or folder written to target directory, + subject to name passing the qualifier fn. + Optional key fn can be used for custom sorting. + Both functions take full path arguments. + If directory is empty or nonexistent or no items qualify, return None. + """ if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: oldest = min( [ @@ -46,9 +56,9 @@ def get_oldest(targdir, qualifier=lambda x: True): for x in os.listdir(targdir) if qualifier(os.path.join(targdir, x)) ], - key=os.path.getctime, + key=key, ) - return os.path.join(targdir, oldest) + return oldest return None diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index f8996a28..9cb360d8 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -32,7 +32,7 @@ rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be re-distributed over workers (i.e. buffers). -Our loaders obey the following type heirarchy: +Our loaders obey the following type hierarchy: torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset. `_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times, @@ -510,8 +510,8 @@ def _validate_ckp_path(self, path: str, verbose: bool = False): f" Dataset: No valid checkpoint detected at {path}, dataset starting from scratch." ) return "" - # Check latest path - latest = os.path.join(path, get_latest(path)) + # Check latest path, using ckp naming syntax + latest = get_latest(path, key= lambda path: int(path.split("_")[-2])) if verbose: self.report(f"Checkpoint detected at {latest}") # If item is not a folder, exit early From 518ccf48113d532e8eb02639e71fcbfa73d36fd9 Mon Sep 17 00:00:00 2001 From: Johannes Schmude <38668226+johannesschmude@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:53:02 +0100 Subject: [PATCH 2/7] Typo is_file vs isfile (#111) Signed-off-by: Johannes Schmude Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/checkpointing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index f953aaa2..2bbeef18 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -128,7 +128,7 @@ def _cleanup(self): ckp_to_remove = Path( get_oldest(self.ckp_path, qualifier=lambda x: "tmp" in x) ) - if os.path.is_file(ckp_to_remove): + if os.path.isfile(ckp_to_remove): ckp_to_remove.unlink() else: shutil.rmtree(ckp_to_remove) From 3280dc08f5db4563981194d5b033af9b34c4126b Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 10 Oct 2024 14:18:21 -0400 Subject: [PATCH 3/7] Format fixes Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/dataset_utils.py | 2 +- speculator/train_speculator_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 9cb360d8..d1d442d7 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -511,7 +511,7 @@ def _validate_ckp_path(self, path: str, verbose: bool = False): ) return "" # Check latest path, using ckp naming syntax - latest = get_latest(path, key= lambda path: int(path.split("_")[-2])) + latest = get_latest(path, key=lambda path: int(path.split("_")[-2])) if verbose: self.report(f"Checkpoint detected at {latest}") # If item is not a folder, exit early diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 87b4e7b2..bb9e4d4f 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -440,7 +440,11 @@ def forward( x: torch.LongTensor, mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, + past_key_value_states: Optional[ + Tuple[ + torch.FloatTensor, + ] + ] = None, use_cache: bool = False, attn_algorithm: Optional[str] = None, include_embeds: bool = False, From 3adfb7d2119635380f57d8836b5d2549ffa4c7a2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 10 Oct 2024 14:21:05 -0400 Subject: [PATCH 4/7] Manual format fixing Signed-off-by: Davis Wertheimer --- speculator/train_speculator_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index bb9e4d4f..87b4e7b2 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -440,11 +440,7 @@ def forward( x: torch.LongTensor, mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value_states: Optional[ - Tuple[ - torch.FloatTensor, - ] - ] = None, + past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, use_cache: bool = False, attn_algorithm: Optional[str] = None, include_embeds: bool = False, From cf93f60aa68598c936496960af57b637fae32215 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 10 Oct 2024 14:33:42 -0400 Subject: [PATCH 5/7] gptbigcode forward type fixes Signed-off-by: Davis Wertheimer --- speculator/train_speculator_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 87b4e7b2..0ce2a299 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -1,7 +1,7 @@ import os import re import time -from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union import torch import torch.distributed as dist @@ -437,11 +437,12 @@ class EmbedGPTBigCode(GPTBigCode): # Overrides the forward function of GPTBigCode to allow returning embedding vectors def forward( self, - x: torch.LongTensor, + x: torch.Tensor, mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value_states: Optional[List[Tuple[torch.Tensor,]]] = None, use_cache: bool = False, + only_last_token: bool = False, attn_algorithm: Optional[str] = None, include_embeds: bool = False, ): From 5ac7d959303a0fa9c9b0e0bc933e3cabc63270e0 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 10 Oct 2024 14:41:41 -0400 Subject: [PATCH 6/7] gptbigcode forward type fixes pt2 Signed-off-by: Davis Wertheimer --- speculator/train_speculator_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 0ce2a299..11e8505e 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -440,7 +440,7 @@ def forward( x: torch.Tensor, mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - past_key_value_states: Optional[List[Tuple[torch.Tensor,]]] = None, + past_key_value_states: Optional[List[Tuple[torch.Tensor,torch.Tensor]]] = None, use_cache: bool = False, only_last_token: bool = False, attn_algorithm: Optional[str] = None, From f4b2e3707c915ece98d8ce4edea9cf4a4ff311de Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 10 Oct 2024 14:44:49 -0400 Subject: [PATCH 7/7] Final format fix Signed-off-by: Davis Wertheimer --- speculator/train_speculator_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 11e8505e..0a265a63 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -440,7 +440,7 @@ def forward( x: torch.Tensor, mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - past_key_value_states: Optional[List[Tuple[torch.Tensor,torch.Tensor]]] = None, + past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, only_last_token: bool = False, attn_algorithm: Optional[str] = None,