Skip to content
1 change: 1 addition & 0 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class train_config:
logical_shards: int = 1024
num_workers: int = 1
doc_cutoff: int = 1_000_000
doc_breakpoint: int = 65_536

# fsdp policies
sharding_strategy: str = "hsdp"
Expand Down
4 changes: 3 additions & 1 deletion fms_fsdp/utils/dataloader_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from math import ceil

from fms_fsdp.utils.dataset_utils import (
ArrowHandler,
Expand Down Expand Up @@ -94,7 +95,7 @@ def get_data_loader(cfg, rank, world_size):
)
else:
filehandler = _handler_map[cfg.file_type](cols)

# Base reader layer
data = StreamingDocDataset(
cfg.data_path,
Expand All @@ -105,6 +106,7 @@ def get_data_loader(cfg, rank, world_size):
bos_token=cfg.bos_token,
strip_tokens=set(droplist),
min_length=3,
max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024),
seed=cfg.seed,
)
# Add rescaling/resharding
Expand Down
Loading