Skip to content
7 changes: 7 additions & 0 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ class train_config:
stage2_prompt_length: int = 64
stage2_batch_size: int = 96
stage2_seq_length: int = 256

# FIM training
psm_rate: float = 0.0
spm_rate: float = 0.0
fim_pre: int = 1
fim_mid: int = 2
fim_suf: int = 3
45 changes: 31 additions & 14 deletions fms_fsdp/utils/dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AutoHandler,
BufferDataset,
CheckpointDataset,
FIMDataset,
ParquetHandler,
PreloadBufferDataset,
PreprocessDataset,
Expand Down Expand Up @@ -57,9 +58,9 @@ def __iter__(self):
return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size)


def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
def get_data_loader(cfg, rank, world_size):
"""
Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
Pytorch dataloader for stateful, distributed, and rescalable language model training.
Assumes underlying data is sequences of integer values.
...
Args
Expand All @@ -70,12 +71,13 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
Rank of current distributed worker. Used for handling dataset sharding logic.
world_size : int
Number of distributed workers. Used for handling dataset sharding logic.
postprocess : List[Callable]
Any task-specific postprocessing to apply before handing over data. Steps will apply in
the order provided by the user. For CLM training, use postprocess=[causal_lm].
"""

datasets, weights = parse_data_args(cfg.datasets, cfg.weights)
fim_training = cfg.psm_rate + cfg.spm_rate > 0
if fim_training:
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"

datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)

# Base streaming dataset. Returns doc chunks in sequence.
# Implements dataset sampling and rescalability.
Expand All @@ -87,9 +89,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
cfg.file_type in _handler_map
), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name)
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols)
else:
filehandler = _handler_map[cfg.file_type]
filehandler = _handler_map[cfg.file_type](cols)
# Base reader layer
data = StreamingDocDataset(
cfg.data_path,
Expand Down Expand Up @@ -118,20 +120,34 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
verbose=(rank == 0),
)
# Wrap above dataset in packing logic to form constant-length lines.
# Increment seq len to counteract CLM's one token removal.
data = BufferDataset(
data,
cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1,
cfg.seq_length + 1,
bos_token=cfg.bol_token,
eos_token=cfg.eol_token,
pack_hard=True,
)
# Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
data = PreloadBufferDataset(data, 10000)

# Apply desired postprocessing steps in sequence
# Apply FIM transformation if needed
if fim_training:
data = FIMDataset(
data,
cfg.eos_token,
cfg.psm_rate,
cfg.spm_rate,
pre_token=cfg.fim_pre,
mid_token=cfg.fim_mid,
suf_token=cfg.fim_suf,
)

# Transform to tensors
data = PreprocessDataset(data, torch.IntTensor)
for p in postprocess:
data = PreprocessDataset(data, p)

# Apply CLM transformation
data = PreprocessDataset(data, causal_lm)

# Enable auto-saving
data = CheckpointDataset(
Expand All @@ -146,7 +162,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
)


def parse_data_args(datas, weights):
def parse_data_args(datas, weights, cols):
# Convert csv inputs into corresponding lists of values
def splitstrip(x):
if isinstance(x, str):
Expand All @@ -160,4 +176,5 @@ def splitstrip(x):

datas = splitstrip(datas)
weights = [float(x) for x in splitstrip(weights)]
return datas, weights
cols = splitstrip(cols)
return datas, weights, cols
Loading
Loading