From ac2473e7dc9aa700559a87a568dc6bf755033ff2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 15 Apr 2025 13:36:19 -0400 Subject: [PATCH 1/3] Sampling parameters --- fast_llm/data/data/abstract.py | 7 +- fast_llm/data/data/gpt/data.py | 44 ++++++------- fast_llm/data/dataset/blended.py | 2 +- fast_llm/data/dataset/config.py | 37 +++++++++-- fast_llm/data/dataset/gpt/config.py | 34 +++++++--- fast_llm/data/dataset/gpt/fim.py | 2 +- fast_llm/data/dataset/gpt/random.py | 6 +- fast_llm/data/dataset/gpt/sampled.py | 95 ++++++++++++++++------------ fast_llm/engine/training/trainer.py | 10 ++- fast_llm/models/gpt/trainer.py | 17 ++++- tests/data/common.py | 31 ++++++--- tests/data/test_fim.py | 7 +- 12 files changed, 187 insertions(+), 105 deletions(-) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index 1addf518..e24d3998 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -4,6 +4,7 @@ from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig +from fast_llm.data.dataset.config import SamplingParameters from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig @@ -13,7 +14,7 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): _distributed: "Distributed" - _samples_per_dataset: dict[str, int] + _sampling_parameters: dict[str, SamplingParameters] _cache_directory: pathlib.Path | None def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: @@ -24,12 +25,12 @@ def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> def setup( self, distributed: "Distributed", - samples_per_dataset: dict[str, int], + sampling_parameters: dict[str, SamplingParameters], cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: self._distributed = distributed - self._samples_per_dataset = samples_per_dataset + self._sampling_parameters = sampling_parameters self._cache_directory = cache_directory @property diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6e5a519a..02c1b6c0 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -13,7 +13,7 @@ from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator @@ -34,15 +34,13 @@ class GPTBatch: sequence_lengths: list[torch.Tensor] | None = None -def gpt_data_collate_fn( - batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool -) -> GPTBatch: +def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None - if use_loss_masking_spans: + if sampling_parameters.use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] - if not cross_document_attention: + if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths @@ -57,6 +55,7 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ _datasets: dict[str, SampledDataset] + _sampling_parameters: dict[str, GPTSamplingParameters] _tokenizer: Tokenizer | None _is_setup: bool = False @@ -64,23 +63,17 @@ def __init__( self, config: GPTDataConfig, distributed_config: DistributedConfig, - vocab_size: int, - max_sequence_length: int, - cross_document_attention: bool = True, ): """ Create the data and gather some basic information on the dataset(s). Should be `setup` before use. """ super().__init__(config, distributed_config) - self._vocab_size = vocab_size - self._max_sequence_length = max_sequence_length - self._cross_document_attention = cross_document_attention def setup( self, distributed: "Distributed", - samples_per_dataset: dict[str, int], + sampling_parameters: dict[str, GPTSamplingParameters], cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: @@ -88,20 +81,21 @@ def setup( Load the datasets, and prepare or load the samplings. This may take a while and a significant amount of cpu memory. """ + super().setup(distributed, sampling_parameters, cache_directory) + # Check and raise an error if a used dataset is not defined. - for dataset_name in samples_per_dataset.keys(): + for dataset_name in self._sampling_parameters.keys(): if dataset_name not in self._config.datasets: raise ValueError(f"Dataset {dataset_name} not found.") # Check and warn if there are defined datasets that are not used. - unused_datasets = self._config.datasets.keys() - samples_per_dataset.keys() + unused_datasets = self._config.datasets.keys() - self._sampling_parameters.keys() if unused_datasets: warnings.warn( f"The following datasets are defined but not used: {', '.join(unused_datasets)}. " "Ensure this is intentional, or update the configuration accordingly." ) - super().setup(distributed, samples_per_dataset, cache_directory) log_main_rank(f"Preparing dataset. This may take several minutes.") self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer) @@ -110,19 +104,19 @@ def setup( warnings.warn(f"Using the dataset directory for the index cache.") self._datasets = {} - for dataset_name, num_samples in samples_per_dataset.items(): - if num_samples > 0: + for dataset_name, sampling_parameters in self._sampling_parameters.items(): + if self._tokenizer is not None: + # TODO: Too constraining? + Assert.eq(self._tokenizer.vocab_size, sampling_parameters.vocab_size) + if sampling_parameters.num_samples > 0: sampling = GPTSamplingData( - num_samples=num_samples, config=self._config.sampling, + parameters=sampling_parameters, cache_directory=self._cache_directory, distributed=distributed, dataset_name=dataset_name, - sequence_length=self._max_sequence_length, - vocab_size=self._vocab_size, tokenizer=self._tokenizer, truncate_documents=self._config.truncate_documents, - cross_document_attention=self._cross_document_attention, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) @@ -152,7 +146,8 @@ def get_iterator( dataset_name = dataset_name.lower() Assert.incl(dataset_name, self._datasets) - Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length) + sampling_parameters = self._sampling_parameters[dataset_name] + Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") return iter( torch.utils.data.DataLoader( @@ -169,8 +164,7 @@ def get_iterator( pin_memory=True, collate_fn=partial( gpt_data_collate_fn, - use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, - cross_document_attention=self._cross_document_attention, + sampling_parameters=sampling_parameters, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 8468397e..24b0fa76 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -30,7 +30,7 @@ def __init__( Assert.eq(len(datasets), len(weights)) self._datasets = datasets self._weights = np.array(normalize_probabilities(weights)) - self._num_samples = sampling_config.num_samples + self._num_samples = sampling_config.parameters.num_samples def __len__(self) -> int: return self._num_samples diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 85401c2e..7901d6e7 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -16,6 +16,10 @@ @config_class() class SamplingConfig(Config): + """ + A dataset-dependent configuration for sampling. + """ + seed: int = Field( default=784569, desc="Seed for random sampling.", @@ -23,11 +27,25 @@ class SamplingConfig(Config): ) +@dataclasses.dataclass(kw_only=True) +class SamplingParameters: + """ + Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. + """ + + num_samples: int + + @dataclasses.dataclass(kw_only=True) class SamplingData: + """ + Holds all the necessary information for sampling, including dataset-dependent ones (`SamplingConfig`), + usage-dependent ones (`SamplingParameters`), and others set by the `Data`. + """ + # TODO: Have a separate configuration (subset?) for `build`? config: SamplingConfig - num_samples: int + parameters: SamplingParameters cache_directory: pathlib.Path | None # TODO: This prevents the sampling config from being pickled in multiprocessing. distributed: "Distributed" @@ -213,10 +231,19 @@ def build_and_sample( # Blending is deterministic and the error will never be higher than 1. dataclasses.replace( sampling, - num_samples=( - math.ceil(weight * (sampling.num_samples + 5 * (sampling.num_samples * (1 - weight)) ** 0.5)) - if self.legacy - else math.ceil(weight * sampling.num_samples) + 1 + parameters=dataclasses.replace( + sampling.parameters, + num_samples=( + math.ceil( + weight + * ( + sampling.parameters.num_samples + + 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5 + ) + ) + if self.legacy + else math.ceil(weight * sampling.parameters.num_samples) + 1 + ), ), # TODO: Seed may not be unique for nested blended datasets. config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}), diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index d8dfa0ce..c347a5c7 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -20,6 +20,7 @@ SampledDatasetUpdateConfig, SamplingConfig, SamplingData, + SamplingParameters, ) from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum @@ -45,6 +46,10 @@ class ShufflingType(str, enum.Enum): @config_class() class GPTSamplingConfig(SamplingConfig): + """ + A dataset-dependent configuration for sampling. + """ + gpu: bool = Field( default=True, desc="Enable fast sampling on GPU." @@ -52,11 +57,6 @@ class GPTSamplingConfig(SamplingConfig): " so the sample won't match the CPU equivalent.", hint=FieldHint.feature, ) - use_loss_masking_spans: bool = Field( - default=False, - desc="Read loss masking spans from the dataset.", - hint=FieldHint.feature, - ) shuffle: ShufflingType = Field( default=ShufflingType.epoch, desc="Shuffling strategy.", @@ -64,15 +64,29 @@ class GPTSamplingConfig(SamplingConfig): ) -@dataclasses.dataclass -class GPTSamplingData(SamplingData): - config: GPTSamplingConfig - # TODO: Sort these out +@dataclasses.dataclass(kw_only=True) +class GPTSamplingParameters(SamplingParameters): + """ + Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. + """ + sequence_length: int vocab_size: int + use_loss_masking_spans: bool = False + cross_document_attention: bool = True + + +@dataclasses.dataclass(kw_only=True) +class GPTSamplingData(SamplingData): + """ + Holds all the necessary information for sampling, including dataset-dependent ones (`GPTSamplingConfig`), + usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`. + """ + + config: GPTSamplingConfig + parameters: GPTSamplingParameters tokenizer: "Tokenizer" truncate_documents: bool = True - cross_document_attention: bool = True @config_class() diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 275505ba..63b7f437 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -18,7 +18,7 @@ def __init__( dataset: SampledDataset, sampling: GPTSamplingData, ): - if sampling.config.use_loss_masking_spans: + if sampling.parameters.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") self._config = config self._dataset = dataset diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index bb26d22e..f2d4035e 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -25,9 +25,9 @@ class GPTRandomSampledDataset(SampledDataset): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed - self._sequence_length = sampling.sequence_length - self._vocab_size = sampling.vocab_size - self._num_samples = sampling.num_samples + self._sequence_length = sampling.parameters.sequence_length + self._vocab_size = sampling.parameters.vocab_size + self._num_samples = sampling.parameters.num_samples def __len__(self) -> int: return self._num_samples diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index e8c5de11..0dac725b 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -85,10 +85,8 @@ def __init__( ): assert isinstance(sampling, GPTSamplingData) self._indexed_dataset = indexed_dataset - self._num_samples = sampling.num_samples - self._sequence_length = sampling.sequence_length - self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config + self._parameters = sampling.parameters self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") @@ -105,7 +103,8 @@ def __init__( self._sample() else: base_path = ( - sampling.cache_directory / f"{self.name}_ns_{self._num_samples}_sl_{self._sequence_length}" + sampling.cache_directory + / f"{self.name}_ns_{self._parameters.num_samples}_sl_{self._parameters.sequence_length}" f"_s_{self._config.seed}" ) # TODO: Names are confusing @@ -134,24 +133,27 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._sequence_length + 1 + long_docs_filter = document_sizes > self._parameters.sequence_length + 1 ignored_documents = sum(long_docs_filter) if ignored_documents: log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", + f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() if tokens_per_epoch == 0: raise RuntimeError( - f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." ) # TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? # We produce sequences of length `self._sequence_length + 1` so the last token has a label, # but in case of truncations we also include that last label in the following sample, # so we need `sequence_length * num_samples + 1` tokens in total. num_epochs = math.ceil( - ((self._sequence_length + 1 - self._truncate_documents) * self._num_samples + 1 * self._truncate_documents) + ( + (self._parameters.sequence_length + 1 - self._truncate_documents) * self._parameters.num_samples + + 1 * self._truncate_documents + ) / tokens_per_epoch ) @@ -172,9 +174,9 @@ def _sample(self) -> None: "documents_per_epoch": documents_per_epoch, "tokens_per_epoch": tokens_per_epoch, }, - "num_samples": self._num_samples, + "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, - "sequence_length": self._sequence_length, + "sequence_length": self._parameters.sequence_length, "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } @@ -311,7 +313,9 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - # Crop unnecessary entries. out = out[ : torch.clamp_min_( - torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"), + torch.searchsorted( + out, self._parameters.num_samples * self._parameters.sequence_length, side="right" + ), 0, ) ] @@ -319,16 +323,22 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - else: # TODO: dynamically handle int64 or int32 in CPP out = build_padded_token_cumsum( - sizes.cpu().numpy(), (self._sequence_length + 1), TOKEN_CUMSUM_RATE, offset + sizes.cpu().numpy(), (self._parameters.sequence_length + 1), TOKEN_CUMSUM_RATE, offset ) num_tokens = out[-1] out = out[:-1][ - : np.clip(np.searchsorted(out, self._num_samples * (self._sequence_length + 1), side="right"), 0, None) + : np.clip( + np.searchsorted( + out, self._parameters.num_samples * (self._parameters.sequence_length + 1), side="right" + ), + 0, + None, + ) ] return out, num_tokens def __len__(self) -> int: - return self._num_samples + return self._parameters.num_samples def __getitem__(self, index: int) -> typing.Any: """ @@ -339,8 +349,8 @@ def __getitem__(self, index: int) -> typing.Any: self._lazy_load() # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample - token_start = index * (self._sequence_length + 1 - self._truncate_documents) - token_end = token_start + self._sequence_length + 1 + token_start = index * (self._parameters.sequence_length + 1 - self._truncate_documents) + token_end = token_start + self._parameters.sequence_length + 1 if token_start < self._unshuffled_tokens: token_start_array = self._token_cumsum_unshuffled.array @@ -368,14 +378,14 @@ def __getitem__(self, index: int) -> typing.Any: document_size = self._indexed_dataset.get_document_size(document_index) if not self._truncate_documents: - if document_size > self._sequence_length + 1: + if document_size > self._parameters.sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue - tokens_in_sample = token_count % (self._sequence_length + 1) - if document_size + tokens_in_sample > self._sequence_length + 1: + tokens_in_sample = token_count % (self._parameters.sequence_length + 1) + if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. - padding_size = self._sequence_length + 1 - tokens_in_sample + padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: # Add padding tokens to current sample token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) @@ -394,12 +404,14 @@ def __getitem__(self, index: int) -> typing.Any: document_index, offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, - use_loss_masking_spans=self._config.use_loss_masking_spans, + use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) token_ids.append(sample.token_ids) - if self._config.use_loss_masking_spans: + if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: - span = np.clip(loss_masking_span + token_count - token_start, 0, self._sequence_length + 1) + span = np.clip( + loss_masking_span + token_count - token_start, 0, self._parameters.sequence_length + 1 + ) if span[1] > span[0]: loss_masking_spans.append(span) @@ -409,16 +421,16 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - if not self._cross_document_attention + if not self._parameters.cross_document_attention else None ) token_ids = np.concatenate(token_ids, dtype=np.int64) loss_masking_spans = ( (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) - if self._config.use_loss_masking_spans + if self._parameters.use_loss_masking_spans else None ) - Assert.eq(len(token_ids), self._sequence_length + 1) + Assert.eq(len(token_ids), self._parameters.sequence_length + 1) return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) @@ -458,14 +470,12 @@ def __init__( ): assert isinstance(sampling, GPTSamplingData) self._indexed_dataset = indexed_dataset - self._num_samples = sampling.num_samples - self._sequence_length = sampling.sequence_length if not sampling.truncate_documents: raise NotImplementedError( "Legacy sampling only supports document truncation. Please use the latest dataset format." ) - self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config + self._parameters = sampling.parameters if sampling.cache_directory is None: log_main_rank( @@ -476,7 +486,8 @@ def __init__( base_path = None else: base_path = ( - sampling.cache_directory / f"{self.name}_ns_{self._num_samples}_sl_{self._sequence_length}" + sampling.cache_directory + / f"{self.name}_ns_{self._parameters.num_samples}_sl_{self._parameters.sequence_length}" f"_s_{self._config.seed}" ) @@ -507,10 +518,10 @@ def _sample(self) -> None: num_tokens = document_sizes.sum() np_rng = np.random.RandomState(seed=self._config.seed) - num_epochs = math.ceil((self._sequence_length * self._num_samples + 1) / num_tokens) - main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // self._sequence_length - last_epoch_samples = self._num_samples - main_epochs_samples - samples_per_epoch = (num_tokens - 1) // self._sequence_length + num_epochs = math.ceil((self._parameters.sequence_length * self._parameters.num_samples + 1) / num_tokens) + main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // self._parameters.sequence_length + last_epoch_samples = self._parameters.num_samples - main_epochs_samples + samples_per_epoch = (num_tokens - 1) // self._parameters.sequence_length separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch doc_idx = np.tile(np.arange(num_documents, dtype=np.int32), num_epochs) @@ -527,7 +538,7 @@ def _sample(self) -> None: sample_idx = build_sample_idx( document_sizes, doc_idx, - self._sequence_length, + self._parameters.sequence_length, num_epochs, num_tokens, True, @@ -543,13 +554,13 @@ def _sample(self) -> None: else: np_rng.shuffle(shuffle_idx) - Assert.geq(len(shuffle_idx), self._num_samples) + Assert.geq(len(shuffle_idx), self._parameters.num_samples) self._doc_idx.save(doc_idx) self._sample_idx.save(sample_idx) - self._shuffle_idx.save(shuffle_idx[: self._num_samples]) + self._shuffle_idx.save(shuffle_idx[: self._parameters.num_samples]) def __len__(self) -> int: - return self._num_samples + return self._parameters.num_samples def __getitem__(self, idx: int) -> typing.Any: """ @@ -567,14 +578,14 @@ def __getitem__(self, idx: int) -> typing.Any: self._doc_idx[doc].item(), offset=(doc == doc_f) * offset_f, length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, - use_loss_masking_spans=self._config.use_loss_masking_spans, + use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) for doc in range(doc_f, doc_l + 1) ] token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) - Assert.eq(len(token_ids), self._sequence_length + 1) + Assert.eq(len(token_ids), self._parameters.sequence_length + 1) - if self._config.use_loss_masking_spans: + if self._parameters.use_loss_masking_spans: spans = [] offset = 0 for sample in sample_list: @@ -589,7 +600,7 @@ def __getitem__(self, idx: int) -> typing.Any: [sample.token_ids.size - (idx == len(sample_list) - 1) for idx, sample in enumerate(sample_list)], dtype=np.int32, ) - if not self._cross_document_attention + if not self._parameters.cross_document_attention else None ) return GPTSample(token_ids=token_ids, loss_masking_spans=spans, sequence_lengths=sequence_lengths) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 33209b95..62924317 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -11,6 +11,7 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.data.dataset.config import SamplingParameters from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType @@ -137,9 +138,9 @@ def setup(self, distributed: Distributed, run: Run) -> None: self._data.setup( distributed, { - dataset_name: steps + dataset_name: self._get_sampling_parameters({"num_samples": samples}) for datasets in self._samples_per_split.values() - for dataset_name, steps in datasets.items() + for dataset_name, samples in datasets.items() }, None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", timeout=self._config.training.timeout, @@ -150,6 +151,11 @@ def setup(self, distributed: Distributed, run: Run) -> None: def _get_data(self) -> Data: pass + def _get_sampling_parameters( + self, parameters: dict[str, typing.Any], _return_dict: bool = False + ) -> SamplingParameters | dict[str, typing.Any]: + return parameters if _return_dict else SamplingParameters(**parameters) + @property def _consumed_samples(self) -> int: assert self._is_setup diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index f9e21d1e..938666a1 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -2,6 +2,7 @@ import typing from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer @@ -36,11 +37,21 @@ def _get_data(self) -> GPTData: return GPTData( config=self._config.data, distributed_config=self._config.model.distributed, - vocab_size=self._config.model.base_model.vocab_size, - max_sequence_length=self._config.batch.sequence_length, - cross_document_attention=self._config.batch.cross_document_attention, ) + def _get_sampling_parameters( + self, parameters: dict[str, typing.Any], _return_dict: bool = False + ) -> GPTSamplingParameters | dict[str, typing.Any]: + parameters = super()._get_sampling_parameters(parameters, _return_dict=True) + parameters.update( + { + "vocab_size": self._config.model.base_model.vocab_size, + "sequence_length": self._config.batch.sequence_length, + "cross_document_attention": self._config.batch.cross_document_attention, + } + ) + return parameters if _return_dict else GPTSamplingParameters(**parameters) + def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: # TODO: Do in model, automate/generalize, get other stats """Get tflop/s/GPU from global-batch-size and elapsed-time""" diff --git a/tests/data/common.py b/tests/data/common.py index a22e6859..c6ddd3ef 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -13,6 +13,7 @@ GPTSampledDatasetConfig, GPTSamplingConfig, GPTSamplingData, + GPTSamplingParameters, ShufflingType, ) from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset @@ -46,12 +47,14 @@ def get_sampling_data( gpu=gpu, shuffle=shuffle, ), - num_samples=num_samples, + parameters=GPTSamplingParameters( + num_samples=num_samples, + sequence_length=sequence_length, + vocab_size=vocab_size, + ), cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, - sequence_length=sequence_length, - vocab_size=vocab_size, tokenizer=tokenizer, truncate_documents=truncate_documents, ) @@ -80,6 +83,16 @@ def get_test_data_and_compare_samples( distributed = Distributed(distributed_config, use_cpu=True) if isinstance(samples_per_dataset, int): samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} + + sampling_parameters = { + dataset_name: GPTSamplingParameters( + num_samples=num_samples, + sequence_length=sequence_length, + vocab_size=vocab_size, + ) + for dataset_name, num_samples in samples_per_dataset.items() + } + if isinstance(expected_samples, list): expected_samples = {PhaseType.training.value.lower(): expected_samples} @@ -89,8 +102,8 @@ def get_test_data_and_compare_samples( gpu=gpu, shuffle=shuffle, ) - data = GPTData(GPTDataConfig.from_dict(config), distributed_config, vocab_size, sequence_length) - data.setup(distributed, samples_per_dataset, cache_directory) + data = GPTData(GPTDataConfig.from_dict(config), distributed_config) + data.setup(distributed, sampling_parameters, cache_directory) with NoAutoValidate(): batch_config = BatchConfig(batch_size=1, sequence_length=sequence_length) batch_config.setup(distributed_config) @@ -140,8 +153,8 @@ def validate_indexed_dataset_sampling( """ Compare `GPTSampledIndexedDataset` sampling against a more basic approach """ - num_tokens = sampled._num_samples * sampled._sequence_length + 1 - all_tokens = np.full(sampled._num_samples * sampled._sequence_length + 1, -1, dtype=np.int64) + num_tokens = sampled._parameters.num_samples * sampled._parameters.sequence_length + 1 + all_tokens = np.full(sampled._parameters.num_samples * sampled._parameters.sequence_length + 1, -1, dtype=np.int64) unshuffled_epochs = div(sampled._unshuffled_documents, sampled._documents_per_epoch) document_sampling = np.tile( @@ -165,8 +178,8 @@ def validate_indexed_dataset_sampling( break validate_samples = [ - all_tokens[index * sampled._sequence_length : (index + 1) * sampled._sequence_length + 1] - for index in range(sampled._num_samples) + all_tokens[index * sampled._parameters.sequence_length : (index + 1) * sampled._parameters.sequence_length + 1] + for index in range(sampled._parameters.num_samples) ] token_ids = [sampled[i].token_ids for i in range(len(sampled))] diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 65fbf369..7b614d2f 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -37,7 +37,10 @@ def test_gpt_fim(): get_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. sampling_config = get_sampling_data( - 8, sequence_length=5, tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})) + 8, + sequence_length=5, + tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})), + vocab_size=49157, ) sampled = get_dataset_config( { @@ -73,6 +76,7 @@ def test_gpt_fim_data(): 8, sequence_length=5, expected_samples=GPT_FIM_SAMPLES, + vocab_size=49157, ) @@ -89,4 +93,5 @@ def test_gpt_fim_data_legacy(): sequence_length=5, expected_samples=GPT_FIM_SAMPLES_LEGACY, legacy=True, + vocab_size=49157, ) From ff56e62a5d873dbfa129dc9d57e26e5fa9b9cef1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 15 Apr 2025 15:53:17 -0400 Subject: [PATCH 2/3] Generalize batch config --- fast_llm/engine/multi_stage/stage.py | 8 +++--- fast_llm/engine/schedule/config.py | 36 +++++------------------ fast_llm/engine/schedule/runner.py | 16 ++++------- fast_llm/engine/schedule/schedule.py | 28 +++++++++--------- fast_llm/models/gpt/config.py | 43 ++++++++++++++++++++++++++-- fast_llm/models/gpt/model.py | 9 +++--- fast_llm/models/gpt/trainer.py | 1 + 7 files changed, 77 insertions(+), 64 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index a60fafd3..7ccd740e 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -191,8 +191,8 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] name = f"layer {self._layer_range[i]} fw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" - if (nms := kwargs.get("num_micro_sequences", 1)) > 1: - name = f"{name}, ms={kwargs.get('micro_sequence',0)}/{nms}" + if (nms := kwargs.get("micro_batch_splits", 1)) > 1: + name = f"{name}, ms={kwargs.get('micro_batch_split',0)}/{nms}" log_distributed_tensor( name, @@ -222,8 +222,8 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any name = f"layer {self._layer_range[i]} bw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" - if (nms := kwargs.get("num_micro_sequences", 1)) > 1: - name = f"{name}, ms={kwargs.get('micro_sequence',0)}/{nms}" + if (nms := kwargs.get("micro_batch_splits", 1)) > 1: + name = f"{name}, ms={kwargs.get('micro_batch_split',0)}/{nms}" log_distributed_grad( name, input_, diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 91256deb..141490ac 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -1,4 +1,5 @@ import enum +import functools import warnings from fast_llm.config import Config, Field, FieldHint, check_field, config_class, test_field @@ -43,29 +44,6 @@ class BatchConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - sequence_length: int = Field( - default=2048, - desc="Number of tokens in a sample.", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - micro_sequence_length: int = Field( - default=None, - desc="Number of tokens in a micro-sequence (must divide the sequence length).", - hint=FieldHint.performance, - valid=check_field(Assert.gt, 0), - ) - num_micro_sequences: int = Field( - init=False, - desc="Number of micro-sequences to split each sample (= seqence length / micro-sequence length).", - hint=FieldHint.derived, - valid=check_field(Assert.gt, 0), - ) - cross_document_attention: bool = Field( - default=True, - desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", - hint=FieldHint.feature, - ) _distributed: DistributedConfig = Field( init=False, desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", @@ -75,9 +53,13 @@ class BatchConfig(Config): def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config - @property + @functools.cached_property def num_inputs(self) -> int: - return self.sequential_micro_batches * self.num_micro_sequences + return self.sequential_micro_batches * self.micro_batch_splits + + @functools.cached_property + def micro_batch_splits(self) -> int: + return 1 def _validate(self) -> None: # Use the distributed properties to determine the batch size and its breakdown. @@ -128,10 +110,6 @@ def _validate(self) -> None: "Mixing of breadth-first and depth-first gradient accumulation is not thoroughly tested." " Use at your own risk." ) - if self.micro_sequence_length is None: - with self._set_implicit_default(): - self.micro_sequence_length = self.sequence_length - self.num_micro_sequences = div(self.sequence_length, self.micro_sequence_length) super()._validate() diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 4a5425ee..8eca4559 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -315,11 +315,7 @@ def _preprocess_data( self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool ) -> typing.Generator[None, None, None]: batch_config = context.schedule.batch_config - grad_output = ( - (1 if self._optimizer is None else self._optimizer.grad_scale) - / batch_config.sequential_micro_batches - / batch_config.num_micro_sequences - ) + grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs for micro_batch in range(batch_config.sequential_micro_batches): micro_batch_data = next(data_iterator) if not preprocessed: @@ -330,20 +326,20 @@ def _preprocess_data( iteration=context.iteration, metrics=context.metrics, ) - for micro_sequence, (input_, kwargs) in enumerate(micro_batch_data): + for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data): kwargs.update( grad_output=grad_output, micro_batch=micro_batch, - micro_sequence=micro_sequence, + micro_batch_split=micro_batch_split, num_micro_batches=batch_config.sequential_micro_batches, - num_micro_sequences=batch_config.num_micro_sequences, + micro_batch_splits=batch_config.micro_batch_splits, ) for name, tied_parameter in self._tied_parameters.items(): if tied_parameter.on_device: kwargs[name] = self._stages[tied_parameter.main_stage].get_parameter_buffer( tied_parameter.meta.tensor_name ) - data_index = context.schedule.get_data_index(micro_batch, micro_sequence) + data_index = context.schedule.get_data_index(micro_batch, micro_batch_split) if self._stages_owned[0]: context.inputs[context.schedule.get_step(StepType.forward, 0, data_index).global_index] = input_ if context.is_training and self._stages_owned[-1]: @@ -508,7 +504,7 @@ def _save_events(self, events, context: BatchContext) -> None: "step_stage": step.stage, "step_depth_first_micro_batch": step.depth_first_micro_batch, "step_breadth_first_micro_batch": step.breadth_first_micro_batch, - "step_micro_sequence": step.micro_sequence, + "step_micro_batch_split": step.micro_batch_split, } ), } diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 4c0e4371..44a5f677 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -70,12 +70,12 @@ class Step: meta_kwargs: dict | None = None @property - def micro_sequence(self) -> int: - return self.data_index % self.config.num_micro_sequences + def micro_batch_split(self) -> int: + return self.data_index % self.config.micro_batch_splits @property def micro_batch(self) -> int: - return self.data_index // self.config.num_micro_sequences + return self.data_index // self.config.micro_batch_splits @property def depth_first_micro_batch(self) -> int: @@ -108,7 +108,7 @@ def __repr__(self) -> str: f" local_idx={self.local_index}," f" stage={self.stage}{'f' if self.type_ == StepType.forward else 'b'}," f" dfmb={self.depth_first_micro_batch}, bfmb={self.breadth_first_micro_batch}," - f" ms={self.micro_sequence}{misc})" + f" ms={self.micro_batch_split}{misc})" ) def get_stage_index(self, num_stages) -> int: @@ -198,7 +198,7 @@ def _create_index(self) -> None: Assert.in_range( step.data_index, 0, - self._batch_config.sequential_micro_batches * self._batch_config.num_micro_sequences, + self._batch_config.sequential_micro_batches * self._batch_config.micro_batch_splits, ) Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i @@ -458,7 +458,7 @@ def _setup_metas(self) -> None: if step.type_ == StepType.forward: if step.prev_step is None: assert step.stage == 0 - step.meta_input, step.meta_kwargs = self._preprocessed_meta[step.micro_sequence] + step.meta_input, step.meta_kwargs = self._preprocessed_meta[step.micro_batch_split] # meta_kwargs may be modified. meta_kwargs = step.meta_kwargs.copy() step.meta_output = self._multi_stage.stages[step.stage].forward_meta(step.meta_input, meta_kwargs) @@ -466,15 +466,15 @@ def _setup_metas(self) -> None: step.next_step.meta_input = step.meta_output step.next_step.meta_kwargs = step.meta_kwargs - def get_data_index(self, micro_batch: int, micro_sequence: int) -> int: - return micro_batch * self._batch_config.num_micro_sequences + micro_sequence + def get_data_index(self, micro_batch: int, micro_batch_split: int) -> int: + return micro_batch * self._batch_config.micro_batch_splits + micro_batch_split def get_data_index_split( - self, breadth_first_micro_batch: int, depth_first_micro_batch: int, micro_sequence: int + self, breadth_first_micro_batch: int, depth_first_micro_batch: int, micro_batch_split: int ) -> int: return self.get_data_index( breadth_first_micro_batch * self._batch_config.depth_first_micro_batches + depth_first_micro_batch, - micro_sequence, + micro_batch_split, ) def _create_steps(self) -> tuple[list[Step], int]: @@ -490,13 +490,13 @@ def _create_steps(self) -> tuple[list[Step], int]: for depth_first_micro_batch in range(self._batch_config.depth_first_micro_batches): for stage in range(self._num_stages): for breadth_first_micro_batch in range(self._batch_config.breadth_first_micro_batches): - for micro_sequence in range(self._batch_config.num_micro_sequences): + for micro_batch_split in range(self._batch_config.micro_batch_splits): steps.append( Step( config=self._batch_config, stage=stage, data_index=self.get_data_index_split( - breadth_first_micro_batch, depth_first_micro_batch, micro_sequence + breadth_first_micro_batch, depth_first_micro_batch, micro_batch_split ), type_=StepType.forward, ) @@ -504,13 +504,13 @@ def _create_steps(self) -> tuple[list[Step], int]: if self._is_training: for stage in reversed(range(first_grad_stage, self._num_stages)): for breadth_first_micro_batch in range(self._batch_config.breadth_first_micro_batches): - for micro_sequence in reversed(range(self._batch_config.num_micro_sequences)): + for micro_batch_split in reversed(range(self._batch_config.micro_batch_splits)): steps.append( Step( config=self._batch_config, stage=stage, data_index=self.get_data_index_split( - breadth_first_micro_batch, depth_first_micro_batch, micro_sequence + breadth_first_micro_batch, depth_first_micro_batch, micro_batch_split ), type_=StepType.backward, ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 19c8e6ac..72f3cf70 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,13 +1,15 @@ import typing +from functools import cached_property -from fast_llm.config import Field, FieldHint, FieldUpdate, config_class +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig +from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM @@ -66,6 +68,42 @@ def _from_dict( return super()._from_dict(default, strict, flat) +@config_class() +class GPTBatchConfig(BatchConfig): + sequence_length: int = Field( + default=2048, + desc="Number of tokens in a sample.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + micro_sequence_length: int = Field( + default=None, + desc="Number of tokens in a micro-sequence (must divide the sequence length).", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0), + ) + cross_document_attention: bool = Field( + default=True, + desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", + hint=FieldHint.feature, + ) + use_loss_masking_spans: bool = Field( + default=False, + desc="Read loss masking spans from the dataset.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + if self.micro_sequence_length is None: + with self._set_implicit_default(): + self.micro_sequence_length = self.sequence_length + super()._validate() + + @cached_property + def micro_batch_splits(self) -> int: + return div(self.sequence_length, self.micro_sequence_length) + + @config_class() class GPTBaseModelConfig(LanguageModelBaseConfig, GPTArchitectureConfig): architecture_class = GPTArchitectureConfig @@ -130,6 +168,7 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) + batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) # TODO: Use dynamic model type? reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c672b216..2ec5ecae 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -11,7 +11,6 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead @@ -28,7 +27,7 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert @@ -121,12 +120,12 @@ def setup(self, distributed: Distributed) -> None: self._is_setup = True def preprocess_meta( - self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType + self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: # TODO: How much of this is generalizable? # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence - if isinstance(batch_meta, BatchConfig): + if isinstance(batch_meta, GPTBatchConfig): micro_batch_size = batch_meta.micro_batch_size sequence_length = batch_meta.sequence_length micro_sequence_length = batch_meta.micro_sequence_length @@ -139,7 +138,7 @@ def preprocess_meta( batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) - if isinstance(batch_meta, BatchConfig): + if isinstance(batch_meta, GPTBatchConfig): micro_sequence_length = batch_meta.micro_sequence_length if micro_sequence_length is None: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 938666a1..a269f5a6 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -47,6 +47,7 @@ def _get_sampling_parameters( { "vocab_size": self._config.model.base_model.vocab_size, "sequence_length": self._config.batch.sequence_length, + "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, } ) From 84a03363113bc5078dfd9530c247ae6a03b6a633 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 15 Apr 2025 16:31:48 -0400 Subject: [PATCH 3/3] Fixes, bw compatibility --- fast_llm/engine/inference/runner.py | 3 ++- fast_llm/models/gpt/config.py | 14 ++++++++++++++ fast_llm/models/gpt/model.py | 1 + tests/data/common.py | 4 ++-- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py index b83a5332..30f836b7 100644 --- a/fast_llm/engine/inference/runner.py +++ b/fast_llm/engine/inference/runner.py @@ -11,6 +11,7 @@ class InferenceRunner(abc.ABC): model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel + batch_config_class: typing.ClassVar[type[BatchConfig]] = BatchConfig def __init__(self, fast_llm_model: FastLLMModel): assert isinstance(fast_llm_model, self.model_class) @@ -19,7 +20,7 @@ def __init__(self, fast_llm_model: FastLLMModel): self._schedule_config = ScheduleConfig() # TODO: Sort things out. with NoAutoValidate(): - self._batch_config = BatchConfig() + self._batch_config = self.batch_config_class() self._batch_config.setup(self._fast_llm_model.config.distributed) self._batch_config.validate() self._runner = ScheduleRunner( diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 72f3cf70..b78c3311 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -82,6 +82,7 @@ class GPTBatchConfig(BatchConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) + # TODO: Find a better place for these? cross_document_attention: bool = Field( default=True, desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", @@ -182,6 +183,19 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) super()._validate() + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + cls._handle_renamed_field( + default, ("data", "sampling", "use_loss_masking_spans"), ("batch", "use_loss_masking_spans") + ) + return super()._from_dict(default, strict, flat) + @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2ec5ecae..7d9e59a4 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -373,3 +373,4 @@ class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): class GPTInferenceRunner(InferenceRunner): model_class: typing.ClassVar[type[GPTModel]] = GPTModel + batch_config_class: typing.ClassVar[type[GPTBatchConfig]] = GPTBatchConfig diff --git a/tests/data/common.py b/tests/data/common.py index c6ddd3ef..47b53195 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -21,7 +21,7 @@ from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div from tests.common import TEST_VOCAB_SIZE @@ -105,7 +105,7 @@ def get_test_data_and_compare_samples( data = GPTData(GPTDataConfig.from_dict(config), distributed_config) data.setup(distributed, sampling_parameters, cache_directory) with NoAutoValidate(): - batch_config = BatchConfig(batch_size=1, sequence_length=sequence_length) + batch_config = GPTBatchConfig(batch_size=1, sequence_length=sequence_length) batch_config.setup(distributed_config) batch_config.validate() tokens = {