diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index eadef50f..5fa56793 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -15,7 +15,9 @@ class train_config: file_type: str = "arrow" col_name: str = "tokens" tokenizer_path: str = "/fsx/tokenizer" - datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange" + datasets: str = ( + "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange" + ) weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100" seq_length: int = 4096 vocab_size: int = 32000 @@ -26,6 +28,7 @@ class train_config: strip_tokens: str = "" logical_shards: int = 1024 num_workers: int = 1 + doc_cutoff: int = 1_000_000 # fsdp policies sharding_strategy: str = "hsdp" @@ -74,7 +77,6 @@ class train_config: stage2_seq_length: int = 256 # FIM training - fim_training: bool = False psm_rate: float = 0.0 spm_rate: float = 0.0 fim_pre: int = 1 diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 5152e532..022ca4b5 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -72,7 +72,9 @@ def get_data_loader(cfg, rank, world_size): world_size : int Number of distributed workers. Used for handling dataset sharding logic. """ - if cfg.fim_training: + + fim_training = cfg.psm_rate + cfg.spm_rate > 0 + if fim_training: assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) @@ -87,8 +89,10 @@ def get_data_loader(cfg, rank, world_size): cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": - filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols) - elif cfg.file_type == "arrow": + filehandler = _handler_map[cfg.file_type]( + cfg.tokenizer_path, cols, cfg.doc_cutoff + ) + else: filehandler = _handler_map[cfg.file_type](cols) # Base reader layer @@ -131,7 +135,7 @@ def get_data_loader(cfg, rank, world_size): data = PreloadBufferDataset(data, 10000) # Apply FIM transformation if needed - if cfg.fim_training: + if fim_training: data = FIMDataset( data, cfg.eos_token, diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index e713a87c..ece5b343 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -343,9 +343,9 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_names: List[str] = ["tokens"]): + def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]): self.col_names = col_names - + def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -356,14 +356,18 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): + assert ( + index < reader.num_record_batches + ), f"Illegal index {index} in set of {reader.num_record_batches} documents" frame = reader.get_batch(index) - doc = None for name in self.col_names: if name in frame.column_names: doc = frame[name] break - assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}" + assert ( + doc is not None + ), f"None of column names {self.col_names} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -378,14 +382,20 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: class ParquetHandler(_ShardFileHandler): """ Reader for indexable parquet shard files, common in HF datasets. - Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens), + Here we assume reasonably small shard files (<5Gb) and truncate docs to max_doclen characters, as we rely on parquet/pandas for efficient file reading, and tokenize entire documents before getting/slicing. However, this is a standard and widely-used data format. """ - def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]): + def __init__( + self, + tokenizer_path: str, + col_names: List[str] = ["text", "contents", "tokens"], + max_doclen: int = 1_000_000, + ): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.col_names = col_names + self.max_doclen = max_doclen def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] @@ -397,18 +407,19 @@ def open(self, path: str): if name in names: match = name break - assert match is not None, f"None of column names {self.col_names} found in file headers {names}" + assert ( + match is not None + ), f"None of column names {self.col_names} found in file headers {names}" return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): return pq.read_metadata(path).num_rows def get(self, reader, index: int, drop_tokens: Set): - - document_str = str(reader[index]) - - doc = self.tokenizer(str(reader[index]))["input_ids"] - + assert ( + index < reader.length() + ), f"Illegal index {index} in set of {reader.length()} documents" + doc = self.tokenizer(str(reader[index])[: self.max_doclen])["input_ids"] if len(doc) > 0 and doc[0] in drop_tokens: doc = doc[1:] # Recheck len for edge case where doc=[eos] @@ -421,8 +432,13 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: class AutoHandler(_ShardFileHandler): - def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]): - self.PHandler = ParquetHandler(tokenizer_path, col_names) + def __init__( + self, + tokenizer_path: str, + col_names: List[str] = ["text", "contents", "tokens"], + max_doclen: int = 1_000_000, + ): + self.PHandler = ParquetHandler(tokenizer_path, col_names, max_doclen) self.AHandler = ArrowHandler(col_names) self.current = _ShardFileHandler() @@ -979,10 +995,10 @@ class StreamingDocDataset(_StatefulDataset): Documents below this length are skipped max_chunksize : int Maximum sequence length to return. Break long docs into chunks of this size or shorter. + max_consecutive_chunks : int + Number of doc chunks to emit before manually inserting EOS and resuming later. verbose : bool Track setup progress? - shuffle : bool - Shuffle shard file and document orders? (Disable for simple testing) """ def __init__( @@ -1095,12 +1111,11 @@ def setup(self): for row in reader: fullpath = row["dataset/filename"] prefix = fullpath.find(dataset + "/") - if prefix > 0: + if prefix >= 0: key = fullpath[prefix + len(dataset) + 1 :] sizes[key] = int(row["size"]) shard_sizes = [sizes[shard] for shard in shards] else: - # Count file does not exist, touch every owned file for length shard_sizes = [ os.path.getsize(os.path.join(datapath, shard)) for shard in shards ] @@ -1125,7 +1140,7 @@ def setup(self): reader = csv.DictReader(csvfile) for row in reader: fullpath = row["dataset/filename"] - prefix = fullpath.find(dataset + "/") + prefix = fullpath.find(dataset) if prefix >= 0: key = fullpath[prefix + len(dataset) + 1 :] doc_counts[key] = int(row["documents"]) @@ -1141,10 +1156,13 @@ def setup(self): doccount = 0 for shard in shardset: ndocs = doc_counts[shard] - doc_start = int(ndocs * shardset[shard][0]) - doc_end = max(doc_start, int(ndocs * shardset[shard][1]) - 1) # inclusive upper bound - self.docset.append([shard, doc_start, doc_end]) - doccount += doc_end - doc_start + 1 + if ndocs > 0: + doc_start = int(ndocs * shardset[shard][0]) + doc_end = max( + doc_start, int(ndocs * shardset[shard][1]) - 1 + ) # inclusive upper bound + self.docset.append([shard, doc_start, doc_end]) + doccount += doc_end - doc_start + 1 self._len = doccount if self.verbose: @@ -1253,10 +1271,8 @@ def __iter__(self): doclcg = self._random_map_docid(docrange) docid = doclcg + mindoc doc = self.filehandler.get(reader, docid, self.drop) - if len(doc) == 0: - continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: + if len(doc) > 0 and doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) for j in range(n_chunks): if i == 0 and j < residual_chunks: @@ -1283,10 +1299,8 @@ def __iter__(self): newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) doc = self.filehandler.get(reader, docid, self.drop) - if len(doc) == 0: - continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: + if len(doc) > 0 and doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) for j in range(residual_chunks): self.chunk_index = j @@ -1294,7 +1308,9 @@ def __iter__(self): yield self._construct_chunk(j, doc, n_chunks) # Check that epoch was non-empty - assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}" + assert ( + self.has_yielded + ), f"Empty logical shard detected: {self.dataset, self.docset}" def load_state_dict(self, state_dicts, sharded_input=False): self.setup() @@ -1367,12 +1383,12 @@ def setup(self): if not self.is_setup: _StatefulDataset.setup(self) n_logical_shards = self.total_shards + assert ( + n_logical_shards % self.worldsize == 0 + ), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" logicals = list(range(n_logical_shards)) self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) self.n_logicals = n_logical_shards // self.worldsize - assert ( - len(self.logicals_owned) == self.n_logicals - ), "(world size * num workers) does not divide logical shards evenly" # Build logical shards for i in range(self.n_logicals): @@ -1403,6 +1419,7 @@ def __iter__(self): # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] ) if sum(self.n_docs_remaining) == 0: self.n_docs_remaining = [d._len for d in self.data] + self.generator.manual_seed(self.rank) while True: # Sample logical shard (or load from ckp) if self.current_reader is not None: @@ -1499,6 +1516,10 @@ def __init__( ] ) assert len(self.datasets) > 0, "You must specify at least one dataset" + for d in datasets: + assert os.path.exists( + os.path.join(datapath, d) + ), f"Invalid subdataset path: {os.path.join(datapath, d)}" if weights is not None: assert len(weights) == len(