From 6a46c16efd45b21526602976a78f0b154267ff56 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 10 Jan 2025 16:23:48 -0500 Subject: [PATCH 01/15] Add fim Signed-off-by: Davis Wertheimer --- fms_fsdp/config/training.py | 8 ++ fms_fsdp/utils/dataloader_utils.py | 34 ++++++--- fms_fsdp/utils/dataset_utils.py | 118 +++++++++++++++++++++++++++++ tests/test_datasets.py | 4 + 4 files changed, 155 insertions(+), 9 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 1d072958..eadef50f 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -72,3 +72,11 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 + + # FIM training + fim_training: bool = False + psm_rate: float = 0.0 + spm_rate: float = 0.0 + fim_pre: int = 1 + fim_mid: int = 2 + fim_suf: int = 3 diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 4b811d6d..a3f0703a 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -5,6 +5,7 @@ AutoHandler, BufferDataset, CheckpointDataset, + FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -57,9 +58,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): +def get_data_loader(cfg, rank, world_size): """ - Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. + Pytorch dataloader for stateful, distributed, and rescalable language model training. Assumes underlying data is sequences of integer values. ... Args @@ -70,11 +71,11 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. - postprocess : List[Callable] - Any task-specific postprocessing to apply before handing over data. Steps will apply in - the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ + if cfg.fim_training: + assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" + datasets, weights = parse_data_args(cfg.datasets, cfg.weights) # Base streaming dataset. Returns doc chunks in sequence. @@ -118,9 +119,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): verbose=(rank == 0), ) # Wrap above dataset in packing logic to form constant-length lines. + # CLM removes 1 token, FIM adds at least 3. data = BufferDataset( data, - cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, + cfg.seq_length - 3 if cfg.fim_training else cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, @@ -128,10 +130,24 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. data = PreloadBufferDataset(data, 10000) - # Apply desired postprocessing steps in sequence + # Apply FIM transformation if needed + if cfg.fim_training: + data = FIMDataset( + data, + cfg.eos_token, + cfg.psm_rate, + cfg.spm_rate, + pre_token=cfg.fim_pre, + mid_token=cfg.fim_mid, + suf_token=cfg.fim_suf, + ) + + # Transform to tensors data = PreprocessDataset(data, torch.IntTensor) - for p in postprocess: - data = PreprocessDataset(data, p) + + # Apply CLM transformation if needed + if not cfg.fim_training: + data = PreprocessDataset(data, causal_lm) # Enable auto-saving data = CheckpointDataset( diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index aedc5862..4a72e6c2 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -694,6 +694,124 @@ def load_state_dict(self, state_dicts, sharded_input=False): # Manually set buffer size self.buffer_size = len(self.buffer) return sharded_dicts + + +class FIMDataset(_WrapperDataset): + """ + Wrapper for a StatefulDataset that implements Fill-In-the-Middle training + (https://arxiv.org/pdf/2207.14255). + Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). + Breaks sequence apart into component document spans, and for each document span + of sufficient length, transforms with specified probability into: + PSM mode:
 (prefix)  (suffix)  (middle) 
+    SPM mode: 
  (suffix)  (prefix) (middle) 
+    The new delimiter tokens can be omitted by passing in None.
+    Any extra tokens after transformation are dropped from the end of the sequence.
+    ...
+    Args
+    ----
+    dataset : _StatefulDataset
+        Fully instantiated dataset
+    delimiter_token : any
+        Token used to indicate document boundaries
+    psm_rate : float
+        Chance to transform into PSM. Cannot exceed 1.
+    spm_rate : float
+        Chance to transform into SPM. Cannot exceed 1.
+    min_len : int
+        Minimum document length to perform FIM transformation
+    pre_token : any | none
+        Token used to indicate prefix section of the document
+    mid_token : any | none
+        Token used to indicate middle infill section of the document
+    suf_token : any | none
+        Token used to indicate suffix section of the document
+    """
+
+    def __init__(
+        self, 
+        dataset: _StatefulDataset, 
+        delimiter_token: Any,
+        psm_rate: float = 0.0,
+        spm_rate: float = 0.0,
+        min_len: int = 10,
+        pre_token = None,
+        mid_token = None,
+        suf_token = None,
+    ):
+        super().__init__(dataset)
+        assert (
+            psm_rate + spm_rate > 0
+        ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate."
+        assert (
+            psm_rate + spm_rate <= 1
+        ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1."
+        self.psm = psm_rate
+        self.spm = spm_rate
+        self.delimiter = delimiter_token
+        self.min_len = min_len
+        self.pref = pre_token
+        self.suff = suf_token
+        self.midd = mid_token
+
+        self.g_state = None
+        self.generator = torch.Generator().manual_seed(self.rank)
+        self.state_params = ["g_state"]
+        
+    def __iter__(self):
+        dataset = iter(self.dataset)
+        while True:
+            inp = next(dataset)
+            len_ = len(inp)
+            i_eos = [0] + [i for i,x in enumerate(inp) if x==self.delimiter] + [len_]
+            docs = [inp[i_eos[j]+1:i_eos[j+1]] for j in range(len(i_eos)-1)]  # list[list[any]]
+            out = []
+            for i in range(len(docs)):
+                doc = docs[i]
+                if len(docs[i]) >= self.min_len:
+                    # decide psm, spm, or nothing
+                    thresh = torch.rand([1], generator=self.generator).item()
+                    if thresh < self.psm + self.spm:
+                        # Split doc
+                        doc = []
+                        if self.pref:
+                            doc = [self.pref]
+                        splits = torch.randint(0, len(docs[i]), [2], generator=self.generator).tolist()
+                        pre = docs[i][:min(splits)]
+                        mid = docs[i][min(splits):max(splits)]
+                        suf = docs[i][max(splits):]
+
+                        if thresh < self.psm:
+                            # PSM transformation
+                            doc += pre
+                            if self.suff:
+                                doc.append(self.suff)
+                            doc += suf
+                            if self.midd:
+                                doc.append(self.midd)
+                            doc += mid
+                        else:
+                            # SPM transformation
+                            if self.suff:
+                                doc.append(self.suff)
+                            doc += suf
+                            if self.midd:
+                                doc.append(self.midd)
+                            doc += pre + mid
+                out += doc + [self.delimiter]
+            yield out[:len_]
+
+    def state_dict(self):
+        # Write generator state manually
+        self.g_state = self.generator.get_state()
+        return super().state_dict()
+
+    def load_state_dict(self, state_dicts, sharded_input=False):
+        sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
+        # Manually set generator state if it exists
+        if self.g_state is not None:
+            self.generator.set_state(self.g_state)
+        return sharded_dicts
 
 
 class BufferDataset(_WrapperDataset):
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 83b2426b..e78febad 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -632,6 +632,10 @@ def test_multi_reload_stress():
     # preload / sample / scale / doc pipeline
     multi_reload_stress_check(lambda: d6(d5(d4())))
 
+    # Add FIM dataset
+    d7 = lambda x: [FIMDataset(d, -1, .25, .25, 10, -2, -3, -4) for d in x]
+    multi_reload_stress_check(lambda: d7(d6(d5(d4()))))
+
 
 # SCALABLEDATASET TESTS
 

From 673e6e5736b97de66571c2fb7a18f245749d1cfd Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 10 Jan 2025 16:24:12 -0500
Subject: [PATCH 02/15] Blacking

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py     |  4 +++-
 fms_fsdp/utils/dataset_utils.py | 38 ++++++++++++++++++---------------
 main_training_llama.py          |  8 ++++---
 main_training_mamba.py          |  8 ++++---
 tests/test_datasets.py          |  2 +-
 5 files changed, 35 insertions(+), 25 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index eadef50f..f985ed0d 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -15,7 +15,9 @@ class train_config:
     file_type: str = "arrow"
     col_name: str = "tokens"
     tokenizer_path: str = "/fsx/tokenizer"
-    datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
+    datasets: str = (
+        "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
+    )
     weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
     seq_length: int = 4096
     vocab_size: int = 32000
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 4a72e6c2..e6480dab 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -694,14 +694,14 @@ def load_state_dict(self, state_dicts, sharded_input=False):
         # Manually set buffer size
         self.buffer_size = len(self.buffer)
         return sharded_dicts
-    
+
 
 class FIMDataset(_WrapperDataset):
     """
     Wrapper for a StatefulDataset that implements Fill-In-the-Middle training
     (https://arxiv.org/pdf/2207.14255).
     Input should be a packed sequence (i.e. call BufferDataset before FIMDataset).
-    Breaks sequence apart into component document spans, and for each document span 
+    Breaks sequence apart into component document spans, and for each document span
     of sufficient length, transforms with specified probability into:
     PSM mode: 
 (prefix)  (suffix)  (middle) 
     SPM mode: 
  (suffix)  (prefix) (middle) 
@@ -729,15 +729,15 @@ class FIMDataset(_WrapperDataset):
     """
 
     def __init__(
-        self, 
-        dataset: _StatefulDataset, 
+        self,
+        dataset: _StatefulDataset,
         delimiter_token: Any,
         psm_rate: float = 0.0,
         spm_rate: float = 0.0,
         min_len: int = 10,
-        pre_token = None,
-        mid_token = None,
-        suf_token = None,
+        pre_token=None,
+        mid_token=None,
+        suf_token=None,
     ):
         super().__init__(dataset)
         assert (
@@ -757,14 +757,16 @@ def __init__(
         self.g_state = None
         self.generator = torch.Generator().manual_seed(self.rank)
         self.state_params = ["g_state"]
-        
+
     def __iter__(self):
         dataset = iter(self.dataset)
         while True:
             inp = next(dataset)
             len_ = len(inp)
-            i_eos = [0] + [i for i,x in enumerate(inp) if x==self.delimiter] + [len_]
-            docs = [inp[i_eos[j]+1:i_eos[j+1]] for j in range(len(i_eos)-1)]  # list[list[any]]
+            i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_]
+            docs = [
+                inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1)
+            ]  # list[list[any]]
             out = []
             for i in range(len(docs)):
                 doc = docs[i]
@@ -776,10 +778,12 @@ def __iter__(self):
                         doc = []
                         if self.pref:
                             doc = [self.pref]
-                        splits = torch.randint(0, len(docs[i]), [2], generator=self.generator).tolist()
-                        pre = docs[i][:min(splits)]
-                        mid = docs[i][min(splits):max(splits)]
-                        suf = docs[i][max(splits):]
+                        splits = torch.randint(
+                            0, len(docs[i]), [2], generator=self.generator
+                        ).tolist()
+                        pre = docs[i][: min(splits)]
+                        mid = docs[i][min(splits) : max(splits)]
+                        suf = docs[i][max(splits) :]
 
                         if thresh < self.psm:
                             # PSM transformation
@@ -990,9 +994,9 @@ def __init__(
         self.bos = bos_token
         self.drop = strip_tokens
         self.verbose = verbose
-        self.docset: List[
-            Any
-        ] = []  # map of doc indices to (shardid, min docid, max docid)
+        self.docset: List[Any] = (
+            []
+        )  # map of doc indices to (shardid, min docid, max docid)
 
         # Position
         self.docset_index = 0
diff --git a/main_training_llama.py b/main_training_llama.py
index 67cccee2..a7e1020f 100644
--- a/main_training_llama.py
+++ b/main_training_llama.py
@@ -122,9 +122,11 @@ def main(**kwargs):
         model,
         optimizer,
         None,
-        path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
-        if not os.path.isfile(cfg.ckpt_load_path)
-        else cfg.ckpt_load_path,
+        path=(
+            os.path.join(cfg.ckpt_load_path, "checkpoints/")
+            if not os.path.isfile(cfg.ckpt_load_path)
+            else cfg.ckpt_load_path
+        ),
         strict=False,
     )
     if not is_resuming:
diff --git a/main_training_mamba.py b/main_training_mamba.py
index 3619ea25..68a3c830 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -119,9 +119,11 @@ def main(**kwargs):
         model,
         optimizer,
         None,
-        path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
-        if not os.path.isfile(cfg.ckpt_load_path)
-        else cfg.ckpt_load_path,
+        path=(
+            os.path.join(cfg.ckpt_load_path, "checkpoints/")
+            if not os.path.isfile(cfg.ckpt_load_path)
+            else cfg.ckpt_load_path
+        ),
         strict=False,
     )
     if not is_resuming:
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index e78febad..40bef481 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -633,7 +633,7 @@ def test_multi_reload_stress():
     multi_reload_stress_check(lambda: d6(d5(d4())))
 
     # Add FIM dataset
-    d7 = lambda x: [FIMDataset(d, -1, .25, .25, 10, -2, -3, -4) for d in x]
+    d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x]
     multi_reload_stress_check(lambda: d7(d6(d5(d4()))))
 
 

From d3cc468c007d42935f32236525c5139a8527d59c Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 10 Jan 2025 16:27:14 -0500
Subject: [PATCH 03/15] Corrected fim/clm combo

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/utils/dataloader_utils.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index a3f0703a..d4bbc984 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -119,10 +119,10 @@ def get_data_loader(cfg, rank, world_size):
         verbose=(rank == 0),
     )
     # Wrap above dataset in packing logic to form constant-length lines.
-    # CLM removes 1 token, FIM adds at least 3.
+    # Increment seq len to counteract CLM's one token removal.
     data = BufferDataset(
         data,
-        cfg.seq_length - 3 if cfg.fim_training else cfg.seq_length + 1,
+        cfg.seq_length + 1,
         bos_token=cfg.bol_token,
         eos_token=cfg.eol_token,
         pack_hard=True,
@@ -145,9 +145,8 @@ def get_data_loader(cfg, rank, world_size):
     # Transform to tensors
     data = PreprocessDataset(data, torch.IntTensor)
 
-    # Apply CLM transformation if needed
-    if not cfg.fim_training:
-        data = PreprocessDataset(data, causal_lm)
+    # Apply CLM transformation
+    data = PreprocessDataset(data, causal_lm)
 
     # Enable auto-saving
     data = CheckpointDataset(

From d5da300404213cb3e65e63191f96236713a551ff Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 10 Jan 2025 16:34:57 -0500
Subject: [PATCH 04/15] reblacking

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py     | 4 +---
 fms_fsdp/utils/dataset_utils.py | 5 ++---
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index f985ed0d..eadef50f 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -15,9 +15,7 @@ class train_config:
     file_type: str = "arrow"
     col_name: str = "tokens"
     tokenizer_path: str = "/fsx/tokenizer"
-    datasets: str = (
-        "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
-    )
+    datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
     weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
     seq_length: int = 4096
     vocab_size: int = 32000
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index e6480dab..9803ce94 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -994,9 +994,8 @@ def __init__(
         self.bos = bos_token
         self.drop = strip_tokens
         self.verbose = verbose
-        self.docset: List[Any] = (
-            []
-        )  # map of doc indices to (shardid, min docid, max docid)
+        # Map of doc indices to (shardid, min docid, max docid)
+        self.docset: List[Any] = []  
 
         # Position
         self.docset_index = 0

From 1981fa93a819b4badd45e6ddcd3411189f7e2a7b Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 10 Jan 2025 16:36:54 -0500
Subject: [PATCH 05/15] Rereblacking

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/utils/dataset_utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 9803ce94..7d634161 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -995,7 +995,7 @@ def __init__(
         self.drop = strip_tokens
         self.verbose = verbose
         # Map of doc indices to (shardid, min docid, max docid)
-        self.docset: List[Any] = []  
+        self.docset: List[Any] = []
 
         # Position
         self.docset_index = 0

From 9365cb2c1c78940f831ed6877e6ed07b081e7724 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 23 May 2025 14:29:45 -0400
Subject: [PATCH 06/15] Add multicol support

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py        |  1 -
 fms_fsdp/utils/dataloader_utils.py | 16 ++++++-----
 fms_fsdp/utils/dataset_utils.py    | 43 ++++++++++++++++++++++--------
 3 files changed, 41 insertions(+), 19 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index eadef50f..20eb0b76 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -74,7 +74,6 @@ class train_config:
     stage2_seq_length: int = 256
 
     # FIM training
-    fim_training: bool = False
     psm_rate: float = 0.0
     spm_rate: float = 0.0
     fim_pre: int = 1
diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index d4bbc984..5720f239 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -73,10 +73,11 @@ def get_data_loader(cfg, rank, world_size):
         Number of distributed workers. Used for handling dataset sharding logic.
     """
 
-    if cfg.fim_training:
+    fim_training = cfg.psm_rate + cfg.spm_rate > 0
+    if fim_training:
         assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"
 
-    datasets, weights = parse_data_args(cfg.datasets, cfg.weights)
+    datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)
 
     # Base streaming dataset. Returns doc chunks in sequence.
     # Implements dataset sampling and rescalability.
@@ -88,9 +89,9 @@ def get_data_loader(cfg, rank, world_size):
         cfg.file_type in _handler_map
     ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
     if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
-        filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name)
+        filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols)
     else:
-        filehandler = _handler_map[cfg.file_type]
+        filehandler = _handler_map[cfg.file_type](cols)
     # Base reader layer
     data = StreamingDocDataset(
         cfg.data_path,
@@ -131,7 +132,7 @@ def get_data_loader(cfg, rank, world_size):
     data = PreloadBufferDataset(data, 10000)
 
     # Apply FIM transformation if needed
-    if cfg.fim_training:
+    if fim_training:
         data = FIMDataset(
             data,
             cfg.eos_token,
@@ -161,7 +162,7 @@ def get_data_loader(cfg, rank, world_size):
     )
 
 
-def parse_data_args(datas, weights):
+def parse_data_args(datas, weights, cols):
     # Convert csv inputs into corresponding lists of values
     def splitstrip(x):
         if isinstance(x, str):
@@ -175,4 +176,5 @@ def splitstrip(x):
 
     datas = splitstrip(datas)
     weights = [float(x) for x in splitstrip(weights)]
-    return datas, weights
+    cols = splitstrip(cols)
+    return datas, weights, cols
\ No newline at end of file
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 7d634161..6dd074fa 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -15,6 +15,9 @@
 
 from fms_fsdp.utils.checkpointing_utils import get_latest
 
+# TODO: long doc breaking
+# TODO: titan PR adds
+# TODO: zero-len file asserts/check
 
 """
 The following distributed dataloaders are designed around 3 main principles:
@@ -343,8 +346,8 @@ class ArrowHandler(_ShardFileHandler):
     Non-standard data format, though.
     """
 
-    def __init__(self, col_name: str = "tokens"):
-        self.col_name = col_name
+    def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]):
+        self.col_names = col_names
 
     def is_legal(self, filepath: str):
         return "arrow" in os.path.splitext(filepath)[1]
@@ -356,7 +359,14 @@ def length(self, path: str):
         return self.open(path).num_record_batches
 
     def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
-        doc = reader.get_batch(index)[self.col_name]
+        frame = reader.get_batch(index)
+
+        doc = None
+        for name in self.col_names:
+            if name in frame.column_names:
+                doc = frame[name]
+                break
+        assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}"
         if len(doc) > 0 and doc[0].as_py() in drop_tokens:
             doc = doc.slice(1, len(doc) - 1)
         # Recheck len for edge case where doc=[eos]
@@ -376,17 +386,22 @@ class ParquetHandler(_ShardFileHandler):
     before getting/slicing. However, this is a standard and widely-used data format.
     """
 
-    def __init__(self, tokenizer_path: str, col_name: str = "text"):
+    def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
         self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
-        self.col_name = col_name
+        self.col_names = col_names
 
     def is_legal(self, filepath: str):
         return "parquet" in os.path.splitext(filepath)[1]
 
     def open(self, path: str):
-        return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[
-            self.col_name
-        ]
+        names = pq.read_metadata(path).schema.names
+        match = None
+        for name in self.col_names:
+            if name in names:
+                match = name
+                break
+        assert match is not None, f"None of column names {self.col_names} found in file headers {names}"
+        return pq.read_pandas(path, columns=[match], partitioning=None)[match]
 
     def length(self, path: str):
         return pq.read_metadata(path).num_rows
@@ -405,9 +420,9 @@ def slice(self, doc: List, index: int, n_pull: int) -> List:
 
 
 class AutoHandler(_ShardFileHandler):
-    def __init__(self, tokenizer_path: str, col_name: str = "text"):
-        self.PHandler = ParquetHandler(tokenizer_path, col_name)
-        self.AHandler = ArrowHandler()
+    def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
+        self.PHandler = ParquetHandler(tokenizer_path, col_names)
+        self.AHandler = ArrowHandler(col_names)
         self.current = _ShardFileHandler()
 
     def is_legal(self, filepath: str):
@@ -1000,6 +1015,7 @@ def __init__(
         # Position
         self.docset_index = 0
         self.chunk_index = -1
+        self.has_yielded = False
 
         # Stats
         self.epochs_seen = -1
@@ -1231,6 +1247,7 @@ def __iter__(self):
                                 self.percent_seen = (
                                     self.docs_seen * 100 / (self._len + 1e-9)
                                 )
+                            self.has_yielded = True
                             yield self._construct_chunk(j, doc, n_chunks)
 
                 # Advance RNG state
@@ -1251,8 +1268,12 @@ def __iter__(self):
                 n_chunks = math.ceil(doclen / self.chunksize)
                 for j in range(residual_chunks):
                     self.chunk_index = j
+                    self.has_yielded = True
                     yield self._construct_chunk(j, doc, n_chunks)
 
+            # Check that epoch was non-empty
+            assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}"
+
     def load_state_dict(self, state_dicts, sharded_input=False):
         self.setup()
         assert (

From 189214991b5685e1f7387f93c4a806c3b6861adb Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 23 May 2025 15:02:20 -0400
Subject: [PATCH 07/15] Add everything else

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/utils/dataset_utils.py | 130 ++++++++++++++++++--------------
 1 file changed, 75 insertions(+), 55 deletions(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 6dd074fa..0e1be1a1 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -15,7 +15,6 @@
 
 from fms_fsdp.utils.checkpointing_utils import get_latest
 
-# TODO: long doc breaking
 # TODO: titan PR adds
 # TODO: zero-len file asserts/check
 
@@ -359,8 +358,8 @@ def length(self, path: str):
         return self.open(path).num_record_batches
 
     def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
+        assert index < reader.num_record_batches, f"Illegal index {index} in set of {reader.num_record_batches} documents"
         frame = reader.get_batch(index)
-
         doc = None
         for name in self.col_names:
             if name in frame.column_names:
@@ -407,7 +406,8 @@ def length(self, path: str):
         return pq.read_metadata(path).num_rows
 
     def get(self, reader, index: int, drop_tokens: Set):
-        doc = self.tokenizer(str(reader[index]))["input_ids"]
+        assert index < reader.length(), f"Illegal index {index} in set of {reader.length()} documents"
+        doc = self.tokenizer(str(reader[index])[:1_000_000])["input_ids"]
         if len(doc) > 0 and doc[0] in drop_tokens:
             doc = doc[1:]
         # Recheck len for edge case where doc=[eos]
@@ -978,10 +978,10 @@ class StreamingDocDataset(_StatefulDataset):
         Documents below this length are skipped
     max_chunksize : int
         Maximum sequence length to return. Break long docs into chunks of this size or shorter.
+    max_consecutive_chunks : int
+        Number of doc chunks to emit before manually inserting EOS and resuming later.
     verbose : bool
         Track setup progress?
-    shuffle : bool
-        Shuffle shard file and document orders? (Disable for simple testing)
     """
 
     def __init__(
@@ -996,6 +996,7 @@ def __init__(
         seed: int = 42,
         min_length: int = 1,
         max_chunksize: int = 1024,
+        max_consecutive_chunks: int = 64,
         verbose: bool = False,
     ):
         super().__init__(datapath, rank, worldsize)
@@ -1008,6 +1009,7 @@ def __init__(
         self.eos = delimiter_token
         self.bos = bos_token
         self.drop = strip_tokens
+        self.max_consec = max_consecutive_chunks
         self.verbose = verbose
         # Map of doc indices to (shardid, min docid, max docid)
         self.docset: List[Any] = []
@@ -1022,6 +1024,7 @@ def __init__(
         self.tokens_seen = 0
         self.docs_seen = 0
         self.percent_seen = 0
+        self.consec = 0
 
         self.state_params = [
             "dataset",
@@ -1032,6 +1035,7 @@ def __init__(
             "docs_seen",
             "percent_seen",
             "lcg_state",
+            "consec",
         ]
 
         # Setup flags
@@ -1064,17 +1068,8 @@ def setup(self):
                 if self.filehandler.is_legal(os.path.join(root, name))
             ]
             shards.sort()  # Ensure consistent sharding across machines
-            start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize
-            end_frag = (
-                (self.rank + 1) * self.worldsize * len(shards)
-            ) // self.worldsize
-            shardfrags = [
-                (shards[i // self.worldsize], i % self.worldsize)
-                for i in range(start_frag, end_frag)
-            ]
-
-            # Assemble length of each owned shard file
 
+            # Find metadata file
             countfiles = []
             if os.path.exists(os.path.join(pardir, "meta")):
                 countfiles = [
@@ -1082,55 +1077,72 @@ def setup(self):
                     for x in os.listdir(os.path.join(pardir, "meta"))
                     if "counts" in x and "csv" in x
                 ]
-            doc_counts = {}
             if len(countfiles) > 0:
                 # Count file exists, use it
                 countpath = os.path.join(pardir, "meta", countfiles[0])
+            else:
+                countpath = ""
+
+            # Use shard file sizes to perform partitioning
+            # Create shardlist of form shardid -> [start%, end%]
+            if len(countfiles) > 0:
+                sizes = {}
+                with open(countpath, "r") as csvfile:
+                    reader = csv.DictReader(csvfile)
+                    for row in reader:
+                        fullpath = row["dataset/filename"]
+                        prefix = fullpath.find(dataset + "/")
+                        if prefix >= 0:
+                            key = fullpath[prefix + len(dataset) + 1 :]
+                            sizes[key] = int(row["size"])
+                shard_sizes = [sizes[shard] for shard in shards]
+            else:
+                shard_sizes = [
+                    os.path.getsize(os.path.join(datapath, shard)) for shard in shards
+                ]
+            shard_sizes = [s / sum(shard_sizes) for s in shard_sizes]
+            start = self.rank / self.worldsize
+            end = (self.rank + 1) / self.worldsize
+            shardset = {}
+            tally = 0
+            for i in range(len(shards)):
+                if tally <= end and tally + shard_sizes[i] >= start:
+                    shardset[shards[i]] = [
+                        min(max((start - tally) / shard_sizes[i], 0), 1),
+                        min(max((end - tally) / shard_sizes[i], 0), 1),
+                    ]
+                tally += shard_sizes[i]
+                # Count file exists, use it
                 with open(countpath, "r") as csvfile:
                     reader = csv.DictReader(csvfile)
                     for row in reader:
                         fullpath = row["dataset/filename"]
-                        prefix = fullpath.find("/" + dataset) + 1
-                        if prefix > 0:
+                        prefix = fullpath.find(dataset)
+                        if prefix >= 0:
                             key = fullpath[prefix + len(dataset) + 1 :]
                             doc_counts[key] = int(row["documents"])
             else:
                 # Count file does not exist, touch every owned file for length
-                unique_shardfiles = set(shard for shard, frag in shardfrags)
                 doc_counts = {
                     shard: self.filehandler.length(os.path.join(datapath, shard))
-                    for shard in unique_shardfiles
+                    for shard in shardset
                 }
 
-            # Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
-            ndocs = -1
-            docset = {}  # shardid -> (min docid, max docid)
-            for i, (shard, frag) in enumerate(shardfrags):
-                ndocs = doc_counts[shard]
-                doc_start = (ndocs * frag) // self.worldsize
-                doc_end = (
-                    ndocs * frag + ndocs
-                ) // self.worldsize - 1  # Inclusive upper bound
-                if shard not in docset:
-                    docset[shard] = [doc_start, doc_end]
-                min_d, max_d = docset[shard]
-                if doc_start < min_d:
-                    docset[shard][0] = doc_start
-                if doc_end > max_d:
-                    docset[shard][1] = doc_end
-
-            # Add shard entries to self.docset
+            # Assemble doc list for each file shard
+            # Create docset of form [shardid, min docid, max docid]
             doccount = 0
-            for shardid in docset:
-                min_d = docset[shardid][0]
-                max_d = docset[shardid][1]
-                self.docset.append((shardid, min_d, max_d))
-                doccount += max_d - min_d + 1
+            for shard in shardset:
+                ndocs = doc_counts[shard]
+                if ndocs > 0:
+                    doc_start = int(ndocs * shardset[shard][0])
+                    doc_end = max(doc_start, int(ndocs * shardset[shard][1]) - 1)  # inclusive upper bound
+                    self.docset.append([shard, doc_start, doc_end])
+                    doccount += doc_end - doc_start + 1
             self._len = doccount
 
             if self.verbose:
                 logging.info(
-                    f"    Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}"
+                    f"    Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}"
                 )
 
             # Shuffle shard files - guaranteed inconsistent across workers
@@ -1185,8 +1197,11 @@ def _construct_chunk(self, j, doc, n_chunks):
         # Add bos/eos tokens if needed
         if self.bos is not None and j == 0:
             chunk = [self.bos] + chunk
-        if j == n_chunks - 1:
+        if j == n_chunks - 1 or self.consec == self.max_consec:
             chunk = chunk + [self.eos]
+            self.consec = 0
+        else:
+            self.consec += 1
         return chunk
 
     def _random_map_docid(self, size):
@@ -1231,10 +1246,8 @@ def __iter__(self):
                 doclcg = self._random_map_docid(docrange)
                 docid = doclcg + mindoc
                 doc = self.filehandler.get(reader, docid, self.drop)
-                if len(doc) == 0:
-                    continue
                 doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
-                if doclen >= self.min_length:
+                if len(doc) > 0 and doclen >= self.min_length:
                     n_chunks = math.ceil(doclen / self.chunksize)
                     for j in range(n_chunks):
                         if i == 0 and j < residual_chunks:
@@ -1345,12 +1358,12 @@ def setup(self):
         if not self.is_setup:
             _StatefulDataset.setup(self)
             n_logical_shards = self.total_shards
+            assert (
+                n_logical_shards % self.worldsize == 0
+            ), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly"
             logicals = list(range(n_logical_shards))
             self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize)
             self.n_logicals = n_logical_shards // self.worldsize
-            assert (
-                len(self.logicals_owned) == self.n_logicals
-            ), "(world size * num workers) does not divide logical shards evenly"
 
             # Build logical shards
             for i in range(self.n_logicals):
@@ -1367,21 +1380,26 @@ def setup(self):
                     )
             [d.setup() for d in self.data]
             self.n_docs_remaining = [d._len for d in self.data]
-
+            assert (
+                sum(self.n_docs_remaining) > 0
+            ), f"No documents detected in shard {self.rank} of {self.datapath}"
+                
             self.generator = torch.Generator().manual_seed(self.rank)
 
     def __iter__(self):
         self.setup()
         # Grab one doc at a time in random order
         data = [iter(d) for d in self.data]
+        # Reset if we're rescaling into a prematurely finished epoch
+        # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
+        if sum(self.n_docs_remaining) == 0:
+            self.n_docs_remaining = [d._len for d in self.data]
+            self.generator.manual_seed(self.rank)
         while True:
             # Sample logical shard (or load from ckp)
             if self.current_reader is not None:
                 ind = self.current_reader
             else:
-                assert (
-                    sum(self.n_docs_remaining) > 0
-                ), f"No documents detected in {self.datapath}"
                 ind = torch.multinomial(
                     torch.tensor(self.n_docs_remaining, dtype=torch.float),
                     1,
@@ -1473,6 +1491,8 @@ def __init__(
             ]
         )
         assert len(self.datasets) > 0, "You must specify at least one dataset"
+        for d in datasets:
+            assert os.path.exists(os.path.join(datapath, d)), f"Invalid subdataset path: {os.path.join(datapath, d)}"
 
         if weights is not None:
             assert len(weights) == len(

From 426dd1635f526f656401be6e8b2aa905c0ed77e9 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 23 May 2025 15:12:42 -0400
Subject: [PATCH 08/15] Some cleanup (no continue)

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/utils/dataset_utils.py | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 0e1be1a1..f1dea578 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -15,9 +15,6 @@
 
 from fms_fsdp.utils.checkpointing_utils import get_latest
 
-# TODO: titan PR adds
-# TODO: zero-len file asserts/check
-
 """
 The following distributed dataloaders are designed around 3 main principles:
 
@@ -1274,10 +1271,8 @@ def __iter__(self):
             newpath = os.path.join(self.datapath, shardid)
             path, reader = self._get_reader(path, newpath, reader)
             doc = self.filehandler.get(reader, docid, self.drop)
-            if len(doc) == 0:
-                continue
             doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
-            if doclen >= self.min_length:
+            if len(doc) > 0 and doclen >= self.min_length:
                 n_chunks = math.ceil(doclen / self.chunksize)
                 for j in range(residual_chunks):
                     self.chunk_index = j

From 15f4d7e0d5140d9f689ee4ab69151835d66473fb Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 23 May 2025 15:16:21 -0400
Subject: [PATCH 09/15] Blacking

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py        |  4 +++-
 fms_fsdp/utils/dataloader_utils.py |  2 +-
 fms_fsdp/utils/dataset_utils.py    | 38 ++++++++++++++++++++++--------
 3 files changed, 32 insertions(+), 12 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index 20eb0b76..4a6e919f 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -15,7 +15,9 @@ class train_config:
     file_type: str = "arrow"
     col_name: str = "tokens"
     tokenizer_path: str = "/fsx/tokenizer"
-    datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
+    datasets: str = (
+        "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
+    )
     weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
     seq_length: int = 4096
     vocab_size: int = 32000
diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index 5720f239..5c2016bd 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -177,4 +177,4 @@ def splitstrip(x):
     datas = splitstrip(datas)
     weights = [float(x) for x in splitstrip(weights)]
     cols = splitstrip(cols)
-    return datas, weights, cols
\ No newline at end of file
+    return datas, weights, cols
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index f1dea578..a2d6eff9 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -355,14 +355,18 @@ def length(self, path: str):
         return self.open(path).num_record_batches
 
     def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
-        assert index < reader.num_record_batches, f"Illegal index {index} in set of {reader.num_record_batches} documents"
+        assert (
+            index < reader.num_record_batches
+        ), f"Illegal index {index} in set of {reader.num_record_batches} documents"
         frame = reader.get_batch(index)
         doc = None
         for name in self.col_names:
             if name in frame.column_names:
                 doc = frame[name]
                 break
-        assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}"
+        assert (
+            doc is not None
+        ), f"None of column names {self.col_names} found in file headers {frame.column_names}"
         if len(doc) > 0 and doc[0].as_py() in drop_tokens:
             doc = doc.slice(1, len(doc) - 1)
         # Recheck len for edge case where doc=[eos]
@@ -382,7 +386,9 @@ class ParquetHandler(_ShardFileHandler):
     before getting/slicing. However, this is a standard and widely-used data format.
     """
 
-    def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
+    def __init__(
+        self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]
+    ):
         self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
         self.col_names = col_names
 
@@ -396,14 +402,18 @@ def open(self, path: str):
             if name in names:
                 match = name
                 break
-        assert match is not None, f"None of column names {self.col_names} found in file headers {names}"
+        assert (
+            match is not None
+        ), f"None of column names {self.col_names} found in file headers {names}"
         return pq.read_pandas(path, columns=[match], partitioning=None)[match]
 
     def length(self, path: str):
         return pq.read_metadata(path).num_rows
 
     def get(self, reader, index: int, drop_tokens: Set):
-        assert index < reader.length(), f"Illegal index {index} in set of {reader.length()} documents"
+        assert (
+            index < reader.length()
+        ), f"Illegal index {index} in set of {reader.length()} documents"
         doc = self.tokenizer(str(reader[index])[:1_000_000])["input_ids"]
         if len(doc) > 0 and doc[0] in drop_tokens:
             doc = doc[1:]
@@ -417,7 +427,9 @@ def slice(self, doc: List, index: int, n_pull: int) -> List:
 
 
 class AutoHandler(_ShardFileHandler):
-    def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
+    def __init__(
+        self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]
+    ):
         self.PHandler = ParquetHandler(tokenizer_path, col_names)
         self.AHandler = ArrowHandler(col_names)
         self.current = _ShardFileHandler()
@@ -1132,7 +1144,9 @@ def setup(self):
                 ndocs = doc_counts[shard]
                 if ndocs > 0:
                     doc_start = int(ndocs * shardset[shard][0])
-                    doc_end = max(doc_start, int(ndocs * shardset[shard][1]) - 1)  # inclusive upper bound
+                    doc_end = max(
+                        doc_start, int(ndocs * shardset[shard][1]) - 1
+                    )  # inclusive upper bound
                     self.docset.append([shard, doc_start, doc_end])
                     doccount += doc_end - doc_start + 1
             self._len = doccount
@@ -1280,7 +1294,9 @@ def __iter__(self):
                     yield self._construct_chunk(j, doc, n_chunks)
 
             # Check that epoch was non-empty
-            assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}"
+            assert (
+                self.has_yielded
+            ), f"Empty logical shard detected: {self.dataset, self.docset}"
 
     def load_state_dict(self, state_dicts, sharded_input=False):
         self.setup()
@@ -1378,7 +1394,7 @@ def setup(self):
             assert (
                 sum(self.n_docs_remaining) > 0
             ), f"No documents detected in shard {self.rank} of {self.datapath}"
-                
+
             self.generator = torch.Generator().manual_seed(self.rank)
 
     def __iter__(self):
@@ -1487,7 +1503,9 @@ def __init__(
         )
         assert len(self.datasets) > 0, "You must specify at least one dataset"
         for d in datasets:
-            assert os.path.exists(os.path.join(datapath, d)), f"Invalid subdataset path: {os.path.join(datapath, d)}"
+            assert os.path.exists(
+                os.path.join(datapath, d)
+            ), f"Invalid subdataset path: {os.path.join(datapath, d)}"
 
         if weights is not None:
             assert len(weights) == len(

From a58eecde2e37d5cd56d5f00acb5a1194a19c87d2 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 23 May 2025 15:19:37 -0400
Subject: [PATCH 10/15] Reblacking

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index 4a6e919f..20eb0b76 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -15,9 +15,7 @@ class train_config:
     file_type: str = "arrow"
     col_name: str = "tokens"
     tokenizer_path: str = "/fsx/tokenizer"
-    datasets: str = (
-        "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
-    )
+    datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
     weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
     seq_length: int = 4096
     vocab_size: int = 32000

From 4c26fb49ab66692f414329d81452540e1c08e2c3 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 23 May 2025 15:22:22 -0400
Subject: [PATCH 11/15] isorting

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/utils/dataset_utils.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index a2d6eff9..029d450c 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -15,6 +15,7 @@
 
 from fms_fsdp.utils.checkpointing_utils import get_latest
 
+
 """
 The following distributed dataloaders are designed around 3 main principles:
 

From 0f667be41bbc6e38cf270af2a91dc67a2523b7ca Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Tue, 27 May 2025 11:10:05 -0400
Subject: [PATCH 12/15] Make doc cutoff an arg

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py        |  5 ++++-
 fms_fsdp/utils/dataloader_utils.py |  4 +++-
 fms_fsdp/utils/dataset_utils.py    | 17 ++++++++++++-----
 3 files changed, 19 insertions(+), 7 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index 20eb0b76..5fa56793 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -15,7 +15,9 @@ class train_config:
     file_type: str = "arrow"
     col_name: str = "tokens"
     tokenizer_path: str = "/fsx/tokenizer"
-    datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
+    datasets: str = (
+        "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
+    )
     weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
     seq_length: int = 4096
     vocab_size: int = 32000
@@ -26,6 +28,7 @@ class train_config:
     strip_tokens: str = ""
     logical_shards: int = 1024
     num_workers: int = 1
+    doc_cutoff: int = 1_000_000
 
     # fsdp policies
     sharding_strategy: str = "hsdp"
diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index 5c2016bd..88840c0e 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -89,7 +89,9 @@ def get_data_loader(cfg, rank, world_size):
         cfg.file_type in _handler_map
     ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
     if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
-        filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols)
+        filehandler = _handler_map[cfg.file_type](
+            cfg.tokenizer_path, cols, cfg.doc_cutoff
+        )
     else:
         filehandler = _handler_map[cfg.file_type](cols)
     # Base reader layer
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 029d450c..78854cd8 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -382,16 +382,20 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List:
 class ParquetHandler(_ShardFileHandler):
     """
     Reader for indexable parquet shard files, common in HF datasets.
-    Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens),
+    Here we assume reasonably small shard files (<5Gb) and truncate docs to max_doclen characters,
     as we rely on parquet/pandas for efficient file reading, and tokenize entire documents
     before getting/slicing. However, this is a standard and widely-used data format.
     """
 
     def __init__(
-        self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]
+        self,
+        tokenizer_path: str,
+        col_names: List[str] = ["text", "contents", "tokens"],
+        max_doclen: int = 1_000_000,
     ):
         self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
         self.col_names = col_names
+        self.max_doclen = max_doclen
 
     def is_legal(self, filepath: str):
         return "parquet" in os.path.splitext(filepath)[1]
@@ -415,7 +419,7 @@ def get(self, reader, index: int, drop_tokens: Set):
         assert (
             index < reader.length()
         ), f"Illegal index {index} in set of {reader.length()} documents"
-        doc = self.tokenizer(str(reader[index])[:1_000_000])["input_ids"]
+        doc = self.tokenizer(str(reader[index])[: self.max_doclen])["input_ids"]
         if len(doc) > 0 and doc[0] in drop_tokens:
             doc = doc[1:]
         # Recheck len for edge case where doc=[eos]
@@ -429,9 +433,12 @@ def slice(self, doc: List, index: int, n_pull: int) -> List:
 
 class AutoHandler(_ShardFileHandler):
     def __init__(
-        self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]
+        self,
+        tokenizer_path: str,
+        col_names: List[str] = ["text", "contents", "tokens"],
+        max_doclen: int = 1_000_000,
     ):
-        self.PHandler = ParquetHandler(tokenizer_path, col_names)
+        self.PHandler = ParquetHandler(tokenizer_path, col_names, max_doclen)
         self.AHandler = ArrowHandler(col_names)
         self.current = _ShardFileHandler()
 

From c9f2bdb9ae0f96d6346631294a711fe9ae8b9beb Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Tue, 27 May 2025 11:21:07 -0400
Subject: [PATCH 13/15] Re-insert missing countfile search lines

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/utils/dataset_utils.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 78854cd8..bc5ed772 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -1129,6 +1129,10 @@ def setup(self):
                         min(max((end - tally) / shard_sizes[i], 0), 1),
                     ]
                 tally += shard_sizes[i]
+
+            # Assemble length of each owned shard file
+            doc_counts = {}
+            if len(countfiles) > 0:
                 # Count file exists, use it
                 with open(countpath, "r") as csvfile:
                     reader = csv.DictReader(csvfile)

From 6d751e57ab62532094292151b1746a0ace4e004c Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Thu, 29 May 2025 15:32:30 -0400
Subject: [PATCH 14/15] Expose doc breakpoint in cfg

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py        | 1 +
 fms_fsdp/utils/dataloader_utils.py | 2 ++
 2 files changed, 3 insertions(+)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index 5fa56793..22b9a840 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -29,6 +29,7 @@ class train_config:
     logical_shards: int = 1024
     num_workers: int = 1
     doc_cutoff: int = 1_000_000
+    doc_breakpoint: int = 65_536
 
     # fsdp policies
     sharding_strategy: str = "hsdp"
diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index 88840c0e..6078ae7c 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -1,4 +1,5 @@
 import torch
+from math import ceil
 
 from fms_fsdp.utils.dataset_utils import (
     ArrowHandler,
@@ -104,6 +105,7 @@ def get_data_loader(cfg, rank, world_size):
         bos_token=cfg.bos_token,
         strip_tokens=set(droplist),
         min_length=3,
+        max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024),
         seed=cfg.seed,
     )
     # Add rescaling/resharding

From 46fafd7943b4b0ed8d9320a005d89f37ea747e48 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Tue, 3 Jun 2025 15:07:44 -0400
Subject: [PATCH 15/15] Follow symlinks during walk

---
 fms_fsdp/utils/dataset_utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index bc5ed772..80c55c0a 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -1080,7 +1080,7 @@ def setup(self):
             # listdir, assemble shardfraglist (ind -> shard, frag)
             shards = [
                 os.path.join(root, name)[len(datapath) + 1 :]
-                for root, dirs, files in os.walk(datapath, topdown=False)
+                for root, dirs, files in os.walk(datapath, topdown=False, followlinks=True)
                 for name in files
                 if self.filehandler.is_legal(os.path.join(root, name))
             ]