Skip to content
6 changes: 4 additions & 2 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class train_config:
file_type: str = "arrow"
col_name: str = "tokens"
tokenizer_path: str = "/fsx/tokenizer"
datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
datasets: str = (
"lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
)
weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
seq_length: int = 4096
vocab_size: int = 32000
Expand All @@ -26,6 +28,7 @@ class train_config:
strip_tokens: str = ""
logical_shards: int = 1024
num_workers: int = 1
doc_cutoff: int = 1_000_000

# fsdp policies
sharding_strategy: str = "hsdp"
Expand Down Expand Up @@ -74,7 +77,6 @@ class train_config:
stage2_seq_length: int = 256

# FIM training
fim_training: bool = False
psm_rate: float = 0.0
spm_rate: float = 0.0
fim_pre: int = 1
Expand Down
12 changes: 8 additions & 4 deletions fms_fsdp/utils/dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def get_data_loader(cfg, rank, world_size):
world_size : int
Number of distributed workers. Used for handling dataset sharding logic.
"""
if cfg.fim_training:

fim_training = cfg.psm_rate + cfg.spm_rate > 0
if fim_training:
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"

datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)
Expand All @@ -87,8 +89,10 @@ def get_data_loader(cfg, rank, world_size):
cfg.file_type in _handler_map
), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols)
elif cfg.file_type == "arrow":
filehandler = _handler_map[cfg.file_type](
cfg.tokenizer_path, cols, cfg.doc_cutoff
)
else:
filehandler = _handler_map[cfg.file_type](cols)

# Base reader layer
Expand Down Expand Up @@ -131,7 +135,7 @@ def get_data_loader(cfg, rank, world_size):
data = PreloadBufferDataset(data, 10000)

# Apply FIM transformation if needed
if cfg.fim_training:
if fim_training:
data = FIMDataset(
data,
cfg.eos_token,
Expand Down
87 changes: 54 additions & 33 deletions fms_fsdp/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,9 @@ class ArrowHandler(_ShardFileHandler):
Non-standard data format, though.
"""

def __init__(self, col_names: List[str] = ["tokens"]):
def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]):
self.col_names = col_names

def is_legal(self, filepath: str):
return "arrow" in os.path.splitext(filepath)[1]

Expand All @@ -356,14 +356,18 @@ def length(self, path: str):
return self.open(path).num_record_batches

def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
assert (
index < reader.num_record_batches
), f"Illegal index {index} in set of {reader.num_record_batches} documents"
frame = reader.get_batch(index)

doc = None
for name in self.col_names:
if name in frame.column_names:
doc = frame[name]
break
assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}"
assert (
doc is not None
), f"None of column names {self.col_names} found in file headers {frame.column_names}"
if len(doc) > 0 and doc[0].as_py() in drop_tokens:
doc = doc.slice(1, len(doc) - 1)
# Recheck len for edge case where doc=[eos]
Expand All @@ -378,14 +382,20 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List:
class ParquetHandler(_ShardFileHandler):
"""
Reader for indexable parquet shard files, common in HF datasets.
Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens),
Here we assume reasonably small shard files (<5Gb) and truncate docs to max_doclen characters,
as we rely on parquet/pandas for efficient file reading, and tokenize entire documents
before getting/slicing. However, this is a standard and widely-used data format.
"""

def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]):
def __init__(
self,
tokenizer_path: str,
col_names: List[str] = ["text", "contents", "tokens"],
max_doclen: int = 1_000_000,
):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.col_names = col_names
self.max_doclen = max_doclen

def is_legal(self, filepath: str):
return "parquet" in os.path.splitext(filepath)[1]
Expand All @@ -397,18 +407,19 @@ def open(self, path: str):
if name in names:
match = name
break
assert match is not None, f"None of column names {self.col_names} found in file headers {names}"
assert (
match is not None
), f"None of column names {self.col_names} found in file headers {names}"
return pq.read_pandas(path, columns=[match], partitioning=None)[match]

def length(self, path: str):
return pq.read_metadata(path).num_rows

def get(self, reader, index: int, drop_tokens: Set):

document_str = str(reader[index])

doc = self.tokenizer(str(reader[index]))["input_ids"]

assert (
index < reader.length()
), f"Illegal index {index} in set of {reader.length()} documents"
doc = self.tokenizer(str(reader[index])[: self.max_doclen])["input_ids"]
if len(doc) > 0 and doc[0] in drop_tokens:
doc = doc[1:]
# Recheck len for edge case where doc=[eos]
Expand All @@ -421,8 +432,13 @@ def slice(self, doc: List, index: int, n_pull: int) -> List:


class AutoHandler(_ShardFileHandler):
def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
self.PHandler = ParquetHandler(tokenizer_path, col_names)
def __init__(
self,
tokenizer_path: str,
col_names: List[str] = ["text", "contents", "tokens"],
max_doclen: int = 1_000_000,
):
self.PHandler = ParquetHandler(tokenizer_path, col_names, max_doclen)
self.AHandler = ArrowHandler(col_names)
self.current = _ShardFileHandler()

Expand Down Expand Up @@ -979,10 +995,10 @@ class StreamingDocDataset(_StatefulDataset):
Documents below this length are skipped
max_chunksize : int
Maximum sequence length to return. Break long docs into chunks of this size or shorter.
max_consecutive_chunks : int
Number of doc chunks to emit before manually inserting EOS and resuming later.
verbose : bool
Track setup progress?
shuffle : bool
Shuffle shard file and document orders? (Disable for simple testing)
"""

def __init__(
Expand Down Expand Up @@ -1095,12 +1111,11 @@ def setup(self):
for row in reader:
fullpath = row["dataset/filename"]
prefix = fullpath.find(dataset + "/")
if prefix > 0:
if prefix >= 0:
key = fullpath[prefix + len(dataset) + 1 :]
sizes[key] = int(row["size"])
shard_sizes = [sizes[shard] for shard in shards]
else:
# Count file does not exist, touch every owned file for length
shard_sizes = [
os.path.getsize(os.path.join(datapath, shard)) for shard in shards
]
Expand All @@ -1125,7 +1140,7 @@ def setup(self):
reader = csv.DictReader(csvfile)
for row in reader:
fullpath = row["dataset/filename"]
prefix = fullpath.find(dataset + "/")
prefix = fullpath.find(dataset)
if prefix >= 0:
key = fullpath[prefix + len(dataset) + 1 :]
doc_counts[key] = int(row["documents"])
Expand All @@ -1141,10 +1156,13 @@ def setup(self):
doccount = 0
for shard in shardset:
ndocs = doc_counts[shard]
doc_start = int(ndocs * shardset[shard][0])
doc_end = max(doc_start, int(ndocs * shardset[shard][1]) - 1) # inclusive upper bound
self.docset.append([shard, doc_start, doc_end])
doccount += doc_end - doc_start + 1
if ndocs > 0:
doc_start = int(ndocs * shardset[shard][0])
doc_end = max(
doc_start, int(ndocs * shardset[shard][1]) - 1
) # inclusive upper bound
self.docset.append([shard, doc_start, doc_end])
doccount += doc_end - doc_start + 1
self._len = doccount

if self.verbose:
Expand Down Expand Up @@ -1253,10 +1271,8 @@ def __iter__(self):
doclcg = self._random_map_docid(docrange)
docid = doclcg + mindoc
doc = self.filehandler.get(reader, docid, self.drop)
if len(doc) == 0:
continue
doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
if doclen >= self.min_length:
if len(doc) > 0 and doclen >= self.min_length:
n_chunks = math.ceil(doclen / self.chunksize)
for j in range(n_chunks):
if i == 0 and j < residual_chunks:
Expand All @@ -1283,18 +1299,18 @@ def __iter__(self):
newpath = os.path.join(self.datapath, shardid)
path, reader = self._get_reader(path, newpath, reader)
doc = self.filehandler.get(reader, docid, self.drop)
if len(doc) == 0:
continue
doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
if doclen >= self.min_length:
if len(doc) > 0 and doclen >= self.min_length:
n_chunks = math.ceil(doclen / self.chunksize)
for j in range(residual_chunks):
self.chunk_index = j
self.has_yielded = True
yield self._construct_chunk(j, doc, n_chunks)

# Check that epoch was non-empty
assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}"
assert (
self.has_yielded
), f"Empty logical shard detected: {self.dataset, self.docset}"

def load_state_dict(self, state_dicts, sharded_input=False):
self.setup()
Expand Down Expand Up @@ -1367,12 +1383,12 @@ def setup(self):
if not self.is_setup:
_StatefulDataset.setup(self)
n_logical_shards = self.total_shards
assert (
n_logical_shards % self.worldsize == 0
), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly"
logicals = list(range(n_logical_shards))
self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize)
self.n_logicals = n_logical_shards // self.worldsize
assert (
len(self.logicals_owned) == self.n_logicals
), "(world size * num workers) does not divide logical shards evenly"

# Build logical shards
for i in range(self.n_logicals):
Expand Down Expand Up @@ -1403,6 +1419,7 @@ def __iter__(self):
# (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
if sum(self.n_docs_remaining) == 0:
self.n_docs_remaining = [d._len for d in self.data]
self.generator.manual_seed(self.rank)
while True:
# Sample logical shard (or load from ckp)
if self.current_reader is not None:
Expand Down Expand Up @@ -1499,6 +1516,10 @@ def __init__(
]
)
assert len(self.datasets) > 0, "You must specify at least one dataset"
for d in datasets:
assert os.path.exists(
os.path.join(datapath, d)
), f"Invalid subdataset path: {os.path.join(datapath, d)}"

if weights is not None:
assert len(weights) == len(
Expand Down
Loading