diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 1d072958..22b9a840 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -15,7 +15,9 @@ class train_config: file_type: str = "arrow" col_name: str = "tokens" tokenizer_path: str = "/fsx/tokenizer" - datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange" + datasets: str = ( + "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange" + ) weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100" seq_length: int = 4096 vocab_size: int = 32000 @@ -26,6 +28,8 @@ class train_config: strip_tokens: str = "" logical_shards: int = 1024 num_workers: int = 1 + doc_cutoff: int = 1_000_000 + doc_breakpoint: int = 65_536 # fsdp policies sharding_strategy: str = "hsdp" @@ -72,3 +76,10 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 + + # FIM training + psm_rate: float = 0.0 + spm_rate: float = 0.0 + fim_pre: int = 1 + fim_mid: int = 2 + fim_suf: int = 3 diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 4b811d6d..6078ae7c 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -1,10 +1,12 @@ import torch +from math import ceil from fms_fsdp.utils.dataset_utils import ( ArrowHandler, AutoHandler, BufferDataset, CheckpointDataset, + FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -57,9 +59,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): +def get_data_loader(cfg, rank, world_size): """ - Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. + Pytorch dataloader for stateful, distributed, and rescalable language model training. Assumes underlying data is sequences of integer values. ... Args @@ -70,12 +72,13 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. - postprocess : List[Callable] - Any task-specific postprocessing to apply before handing over data. Steps will apply in - the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ - datasets, weights = parse_data_args(cfg.datasets, cfg.weights) + fim_training = cfg.psm_rate + cfg.spm_rate > 0 + if fim_training: + assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" + + datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) # Base streaming dataset. Returns doc chunks in sequence. # Implements dataset sampling and rescalability. @@ -87,9 +90,11 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": - filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name) + filehandler = _handler_map[cfg.file_type]( + cfg.tokenizer_path, cols, cfg.doc_cutoff + ) else: - filehandler = _handler_map[cfg.file_type] + filehandler = _handler_map[cfg.file_type](cols) # Base reader layer data = StreamingDocDataset( cfg.data_path, @@ -100,6 +105,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): bos_token=cfg.bos_token, strip_tokens=set(droplist), min_length=3, + max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024), seed=cfg.seed, ) # Add rescaling/resharding @@ -118,9 +124,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): verbose=(rank == 0), ) # Wrap above dataset in packing logic to form constant-length lines. + # Increment seq len to counteract CLM's one token removal. data = BufferDataset( data, - cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, + cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, @@ -128,10 +135,23 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. data = PreloadBufferDataset(data, 10000) - # Apply desired postprocessing steps in sequence + # Apply FIM transformation if needed + if fim_training: + data = FIMDataset( + data, + cfg.eos_token, + cfg.psm_rate, + cfg.spm_rate, + pre_token=cfg.fim_pre, + mid_token=cfg.fim_mid, + suf_token=cfg.fim_suf, + ) + + # Transform to tensors data = PreprocessDataset(data, torch.IntTensor) - for p in postprocess: - data = PreprocessDataset(data, p) + + # Apply CLM transformation + data = PreprocessDataset(data, causal_lm) # Enable auto-saving data = CheckpointDataset( @@ -146,7 +166,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): ) -def parse_data_args(datas, weights): +def parse_data_args(datas, weights, cols): # Convert csv inputs into corresponding lists of values def splitstrip(x): if isinstance(x, str): @@ -160,4 +180,5 @@ def splitstrip(x): datas = splitstrip(datas) weights = [float(x) for x in splitstrip(weights)] - return datas, weights + cols = splitstrip(cols) + return datas, weights, cols diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index aedc5862..80c55c0a 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -343,8 +343,8 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_name: str = "tokens"): - self.col_name = col_name + def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]): + self.col_names = col_names def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -356,7 +356,18 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): - doc = reader.get_batch(index)[self.col_name] + assert ( + index < reader.num_record_batches + ), f"Illegal index {index} in set of {reader.num_record_batches} documents" + frame = reader.get_batch(index) + doc = None + for name in self.col_names: + if name in frame.column_names: + doc = frame[name] + break + assert ( + doc is not None + ), f"None of column names {self.col_names} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -371,28 +382,44 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: class ParquetHandler(_ShardFileHandler): """ Reader for indexable parquet shard files, common in HF datasets. - Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens), + Here we assume reasonably small shard files (<5Gb) and truncate docs to max_doclen characters, as we rely on parquet/pandas for efficient file reading, and tokenize entire documents before getting/slicing. However, this is a standard and widely-used data format. """ - def __init__(self, tokenizer_path: str, col_name: str = "text"): + def __init__( + self, + tokenizer_path: str, + col_names: List[str] = ["text", "contents", "tokens"], + max_doclen: int = 1_000_000, + ): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - self.col_name = col_name + self.col_names = col_names + self.max_doclen = max_doclen def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] def open(self, path: str): - return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[ - self.col_name - ] + names = pq.read_metadata(path).schema.names + match = None + for name in self.col_names: + if name in names: + match = name + break + assert ( + match is not None + ), f"None of column names {self.col_names} found in file headers {names}" + return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): return pq.read_metadata(path).num_rows def get(self, reader, index: int, drop_tokens: Set): - doc = self.tokenizer(str(reader[index]))["input_ids"] + assert ( + index < reader.length() + ), f"Illegal index {index} in set of {reader.length()} documents" + doc = self.tokenizer(str(reader[index])[: self.max_doclen])["input_ids"] if len(doc) > 0 and doc[0] in drop_tokens: doc = doc[1:] # Recheck len for edge case where doc=[eos] @@ -405,9 +432,14 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: class AutoHandler(_ShardFileHandler): - def __init__(self, tokenizer_path: str, col_name: str = "text"): - self.PHandler = ParquetHandler(tokenizer_path, col_name) - self.AHandler = ArrowHandler() + def __init__( + self, + tokenizer_path: str, + col_names: List[str] = ["text", "contents", "tokens"], + max_doclen: int = 1_000_000, + ): + self.PHandler = ParquetHandler(tokenizer_path, col_names, max_doclen) + self.AHandler = ArrowHandler(col_names) self.current = _ShardFileHandler() def is_legal(self, filepath: str): @@ -696,6 +728,128 @@ def load_state_dict(self, state_dicts, sharded_input=False): return sharded_dicts +class FIMDataset(_WrapperDataset): + """ + Wrapper for a StatefulDataset that implements Fill-In-the-Middle training + (https://arxiv.org/pdf/2207.14255). + Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). + Breaks sequence apart into component document spans, and for each document span + of sufficient length, transforms with specified probability into: + PSM mode:
(prefix)(suffix) (middle) + SPM mode: (suffix) (prefix) (middle) + The new delimiter tokens can be omitted by passing in None. + Any extra tokens after transformation are dropped from the end of the sequence. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset + delimiter_token : any + Token used to indicate document boundaries + psm_rate : float + Chance to transform into PSM. Cannot exceed 1. + spm_rate : float + Chance to transform into SPM. Cannot exceed 1. + min_len : int + Minimum document length to perform FIM transformation + pre_token : any | none + Token used to indicate prefix section of the document + mid_token : any | none + Token used to indicate middle infill section of the document + suf_token : any | none + Token used to indicate suffix section of the document + """ + + def __init__( + self, + dataset: _StatefulDataset, + delimiter_token: Any, + psm_rate: float = 0.0, + spm_rate: float = 0.0, + min_len: int = 10, + pre_token=None, + mid_token=None, + suf_token=None, + ): + super().__init__(dataset) + assert ( + psm_rate + spm_rate > 0 + ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate." + assert ( + psm_rate + spm_rate <= 1 + ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1." + self.psm = psm_rate + self.spm = spm_rate + self.delimiter = delimiter_token + self.min_len = min_len + self.pref = pre_token + self.suff = suf_token + self.midd = mid_token + + self.g_state = None + self.generator = torch.Generator().manual_seed(self.rank) + self.state_params = ["g_state"] + + def __iter__(self): + dataset = iter(self.dataset) + while True: + inp = next(dataset) + len_ = len(inp) + i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_] + docs = [ + inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1) + ] # list[list[any]] + out = [] + for i in range(len(docs)): + doc = docs[i] + if len(docs[i]) >= self.min_len: + # decide psm, spm, or nothing + thresh = torch.rand([1], generator=self.generator).item() + if thresh < self.psm + self.spm: + # Split doc + doc = [] + if self.pref: + doc = [self.pref] + splits = torch.randint( + 0, len(docs[i]), [2], generator=self.generator + ).tolist() + pre = docs[i][: min(splits)] + mid = docs[i][min(splits) : max(splits)] + suf = docs[i][max(splits) :] + + if thresh < self.psm: + # PSM transformation + doc += pre + if self.suff: + doc.append(self.suff) + doc += suf + if self.midd: + doc.append(self.midd) + doc += mid + else: + # SPM transformation + if self.suff: + doc.append(self.suff) + doc += suf + if self.midd: + doc.append(self.midd) + doc += pre + mid + out += doc + [self.delimiter] + yield out[:len_] + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + return super().state_dict() + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + return sharded_dicts + + class BufferDataset(_WrapperDataset): """ Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them @@ -841,10 +995,10 @@ class StreamingDocDataset(_StatefulDataset): Documents below this length are skipped max_chunksize : int Maximum sequence length to return. Break long docs into chunks of this size or shorter. + max_consecutive_chunks : int + Number of doc chunks to emit before manually inserting EOS and resuming later. verbose : bool Track setup progress? - shuffle : bool - Shuffle shard file and document orders? (Disable for simple testing) """ def __init__( @@ -859,6 +1013,7 @@ def __init__( seed: int = 42, min_length: int = 1, max_chunksize: int = 1024, + max_consecutive_chunks: int = 64, verbose: bool = False, ): super().__init__(datapath, rank, worldsize) @@ -871,20 +1026,22 @@ def __init__( self.eos = delimiter_token self.bos = bos_token self.drop = strip_tokens + self.max_consec = max_consecutive_chunks self.verbose = verbose - self.docset: List[ - Any - ] = [] # map of doc indices to (shardid, min docid, max docid) + # Map of doc indices to (shardid, min docid, max docid) + self.docset: List[Any] = [] # Position self.docset_index = 0 self.chunk_index = -1 + self.has_yielded = False # Stats self.epochs_seen = -1 self.tokens_seen = 0 self.docs_seen = 0 self.percent_seen = 0 + self.consec = 0 self.state_params = [ "dataset", @@ -895,6 +1052,7 @@ def __init__( "docs_seen", "percent_seen", "lcg_state", + "consec", ] # Setup flags @@ -922,22 +1080,13 @@ def setup(self): # listdir, assemble shardfraglist (ind -> shard, frag) shards = [ os.path.join(root, name)[len(datapath) + 1 :] - for root, dirs, files in os.walk(datapath, topdown=False) + for root, dirs, files in os.walk(datapath, topdown=False, followlinks=True) for name in files if self.filehandler.is_legal(os.path.join(root, name)) ] shards.sort() # Ensure consistent sharding across machines - start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize - end_frag = ( - (self.rank + 1) * self.worldsize * len(shards) - ) // self.worldsize - shardfrags = [ - (shards[i // self.worldsize], i % self.worldsize) - for i in range(start_frag, end_frag) - ] - - # Assemble length of each owned shard file + # Find metadata file countfiles = [] if os.path.exists(os.path.join(pardir, "meta")): countfiles = [ @@ -945,55 +1094,78 @@ def setup(self): for x in os.listdir(os.path.join(pardir, "meta")) if "counts" in x and "csv" in x ] - doc_counts = {} if len(countfiles) > 0: # Count file exists, use it countpath = os.path.join(pardir, "meta", countfiles[0]) + else: + countpath = "" + + # Use shard file sizes to perform partitioning + # Create shardlist of form shardid -> [start%, end%] + if len(countfiles) > 0: + sizes = {} with open(countpath, "r") as csvfile: reader = csv.DictReader(csvfile) for row in reader: fullpath = row["dataset/filename"] - prefix = fullpath.find("/" + dataset) + 1 - if prefix > 0: + prefix = fullpath.find(dataset + "/") + if prefix >= 0: + key = fullpath[prefix + len(dataset) + 1 :] + sizes[key] = int(row["size"]) + shard_sizes = [sizes[shard] for shard in shards] + else: + shard_sizes = [ + os.path.getsize(os.path.join(datapath, shard)) for shard in shards + ] + shard_sizes = [s / sum(shard_sizes) for s in shard_sizes] + start = self.rank / self.worldsize + end = (self.rank + 1) / self.worldsize + shardset = {} + tally = 0 + for i in range(len(shards)): + if tally <= end and tally + shard_sizes[i] >= start: + shardset[shards[i]] = [ + min(max((start - tally) / shard_sizes[i], 0), 1), + min(max((end - tally) / shard_sizes[i], 0), 1), + ] + tally += shard_sizes[i] + + # Assemble length of each owned shard file + doc_counts = {} + if len(countfiles) > 0: + # Count file exists, use it + with open(countpath, "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find(dataset) + if prefix >= 0: key = fullpath[prefix + len(dataset) + 1 :] doc_counts[key] = int(row["documents"]) else: # Count file does not exist, touch every owned file for length - unique_shardfiles = set(shard for shard, frag in shardfrags) doc_counts = { shard: self.filehandler.length(os.path.join(datapath, shard)) - for shard in unique_shardfiles + for shard in shardset } - # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): - ndocs = -1 - docset = {} # shardid -> (min docid, max docid) - for i, (shard, frag) in enumerate(shardfrags): - ndocs = doc_counts[shard] - doc_start = (ndocs * frag) // self.worldsize - doc_end = ( - ndocs * frag + ndocs - ) // self.worldsize - 1 # Inclusive upper bound - if shard not in docset: - docset[shard] = [doc_start, doc_end] - min_d, max_d = docset[shard] - if doc_start < min_d: - docset[shard][0] = doc_start - if doc_end > max_d: - docset[shard][1] = doc_end - - # Add shard entries to self.docset + # Assemble doc list for each file shard + # Create docset of form [shardid, min docid, max docid] doccount = 0 - for shardid in docset: - min_d = docset[shardid][0] - max_d = docset[shardid][1] - self.docset.append((shardid, min_d, max_d)) - doccount += max_d - min_d + 1 + for shard in shardset: + ndocs = doc_counts[shard] + if ndocs > 0: + doc_start = int(ndocs * shardset[shard][0]) + doc_end = max( + doc_start, int(ndocs * shardset[shard][1]) - 1 + ) # inclusive upper bound + self.docset.append([shard, doc_start, doc_end]) + doccount += doc_end - doc_start + 1 self._len = doccount if self.verbose: logging.info( - f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}" + f" Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}" ) # Shuffle shard files - guaranteed inconsistent across workers @@ -1048,8 +1220,11 @@ def _construct_chunk(self, j, doc, n_chunks): # Add bos/eos tokens if needed if self.bos is not None and j == 0: chunk = [self.bos] + chunk - if j == n_chunks - 1: + if j == n_chunks - 1 or self.consec == self.max_consec: chunk = chunk + [self.eos] + self.consec = 0 + else: + self.consec += 1 return chunk def _random_map_docid(self, size): @@ -1094,10 +1269,8 @@ def __iter__(self): doclcg = self._random_map_docid(docrange) docid = doclcg + mindoc doc = self.filehandler.get(reader, docid, self.drop) - if len(doc) == 0: - continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: + if len(doc) > 0 and doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) for j in range(n_chunks): if i == 0 and j < residual_chunks: @@ -1110,6 +1283,7 @@ def __iter__(self): self.percent_seen = ( self.docs_seen * 100 / (self._len + 1e-9) ) + self.has_yielded = True yield self._construct_chunk(j, doc, n_chunks) # Advance RNG state @@ -1123,15 +1297,19 @@ def __iter__(self): newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) doc = self.filehandler.get(reader, docid, self.drop) - if len(doc) == 0: - continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: + if len(doc) > 0 and doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) for j in range(residual_chunks): self.chunk_index = j + self.has_yielded = True yield self._construct_chunk(j, doc, n_chunks) + # Check that epoch was non-empty + assert ( + self.has_yielded + ), f"Empty logical shard detected: {self.dataset, self.docset}" + def load_state_dict(self, state_dicts, sharded_input=False): self.setup() assert ( @@ -1203,12 +1381,12 @@ def setup(self): if not self.is_setup: _StatefulDataset.setup(self) n_logical_shards = self.total_shards + assert ( + n_logical_shards % self.worldsize == 0 + ), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" logicals = list(range(n_logical_shards)) self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) self.n_logicals = n_logical_shards // self.worldsize - assert ( - len(self.logicals_owned) == self.n_logicals - ), "(world size * num workers) does not divide logical shards evenly" # Build logical shards for i in range(self.n_logicals): @@ -1225,6 +1403,9 @@ def setup(self): ) [d.setup() for d in self.data] self.n_docs_remaining = [d._len for d in self.data] + assert ( + sum(self.n_docs_remaining) > 0 + ), f"No documents detected in shard {self.rank} of {self.datapath}" self.generator = torch.Generator().manual_seed(self.rank) @@ -1232,14 +1413,16 @@ def __iter__(self): self.setup() # Grab one doc at a time in random order data = [iter(d) for d in self.data] + # Reset if we're rescaling into a prematurely finished epoch + # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] ) + if sum(self.n_docs_remaining) == 0: + self.n_docs_remaining = [d._len for d in self.data] + self.generator.manual_seed(self.rank) while True: # Sample logical shard (or load from ckp) if self.current_reader is not None: ind = self.current_reader else: - assert ( - sum(self.n_docs_remaining) > 0 - ), f"No documents detected in {self.datapath}" ind = torch.multinomial( torch.tensor(self.n_docs_remaining, dtype=torch.float), 1, @@ -1331,6 +1514,10 @@ def __init__( ] ) assert len(self.datasets) > 0, "You must specify at least one dataset" + for d in datasets: + assert os.path.exists( + os.path.join(datapath, d) + ), f"Invalid subdataset path: {os.path.join(datapath, d)}" if weights is not None: assert len(weights) == len( diff --git a/main_training_llama.py b/main_training_llama.py index 67cccee2..a7e1020f 100644 --- a/main_training_llama.py +++ b/main_training_llama.py @@ -122,9 +122,11 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/") - if not os.path.isfile(cfg.ckpt_load_path) - else cfg.ckpt_load_path, + path=( + os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path + ), strict=False, ) if not is_resuming: diff --git a/main_training_mamba.py b/main_training_mamba.py index 3619ea25..68a3c830 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -119,9 +119,11 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/") - if not os.path.isfile(cfg.ckpt_load_path) - else cfg.ckpt_load_path, + path=( + os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path + ), strict=False, ) if not is_resuming: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 83b2426b..40bef481 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -632,6 +632,10 @@ def test_multi_reload_stress(): # preload / sample / scale / doc pipeline multi_reload_stress_check(lambda: d6(d5(d4()))) + # Add FIM dataset + d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x] + multi_reload_stress_check(lambda: d7(d6(d5(d4())))) + # SCALABLEDATASET TESTS