From 147e33b14659cd4bb3890c7bc9489f094b864508 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 6 Jan 2025 16:06:19 -0500 Subject: [PATCH 01/19] Modular dataset configuration --- fast_llm/data/config.py | 9 - fast_llm/data/data/abstract.py | 29 ++- fast_llm/data/data/config.py | 5 +- fast_llm/data/data/gpt/config.py | 57 ++---- fast_llm/data/data/gpt/data.py | 178 +++-------------- fast_llm/data/dataset/blended.py | 43 +---- fast_llm/data/dataset/config.py | 221 ++++++++++++++++++++++ fast_llm/data/dataset/gpt/abstract.py | 2 +- fast_llm/data/dataset/gpt/config.py | 262 ++++++++++++++++++++++++++ fast_llm/data/dataset/gpt/dummy.py | 18 +- fast_llm/data/dataset/gpt/memmap.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/data/dataset/gpt/slice.py | 5 +- tests/common.py | 7 +- 14 files changed, 583 insertions(+), 257 deletions(-) create mode 100644 fast_llm/data/dataset/config.py create mode 100644 fast_llm/data/dataset/gpt/config.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 7080c30e..32675749 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -12,15 +12,6 @@ class MultiprocessingContext(str, enum.Enum): spawn = "spawn" -def _validate_split(value): - Assert.leq(len(value), 3) - return value + [0] * (len(value) - 3) - - -def _validate_path(value): - return [value] if isinstance(value, str) else value - - TokenizerFromFile = "TokenizerFromFile" diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index f9284a71..a2e419d5 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -1,15 +1,34 @@ import abc +import typing -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.data.data.config import DataConfig +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig +if typing.TYPE_CHECKING: + from fast_llm.engine.distributed.distributed import Distributed + class Data(abc.ABC): + _distributed: "Distributed" + _samples_per_phase: dict[PhaseType, int] + + def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: + self._config = config + self._distributed_config = distributed_config + # TODO: Improve interface - @abc.abstractmethod - def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]): - pass + def setup(self, distributed: "Distributed", samples_per_phase: dict[PhaseType, int]): + self._distributed = distributed + self._samples_per_phase = samples_per_phase + + @property + def config(self): + return self._config + + @property + def distributed(self): + return self._distributed @abc.abstractmethod def get_iterator( diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 401342d5..3485c2e0 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -1,12 +1,13 @@ import pathlib import typing -from fast_llm.config import Config, Field, config_class +from fast_llm.config import Config, Field, check_field, config_class +from fast_llm.utils import Assert @config_class class SamplingConfig(Config): - num_samples: int = Field(default=1, desc="Number of samples to generate.") + num_samples: int = Field(default=1, desc="Number of samples to generate.", valid=check_field(Assert.gt, 0)) seed: int = Field(default=0, desc="Random seed.") cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.") verbose: bool = Field(default=True, desc="Log sampling progress.") diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index adeffbcc..bb4d78b9 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,27 +1,17 @@ -import enum +import logging from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.data.config import MultiprocessingContext, TokenizerConfig, _validate_path, _validate_split -from fast_llm.data.data.config import DataConfig, SamplingConfig +from fast_llm.data.config import MultiprocessingContext, TokenizerConfig +from fast_llm.data.data.config import DataConfig +from fast_llm.data.dataset.gpt.config import GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledSplitDatasetConfig from fast_llm.data.dataset.gpt.fim.config import FimConfig from fast_llm.utils import Assert - -class DatasetSource(str, enum.Enum): - """ - An enum for the different ways to load datasets. - TODO: Reduce the diversity? - TODO: Is this specific to GPT data? - """ - - list = "list" - file = "file" - sample = "sample" - random = "random" +logger = logging.getLogger(__name__) @config_class() -class GPTDataConfig(DataConfig): +class GPTDataConfig(DataConfig, GPTLegacyConfig): """ Configuration for the dataset(s), split and sampling. Currently hard-coded to a GPT dataset. @@ -35,29 +25,16 @@ class GPTDataConfig(DataConfig): desc="Configuration for the tokenizer (for FIM).", hint=FieldHint.feature, ) + dataset: GPTSampledSplitDatasetConfig = Field( + default=None, + desc="Configuration for the dataset(s).", + hint=FieldHint.core, + ) fim: FimConfig = Field( default_factory=FimConfig, desc="Configuration for Fill In the Middle (FIM).", hint=FieldHint.feature, ) - # TODO: set default to [1,0,0]? - split: list[float] = Field( - default_factory=lambda: [969, 30, 1], - desc="Split ratio for train, valid and test datasets.", - hint=FieldHint.core, - valid=_validate_split, - ) - format: DatasetSource = Field( - default=DatasetSource.list, - desc="Format for the dataset definition.", - hint=FieldHint.core, - ) - path: list[str] = Field( - default_factory=list, - desc="Path or list of paths and weights.", - hint=FieldHint.core, - valid=_validate_path, - ) data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", @@ -70,7 +47,11 @@ class GPTDataConfig(DataConfig): hint=FieldHint.expert, ) - -@config_class -class GPTSamplingConfig(SamplingConfig): - sequence_length: int = Field(default=None, desc="Number of token in each sample.") + def __post_init__(self): + if self.dataset is None: + logger.warning("Using the legacy dataset definition format." " Specify it through `data.dataset` instead.") + self.dataset = GPTLegacyDatasetConfig( + split=self.split, + format=self.format, + path=self.path, + ) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 36165da7..0632f94d 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,27 +1,21 @@ -import json import logging -import math import pathlib -import typing import warnings import torch import torch.utils.data from fast_llm.data.data.abstract import Data -from fast_llm.data.data.gpt.config import DatasetSource, GPTDataConfig, GPTSamplingConfig -from fast_llm.data.dataset.abstract import CopySplitDataset, PhaseSplits, SampledSplitDataset -from fast_llm.data.dataset.blended import BlendedDataset -from fast_llm.data.dataset.gpt.dummy import DummyGPTDataset -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.dataset.abstract import PhaseSplits, SampledSplitDataset +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.run import get_run, log_main_rank from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.utils import Assert, normalize_probabilities +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -34,11 +28,10 @@ class GPTData(Data): """ _datasets: SampledSplitDataset + _config: GPTDataConfig _tokenizer: Tokenizer | None _distributed: Distributed _cache_directory: pathlib.Path | None - _samples_per_phase: dict[PhaseType, int] - _phases: typing.ClassVar[tuple[PhaseType, ...]] = (PhaseType.training, PhaseType.validation, PhaseType.test) _is_setup: bool = False def __init__( @@ -52,154 +45,53 @@ def __init__( Create the data and gather some basic information on the dataset(s). Should be `setup` before use. """ - self._config = config - self._distributed_config = distributed_config + super().__init__(config, distributed_config) self._vocab_size = vocab_size self._max_sequence_length = max_sequence_length - Assert.eq(len(self._config.split), len(self._phases)) - self._phase_split = { - phase: ratio - for phase, ratio in zip(self._phases, normalize_probabilities(self._config.split)) - if ratio > 0 - } - data_base_path = None - if self._config.format == DatasetSource.file: - Assert.eq(len(self._config.path), 1) - data_path = pathlib.Path(self._config.path[0]) - dataset_defs = json.load(data_path.open("r")) - data_base_path = data_path.parent - dataset_prefixes = [dataset_def["prefix"] for dataset_def in dataset_defs["datasets"]] - dataset_weights = normalize_probabilities( - [dataset_def["weight"] for dataset_def in dataset_defs["datasets"]] - ) - self._build_and_sample_dataset = self._build_and_sample_gpt_dataset - elif self._config.format == DatasetSource.list: - Assert.geq(len(self._config.path), 1) - if len(self._config.path) == 1: - dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] - else: - Assert.custom(lambda x: x % 2 == 0, len(self._config.path)) - dataset_prefixes = [x.strip() for x in self._config.path[1::2]] - assert len(dataset_prefixes) == len(set(dataset_prefixes)) - dataset_weights = normalize_probabilities([float(x) for x in self._config.path[::2]]) - self._build_and_sample_dataset = self._build_and_sample_gpt_dataset - elif self._config.format == DatasetSource.sample: - Assert.eq(len(self._config.path), 1) - dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] - self._build_and_sample_dataset = self._build_and_sample_dummy_dataset - elif self._config.format == DatasetSource.random: - Assert.eq(len(self._config.path), 0) - dataset_prefixes, dataset_weights = [None], [1.0] - self._build_and_sample_dataset = self._build_and_sample_dummy_dataset - else: - raise NotImplementedError(self._config.format) + @property + def vocab_size(self) -> int: + return self._vocab_size - dataset_names = [ - f"dataset_{i}_{'dummy' if prefix is None else prefix.replace('/','__')}" - for i, prefix in enumerate(dataset_prefixes) - ] - self._num_datasets = len(dataset_names) - self._dataset_prefixes = { - name: ( - None - if prefix is None - else ( - pathlib.Path(prefix).resolve() - if data_base_path is None - else (pathlib.Path(data_base_path) / prefix).resolve() - ) - ) - for name, prefix in zip(dataset_names, dataset_prefixes) - } - self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)} + @property + def max_sequence_length(self) -> int: + return self._max_sequence_length - def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]): + def setup(self, distributed: Distributed, samples_per_phase: PhaseSplits[int]): """ Load the datasets, and prepare or load the samplings. This may take a while and a significant amount of cpu memory. """ + super().setup(distributed, samples_per_phase) run = get_run() - Assert.leq(set(samples_per_phase), set(self._phase_split)) - log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.") + log_main_rank(f"Preparing dataset. This may take several minutes.") self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None - self._distributed = distributed - self._samples_per_phase = samples_per_phase + if run.experiment_directory is None: warnings.warn(f"Using the dataset directory for the index cache.") self._cache_directory = None else: self._cache_directory = run.experiment_directory / "dataset_cache" - - datasets_and_weights = [] - for i, (name, weight) in enumerate(self._dataset_weights.items()): - if i % 100 == 0 and i > 0: - log_main_rank(f"Prepared {i} of {self._num_datasets} datasets.") - dataset_samples_per_phase = {} - for phase, samples_per_phase in self._samples_per_phase.items(): - expected_samples = self._dataset_weights[name] * samples_per_phase - # Add 5 times the standard deviation (of a binomial distribution) - # so the probability of sampling more than this amount during blending is negligible. - dataset_samples_per_phase[phase] = math.ceil( - expected_samples - + 5 * math.sqrt(expected_samples * self._dataset_weights[name] * (1 - self._dataset_weights[name])) + sampling_config = PhaseSplits[GPTSamplingConfig]( + { + phase: GPTSamplingConfig( + num_samples=samples_per_phase[phase], + sequence_length=self._max_sequence_length, + seed=self._distributed_config.seed, + cache_directory=self._cache_directory, + verbose=True, ) - sampling_configs = PhaseSplits[GPTSamplingConfig]( - { - phase: GPTSamplingConfig( - num_samples=dataset_samples_per_phase[phase], - sequence_length=self._max_sequence_length, - seed=self._distributed_config.seed, - cache_directory=( - self._dataset_prefixes[name].parent - if self._cache_directory is None and isinstance(self._dataset_prefixes[name], pathlib.Path) - else self._cache_directory - ), - verbose=self._num_datasets <= 5, - ) - for phase, num_samples in dataset_samples_per_phase.items() - if num_samples > 0 - } - ) - datasets_and_weights.append( - (self._build_and_sample_dataset(name, sampling_configs), self._dataset_weights[name]) - ) - - if len(datasets_and_weights) == 1: - self._datasets = datasets_and_weights[0][0] - else: - self._datasets = BlendedDataset.apply( - "blended", - datasets_and_weights, - PhaseSplits[GPTSamplingConfig]( - { - phase: GPTSamplingConfig( - num_samples=samples_per_phase, - sequence_length=self._max_sequence_length, - seed=self._distributed_config.seed, - cache_directory=None if self._cache_directory is None else self._cache_directory, - verbose=self._num_datasets <= 5, - ) - for phase, samples_per_phase in self._samples_per_phase.items() - } - ), - self, - ) - self._is_setup = True - - @property - def config(self): - return self._config + for phase, num_samples in samples_per_phase.items() + if num_samples > 0 + } + ) + self._datasets = self._config.dataset.build_split_sample(self, sampling_config) @property def tokenizer(self): assert self._is_setup return self._tokenizer - @property - def distributed(self): - return self._distributed - def get_iterator( self, batch_config: BatchConfig, @@ -229,15 +121,3 @@ def get_iterator( multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) ) - - def _build_and_sample_gpt_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]): - return GPTDatasetSlice.from_splits( - GPTMemmapDataset(name, self._dataset_prefixes[name]), self._phase_split - ).sample(sampling_configs, self) - - def _build_and_sample_dummy_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]): - return CopySplitDataset( - f"{name}_split", - DummyGPTDataset(name, self._max_sequence_length, self._vocab_size), - list(sampling_configs), - ).sample(sampling_configs, self) diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 4d5711eb..970e7c62 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -1,19 +1,16 @@ import logging import pathlib import time -import typing import numpy as np from fast_llm.core.distributed import safe_barrier +from fast_llm.data.data.abstract import Data from fast_llm.data.data.config import SamplingConfig -from fast_llm.data.dataset.abstract import PhaseSplits, SampledDataset, SplitDataset +from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert, normalize_probabilities -if typing.TYPE_CHECKING: - from fast_llm.data.data.gpt.data import GPTData - try: from fast_llm.csrc.data import build_blending_indices # noqa @@ -35,14 +32,16 @@ class BlendedDataset(SampledDataset): def __init__( self, name: str, - datasets_and_weights: list[tuple[SampledDataset, float]], + datasets: list[SampledDataset], + weights: list[float], sampling_config: SamplingConfig, # TODO: Generalize - data: "GPTData", + data: Data, ): self._name = name - assert len(datasets_and_weights) > 0 - self._datasets, weights = zip(*datasets_and_weights) + assert len(datasets) > 0 + Assert.eq(len(datasets), len(weights)) + self._datasets = datasets self._weights = normalize_probabilities(weights) self._num_samples = sampling_config.num_samples self._data_sample_warn_time_ms = data.config.data_sample_warn_time_ms @@ -72,32 +71,6 @@ def __init__( safe_barrier(group, self._name) self._load_mappings(sampling_config.verbose) - @classmethod - def apply( - cls, - name: str, - datasets_and_weights: list[(SplitDataset[SampledDataset], float)], - sampling_configs: PhaseSplits[SamplingConfig], - data: "GPTData", - ): - Assert.leq(set(sampling_configs), set.union(*[set(dataset) for dataset, _ in datasets_and_weights])) - return SplitDataset[BlendedDataset]( - name, - { - phase: BlendedDataset( - f"{name}_{phase.value}", - [ - (dataset[phase], weight) - for dataset, weight in datasets_and_weights - if phase in dataset and weight > 0 - ], - sampling_config, - data, - ) - for phase, sampling_config in sampling_configs.items() - }, - ) - def __getstate__(self): return ( self._datasets, diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py new file mode 100644 index 00000000..2c29f543 --- /dev/null +++ b/fast_llm/data/dataset/config.py @@ -0,0 +1,221 @@ +import math +import typing + +from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.data.data.abstract import Data +from fast_llm.data.data.config import SamplingConfig +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.data.dataset.abstract import ( + PhaseSplits, + SamplableDataset, + SamplableSplitDataset, + SampledDataset, + SampledSplitDataset, + ) + + +@config_class() +class DatasetConfig(Config): + _abstract = True + + +class SampledSplitDatasetConfig(DatasetConfig): + + def build_split_sample( + self, + data: Data, + config: PhaseSplits[SamplingConfig], + default_phase: PhaseType = PhaseType.training, + ) -> SampledSplitDataset: + raise NotImplementedError() + + @property + def sampled(self): + # Generally hard-coded, but some classes allow for more flexible values. + return True + + @property + def split(self): + # Generally hard-coded, but some classes allow for more flexible values. + return True + + +class SampledDatasetConfig(SampledSplitDatasetConfig): + """ + A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. + (See `fast_llm.data.sampler.Sampler`.) + """ + + def build_sample(self, data: Data, config: SamplingConfig) -> SampledDataset: + raise NotImplementedError() + + def build_split_sample( + self, + data: Data, + config: PhaseSplits[SamplingConfig], + default_phase: PhaseType = PhaseType.training, + ) -> SampledSplitDataset: + dataset = self.build_sample(data, config[default_phase]) + return SampledSplitDataset(dataset.name, {default_phase: dataset}) + + @property + def sampled(self): + return True + + @property + def split(self): + return False + + +class SamplableSplitDatasetConfig(SampledSplitDatasetConfig): + + def build_split( + self, + data: Data, + default_phase: PhaseType = PhaseType.training, + ) -> SamplableSplitDataset: + raise NotImplementedError() + + def build_split_sample( + self, + data: Data, + config: PhaseSplits[SamplingConfig], + default_phase: PhaseType = PhaseType.training, + ) -> SampledSplitDataset: + split_dataset = self.build_split(data) + # TODO: Name + # TODO: Arg order not matching with dataset + return SampledSplitDataset( + "dataset", + {phase: split_dataset[phase].sample(phase_config, data) for phase, phase_config in config.items()}, + ) + + @property + def sampled(self): + return False + + @property + def split(self): + return True + + +class SamplableDatasetConfig(SampledDatasetConfig, SamplableSplitDatasetConfig): + def build(self, data: Data) -> SamplableDataset: + raise NotImplementedError() + + def build_sample(self, data: Data, config: SamplingConfig) -> SampledDataset: + return self.build(data).sample(config, data) + + def build_split( + self, + data: Data, + default_phase: PhaseType = PhaseType.training, + ) -> SamplableSplitDataset: + dataset = self.build(data) + return SamplableSplitDataset(dataset.name, {default_phase: dataset}) + + @property + def sampled(self): + return False + + @property + def split(self): + return False + + +@config_class() +class BlendedDatasetConfig(SampledDatasetConfig): + # [(?sampled, ?split)] -> (sampled, ?split) + name: str = Field( + default="blended", + desc="The name of the dataset.", + hint=FieldHint.core, + ) + datasets: list[SampledDatasetConfig] = Field( + desc="The datasets to blend.", + hint=FieldHint.core, + ) + weights: list[float] = Field( + desc="The blending weight of each dataset.", + hint=FieldHint.core, + ) + + def __post_init__(self): + Assert.eq(len(self.datasets), len(self.weights)) + + @property + def split(self): + return any(dataset.split for dataset in self.datasets) + + def build_sample( + self, + data: "Data", + config: SamplingConfig, + ) -> SampledDataset: + from fast_llm.data.dataset.blended import BlendedDataset + + assert not self.split + + # Build and sample the datasets. + sampled_datasets = [ + dataset.build_sample( + data, + # Blending is deterministic and the error will never be higher than 1. + config.to_copy({"num_samples": math.ceil(weight * config.num_samples) + 1}), + ) + for dataset, weight in zip(self.datasets, self.weights, strict=True) + ] + # Blend the datasets. + return BlendedDataset( + self.name, + sampled_datasets, + self.weights, + config, + data, + ) + + def build_split_sample( + self, + data: "Data", + config: PhaseSplits[SamplingConfig], + default_phase: PhaseType = PhaseType.training, + ) -> SampledSplitDataset: + from fast_llm.data.dataset.blended import BlendedDataset + + if not self.split: + # Take the base class shortcut using build_sample if it's available. + return super().build_split_sample(data, config, default_phase) + + # Build, sample and split the datasets. + sampled_datasets = [ + dataset.build_split_sample( + data, + # Blending is deterministic and the error will never be higher than 1. + PhaseSplits[SamplingConfig]( + { + phase: phase_config.to_copy({"num_samples": math.ceil(weight * phase_config.num_samples) + 1}) + for phase, phase_config in config.items() + } + ), + default_phase, + ) + for dataset, weight in zip(self.datasets, self.weights, strict=True) + ] + + # Blend the datasets for each phase. + return SampledSplitDataset[BlendedDataset]( + self.name, + { + phase: BlendedDataset( + f"{self.name}_{phase.value}", + [dataset[phase] for dataset in sampled_datasets], + self.weights, + phase_config, + data, + ) + for phase, phase_config in config.items() + }, + ) diff --git a/fast_llm/data/dataset/gpt/abstract.py b/fast_llm/data/dataset/gpt/abstract.py index e2c7f093..40f1532a 100644 --- a/fast_llm/data/dataset/gpt/abstract.py +++ b/fast_llm/data/dataset/gpt/abstract.py @@ -3,8 +3,8 @@ import numpy as np -from fast_llm.data.data.gpt.config import GPTSamplingConfig from fast_llm.data.dataset.abstract import SamplableDataset +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig if typing.TYPE_CHECKING: from fast_llm.data.data.gpt.data import GPTData diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py new file mode 100644 index 00000000..29795c64 --- /dev/null +++ b/fast_llm/data/dataset/gpt/config.py @@ -0,0 +1,262 @@ +import enum +import json +import pathlib +import typing + +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class +from fast_llm.data.data.config import SamplingConfig +from fast_llm.data.dataset.abstract import PhaseSplits, SamplableSplitDataset, SampledSplitDataset +from fast_llm.data.dataset.config import ( + BlendedDatasetConfig, + DatasetConfig, + SamplableDatasetConfig, + SamplableSplitDatasetConfig, + SampledDatasetConfig, + SampledSplitDatasetConfig, +) +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.utils import Assert, Registry, normalize_probabilities + +if typing.TYPE_CHECKING: + from fast_llm.data.data.gpt.data import GPTData + from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset + from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset + from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset + from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset + from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice + + +@config_class +class GPTSamplingConfig(SamplingConfig): + sequence_length: int = Field(default=None, desc="Number of token in each sample.") + + +@config_class() +class GPTDatasetConfig(DatasetConfig): + _registry: typing.ClassVar[Registry[str, type["GPTDatasetConfig"]]] = Registry[str, type["GPTDatasetConfig"]]( + "gpt_dataset_class", {} + ) + type_: typing.ClassVar[type["GPTDatasetConfig"] | None] = None + type: str | None = Field( + default=None, + desc="The type of dataset.", + hint=FieldHint.core, + ) + + def __new__(cls, *args, **kwargs): + # Find and instantiate the actual class. + type_ = kwargs.get("type") + if type_ is not None: + actual_cls = GPTDatasetConfig._registry[type_] + Assert.custom(issubclass, actual_cls, cls) + if actual_cls != cls: + return actual_cls(*args, **kwargs) + return super().__new__() + + def __init_subclass__(cls, type_: str | None = None, **kwargs): + if type_ is not None: + GPTDatasetConfig._registry[type_] = cls + cls.type = type_ + + +class GPTSampledSplitDatasetConfig(SampledSplitDatasetConfig, GPTDatasetConfig): + pass + + +class GPTSampledDatasetConfig(SampledDatasetConfig, GPTSampledSplitDatasetConfig): + pass + + +class GPTSamplableSplitDatasetConfig(SamplableSplitDatasetConfig, GPTSampledSplitDatasetConfig): + pass + + +class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig, GPTSamplableSplitDatasetConfig): + pass + + +class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig): + def build(self, data: "GPTData") -> "GPTIndexedDataset": + raise NotImplementedError() + + +class GPTDummyDatasetConfig(GPTSamplableDatasetConfig): + name: str = Field( + default="dummy", + desc="The name of the dataset.", + hint=FieldHint.core, + ) + + def build(self, data: "GPTData") -> "GPTDummyDataset": + return GPTDummyDataset(self.name, data.max_sequence_length, data.vocab_size) + + +@config_class() +class GPTMemmapDatasetConfig(GPTDatasetConfig, SamplableDatasetConfig, type_="memmap"): + # Path -> (unsampled, unsplit) + _abstract = False + path: pathlib.Path = Field( + desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", + hint=FieldHint.core, + ) + + def build(self, data: "GPTData") -> "GPTMemmapDataset": + from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset + + return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path) + + +@config_class() +class GPTConcatenatedDatasetConfig(GPTDatasetConfig, SamplableDatasetConfig, type_="concatenated"): + """ + Concatenate multiple datasets as if they were one. + Must be done before sampling and splitting. + TODO: OK after sampling (staged training?) or splitting (Equal split for each sub-dataset, probably better? + [(unsampled, unsplit)] -> (unsampled, unsplit) + """ + + _abstract = False + name: str = Field( + default="concatenated", + desc="The name of the dataset.", + hint=FieldHint.core, + ) + datasets: list[GPTIndexedDatasetConfig] = Field( + desc="The datasets to concatenate.", + hint=FieldHint.core, + ) + + def build(self, data: "GPTData") -> "GPTConcatenatedDataset": + from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset + + return GPTConcatenatedDataset(self.name, [dataset.build(data) for dataset in self.datasets]) + + +@config_class() +class GPTSplitDatasetConfig(SamplableSplitDatasetConfig, type_="split"): + """ + Split a single dataset into multiple phases. + Must be done before sampling. + TODO: Ok after sampling? + (unsampled, unsplit) -> (unsampled, split) + """ + + _abstract = False + dataset: GPTIndexedDatasetConfig = Field( + desc="The dataset to split.", + hint=FieldHint.core, + ) + ratios: PhaseSplits[float] = Field( + desc="The split ratio for each phase", + hint=FieldHint.core, + ) + + def build(self, data: "GPTData") -> SamplableSplitDataset["GPTDatasetSlice"]: + from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice + + return GPTDatasetSlice.from_splits(self.dataset.build(data), self.ratios) + + +@config_class() +class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig, type_="blended"): + _abstract = False + datasets: list[GPTSampledDatasetConfig] = FieldUpdate() + + +class LegacyDatasetSource(str, enum.Enum): + """ + An enum for the different ways to load datasets. + """ + + list = "list" + file = "file" + random = "random" + + +def _validate_split(value): + Assert.leq(len(value), 3) + return value + [0] * (len(value) - 3) + + +def _validate_path(value): + return [value] if isinstance(value, str) else value + + +@config_class() +class GPTLegacyConfig(Config): + split: list[float] = Field( + default_factory=lambda: [969, 30, 1], + desc="Split ratio for train, valid and test datasets.", + hint=FieldHint.deprecated, + valid=_validate_split, + ) + format: LegacyDatasetSource = Field( + default=LegacyDatasetSource.list, + desc="Format for the dataset definition.", + hint=FieldHint.deprecated, + ) + path: list[str] = Field( + default_factory=list, + desc="Path or list of paths and weights.", + hint=FieldHint.deprecated, + valid=_validate_path, + ) + + +@config_class() +class GPTLegacyDatasetConfig(GPTSampledSplitDatasetConfig, GPTLegacyConfig, type_="legacy"): + _abstract = False + + def build_split_sample( + self, + data: "GPTData", + config: PhaseSplits[GPTSamplingConfig], + default_phase: PhaseType = PhaseType.training, + ) -> SampledSplitDataset: + + if self.format == LegacyDatasetSource.random: + Assert.eq(len(self.path), 0) + # TODO: Multiple phase. + dataset_config = GPTDummyDatasetConfig() + else: + if self.format == LegacyDatasetSource.file: + Assert.eq(len(self.path), 1) + data_path = pathlib.Path(self.path[0]) + dataset_defs = json.load(data_path.open("r")) + data_base_path = data_path.parent + dataset_prefixes = [ + (data_base_path / dataset_def["prefix"]).resolve() for dataset_def in dataset_defs["datasets"] + ] + dataset_weights = normalize_probabilities( + [dataset_def["weight"] for dataset_def in dataset_defs["datasets"]] + ) + elif self.format == LegacyDatasetSource.list: + Assert.geq(len(self.path), 1) + if len(self.path) == 1: + dataset_prefixes, dataset_weights = [self.path[0].strip()], [1.0] + else: + Assert.custom(lambda x: x % 2 == 0, len(self.path)) + dataset_prefixes = [pathlib.Path(x.strip()).resolve() for x in self.path[1::2]] + assert len(dataset_prefixes) == len(set(dataset_prefixes)) + dataset_weights = normalize_probabilities([float(x) for x in self.path[::2]]) + else: + raise NotImplementedError(self.format) + + dataset_configs = [ + GPTSplitDatasetConfig( + dataset=GPTMemmapDatasetConfig(path=prefix), + ratios=self.split, + ) + for prefix in dataset_prefixes + ] + dataset_config = ( + GPTBlendedDatasetConfig( + name="blended", + datasets=dataset_configs, + weights=dataset_weights, + ) + if len(dataset_configs) > 1 + else dataset_configs[0] + ) + + return dataset_config.build_split_sample(data, config, default_phase) diff --git a/fast_llm/data/dataset/gpt/dummy.py b/fast_llm/data/dataset/gpt/dummy.py index dd04b6e6..68a864c8 100644 --- a/fast_llm/data/dataset/gpt/dummy.py +++ b/fast_llm/data/dataset/gpt/dummy.py @@ -2,17 +2,16 @@ import numpy as np -from fast_llm.data.data.gpt.config import GPTSamplingConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig if typing.TYPE_CHECKING: from fast_llm.data.data.gpt.data import GPTData -class DummyGPTDataset(SamplableDataset): +class GPTDummyDataset(SamplableDataset): """ - A dummy dataset that always returns the same sample, for debugging purposes. - The sample can be purely random, or read from a file to allow reproducing in other runs. + A dummy dataset that always returns the same random sample, for debugging purposes. """ def __init__(self, name: str, sequence_length: int, vocab_size: int): @@ -20,7 +19,7 @@ def __init__(self, name: str, sequence_length: int, vocab_size: int): self._name = name def sample(self, config: GPTSamplingConfig, data: "GPTData"): - return DummyGPTSampledDataset(self, config) + return GPTDummySampledDataset(self, config) def get(self): return self._dummy_sample @@ -30,13 +29,8 @@ def name(self): return self._name -class DummyGPTSampledDataset(SampledDataset): - """ - A dummy dataset that always returns the same sample, for debugging purposes. - The sample can be purely random, or read from a file to allow reproducing in other runs. - """ - - def __init__(self, dataset: DummyGPTDataset, config: GPTSamplingConfig): +class GPTDummySampledDataset(SampledDataset): + def __init__(self, dataset: GPTDummyDataset, config: GPTSamplingConfig): self._config = config self._dataset = dataset diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f24536f8..30184415 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -83,7 +83,7 @@ def num_documents(self) -> int: def num_tokens(self) -> int: return div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - def get_document_sizes(self) -> "np.ndarray": + def get_document_sizes(self) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index beebc5ed..785c7aac 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -3,10 +3,10 @@ import numpy as np from fast_llm.core.distributed import safe_barrier -from fast_llm.data.data.gpt.config import GPTSamplingConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.data.dataset.gpt.fim.fim import Fim from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import MAX_SEED diff --git a/fast_llm/data/dataset/gpt/slice.py b/fast_llm/data/dataset/gpt/slice.py index 87684e01..df1a511a 100644 --- a/fast_llm/data/dataset/gpt/slice.py +++ b/fast_llm/data/dataset/gpt/slice.py @@ -1,6 +1,5 @@ -from fast_llm.data.dataset.abstract import SamplableSplitDataset +from fast_llm.data.dataset.abstract import PhaseSplits, SamplableSplitDataset from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset -from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -56,7 +55,7 @@ def name(self): return self._name @classmethod - def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, float]): + def from_splits(cls, dataset: GPTIndexedDataset, phase_split: PhaseSplits[float]): """ Create a set of GPT datasets from a MMapIndexedDataset, each containing approximately the requested proportion of the total tokens. diff --git a/tests/common.py b/tests/common.py index 6c9d11d7..8c83b2ab 100644 --- a/tests/common.py +++ b/tests/common.py @@ -58,7 +58,12 @@ "training.num_workers=0", "batch.batch_size=8", "batch.sequence_length=512", - f"data.path={DATASET_PREFIX}", + "data.dataset.type=split", + "data.dataset.dataset.type=memmap", + f"data.dataset.dataset.path={DATASET_PREFIX}", + f"data.dataset.ratios.training=969", + f"data.dataset.ratios.validation=30", + f"data.dataset.ratios.test=1", "optimizer.learning_rate.base=0.0001", ] CONFIG_BASE_MEGATRON = [ From c41a2c588d3e7d6cdcaa31d538e8c6fb18fc6b66 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 6 Jan 2025 19:13:26 -0500 Subject: [PATCH 02/19] fixes --- examples/mistral-4-node-benchmark.yaml | 10 +-- fast_llm/config.py | 33 ++++++---- fast_llm/data/data/gpt/config.py | 9 +-- fast_llm/data/data/gpt/data.py | 1 + fast_llm/data/dataset/config.py | 27 +++++---- fast_llm/data/dataset/gpt/config.py | 84 ++++++++++++++++++++------ fast_llm/data/dataset/gpt/slice.py | 5 +- tests/common.py | 6 +- 8 files changed, 119 insertions(+), 56 deletions(-) diff --git a/examples/mistral-4-node-benchmark.yaml b/examples/mistral-4-node-benchmark.yaml index 99dd0ee7..0a71d392 100644 --- a/examples/mistral-4-node-benchmark.yaml +++ b/examples/mistral-4-node-benchmark.yaml @@ -11,8 +11,8 @@ batch: micro_batch_size: 1 batch_size: 32 data: - format: random - split: [1, 0, 0] + dataset: + type: dummy optimizer: learning_rate: base: 1.0e-05 @@ -27,18 +27,18 @@ model: normalization: type: rms_norm epsilon: 1.0e-05 + rotary: + type: default + theta: 10000 num_layers: 32 hidden_size: 4096 ffn_hidden_size: 14336 num_attention_heads: 32 head_groups: 8 add_linear_biases: false - use_rotary_embeddings: true gated: true activation_type: silu - triton_rotary: true kv_channels: 128 - rotary_embedding_scale: -9.210340371976184 window_size: 4096 init_method_std: 0.009021 attention_dropout: 0.0 diff --git a/fast_llm/config.py b/fast_llm/config.py index 1934caf2..d3deda23 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -287,7 +287,6 @@ def __post_init__(self): In general this should not be overridden in derived classes, and all post-processing should be done in `_validate` """ - self._check_abstract() self._validated = False if _AUTO_VALIDATE: self.validate() @@ -343,6 +342,7 @@ def _validate(self): Can be extended to add custom post-processing (typically before the super() call) and validation (typically after) """ + self._check_abstract() errors = [] for name, field in self.fields(): if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa @@ -604,17 +604,12 @@ def _add_field_to_args( else: field_value = value if serializable: - if hasattr(value, "__fast_llm_serialize__"): - field_value = field_value.__fast_llm_serialize__() - if isinstance(value, enum.Enum): - field_value = field_value.value - # Tag is not actually serializable, but needs to be kept as-is for config processing, - # and should be absent for valid configs. - elif not isinstance(value, int | float | bool | str | Tag | None): - field_value = str(field_value) + field_value = cls._serialize_value(value) if format_ == _ConfigDictFormat.tuple: field_value = {(): field_value} + if serializable: + name = cls._serialize_value(name) if format_ == _ConfigDictFormat.tuple: args.update({(name,) + name_: value_ for name_, value_ in field_value.items()}) elif format_ == _ConfigDictFormat.nested: @@ -626,6 +621,19 @@ def _add_field_to_args( else: raise NotImplementedError(format_) + @classmethod + def _serialize_value(cls, value): + value = value + if hasattr(value, "__fast_llm_serialize__"): + value = value.__fast_llm_serialize__() + if isinstance(value, enum.Enum): + value = value.value + # Tag is not actually serializable, but needs to be kept as-is for config processing, + # and should be absent for valid configs. + elif not isinstance(value, int | float | bool | str | Tag | None): + value = str(value) + return value + def to_copy( self, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], @@ -690,7 +698,6 @@ def _from_dict( strict: bool = True, flat: bool = False, ): - cls._check_abstract() # TODO v0.3: Remove flat format out_arg_dict = {} @@ -841,9 +848,11 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ @classmethod def _check_abstract(cls): if cls._abstract: - raise RuntimeError(f"{cls.__name__} is abstract") + raise ValidationError(f"{cls.__name__} is abstract") if not cls.__class_validated__: - raise RuntimeError(f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator.") + raise ValidationError( + f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator." + ) def __init_subclass__(cls): """ diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index bb4d78b9..37bf8752 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -26,7 +26,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): hint=FieldHint.feature, ) dataset: GPTSampledSplitDatasetConfig = Field( - default=None, + default_factory=GPTSampledSplitDatasetConfig, desc="Configuration for the dataset(s).", hint=FieldHint.core, ) @@ -47,11 +47,12 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): hint=FieldHint.expert, ) - def __post_init__(self): - if self.dataset is None: + def _validate(self): + if self.dataset.type is None: logger.warning("Using the legacy dataset definition format." " Specify it through `data.dataset` instead.") self.dataset = GPTLegacyDatasetConfig( - split=self.split, + ratio=self.ratio, format=self.format, path=self.path, ) + super()._validate() diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 0632f94d..12396c97 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -86,6 +86,7 @@ def setup(self, distributed: Distributed, samples_per_phase: PhaseSplits[int]): } ) self._datasets = self._config.dataset.build_split_sample(self, sampling_config) + self._is_setup = True @property def tokenizer(self): diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 2c29f543..6fe9314f 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -1,27 +1,26 @@ +import functools import math -import typing -from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.data.data.abstract import Data from fast_llm.data.data.config import SamplingConfig +from fast_llm.data.dataset.abstract import ( + PhaseSplits, + SamplableDataset, + SamplableSplitDataset, + SampledDataset, + SampledSplitDataset, +) from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - from fast_llm.data.dataset.abstract import ( - PhaseSplits, - SamplableDataset, - SamplableSplitDataset, - SampledDataset, - SampledSplitDataset, - ) - @config_class() class DatasetConfig(Config): _abstract = True +@config_class() class SampledSplitDatasetConfig(DatasetConfig): def build_split_sample( @@ -43,6 +42,7 @@ def split(self): return True +@config_class() class SampledDatasetConfig(SampledSplitDatasetConfig): """ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. @@ -70,6 +70,7 @@ def split(self): return False +@config_class() class SamplableSplitDatasetConfig(SampledSplitDatasetConfig): def build_split( @@ -102,6 +103,7 @@ def split(self): return True +@config_class() class SamplableDatasetConfig(SampledDatasetConfig, SamplableSplitDatasetConfig): def build(self, data: Data) -> SamplableDataset: raise NotImplementedError() @@ -135,10 +137,13 @@ class BlendedDatasetConfig(SampledDatasetConfig): hint=FieldHint.core, ) datasets: list[SampledDatasetConfig] = Field( + default_factory=list, desc="The datasets to blend.", hint=FieldHint.core, + valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)), ) weights: list[float] = Field( + default_factory=list, desc="The blending weight of each dataset.", hint=FieldHint.core, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 29795c64..f9c78d44 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,9 +1,10 @@ import enum +import functools import json import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.config import SamplingConfig from fast_llm.data.dataset.abstract import PhaseSplits, SamplableSplitDataset, SampledSplitDataset from fast_llm.data.dataset.config import ( @@ -23,7 +24,6 @@ from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice @config_class @@ -33,6 +33,7 @@ class GPTSamplingConfig(SamplingConfig): @config_class() class GPTDatasetConfig(DatasetConfig): + # TODO: Generalize dynamic types? _registry: typing.ClassVar[Registry[str, type["GPTDatasetConfig"]]] = Registry[str, type["GPTDatasetConfig"]]( "gpt_dataset_class", {} ) @@ -43,44 +44,65 @@ class GPTDatasetConfig(DatasetConfig): hint=FieldHint.core, ) - def __new__(cls, *args, **kwargs): - # Find and instantiate the actual class. - type_ = kwargs.get("type") - if type_ is not None: - actual_cls = GPTDatasetConfig._registry[type_] + def _validate(self): + if self.type is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.eq(self.type, self.type_) + super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ): + type_ = default.get("type") + if type_ is None: + actual_cls = cls + else: + actual_cls = cls._registry[type_] Assert.custom(issubclass, actual_cls, cls) - if actual_cls != cls: - return actual_cls(*args, **kwargs) - return super().__new__() + if actual_cls == cls: + return super()._from_dict(default, strict=strict, flat=flat) + else: + return actual_cls._from_dict(default, strict=strict, flat=flat) def __init_subclass__(cls, type_: str | None = None, **kwargs): if type_ is not None: GPTDatasetConfig._registry[type_] = cls - cls.type = type_ + cls.type_ = type_ + super().__init_subclass__() +@config_class() class GPTSampledSplitDatasetConfig(SampledSplitDatasetConfig, GPTDatasetConfig): pass +@config_class() class GPTSampledDatasetConfig(SampledDatasetConfig, GPTSampledSplitDatasetConfig): pass +@config_class() class GPTSamplableSplitDatasetConfig(SamplableSplitDatasetConfig, GPTSampledSplitDatasetConfig): pass +@config_class() class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig, GPTSamplableSplitDatasetConfig): pass +@config_class() class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig): def build(self, data: "GPTData") -> "GPTIndexedDataset": raise NotImplementedError() -class GPTDummyDatasetConfig(GPTSamplableDatasetConfig): +@config_class() +class GPTDummyDatasetConfig(GPTSamplableDatasetConfig, type_="dummy"): name: str = Field( default="dummy", desc="The name of the dataset.", @@ -92,10 +114,11 @@ def build(self, data: "GPTData") -> "GPTDummyDataset": @config_class() -class GPTMemmapDatasetConfig(GPTDatasetConfig, SamplableDatasetConfig, type_="memmap"): +class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig, type_="memmap"): # Path -> (unsampled, unsplit) _abstract = False path: pathlib.Path = Field( + default=None, desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", hint=FieldHint.core, ) @@ -122,8 +145,10 @@ class GPTConcatenatedDatasetConfig(GPTDatasetConfig, SamplableDatasetConfig, typ hint=FieldHint.core, ) datasets: list[GPTIndexedDatasetConfig] = Field( + default_factory=list, desc="The datasets to concatenate.", hint=FieldHint.core, + valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)), ) def build(self, data: "GPTData") -> "GPTConcatenatedDataset": @@ -133,7 +158,7 @@ def build(self, data: "GPTData") -> "GPTConcatenatedDataset": @config_class() -class GPTSplitDatasetConfig(SamplableSplitDatasetConfig, type_="split"): +class GPTSplitDatasetConfig(GPTSamplableSplitDatasetConfig, type_="split"): """ Split a single dataset into multiple phases. Must be done before sampling. @@ -143,15 +168,21 @@ class GPTSplitDatasetConfig(SamplableSplitDatasetConfig, type_="split"): _abstract = False dataset: GPTIndexedDatasetConfig = Field( + default=None, desc="The dataset to split.", hint=FieldHint.core, ) - ratios: PhaseSplits[float] = Field( + ratios: dict[PhaseType, float] = Field( + default=None, desc="The split ratio for each phase", hint=FieldHint.core, ) - def build(self, data: "GPTData") -> SamplableSplitDataset["GPTDatasetSlice"]: + def build_split( + self, + data: "GPTData", + default_phase: PhaseType = PhaseType.training, + ) -> SamplableSplitDataset: from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice return GPTDatasetSlice.from_splits(self.dataset.build(data), self.ratios) @@ -160,7 +191,7 @@ def build(self, data: "GPTData") -> SamplableSplitDataset["GPTDatasetSlice"]: @config_class() class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig, type_="blended"): _abstract = False - datasets: list[GPTSampledDatasetConfig] = FieldUpdate() + datasets: list[GPTSampledDatasetConfig] = FieldUpdate(desc="UINGBRI") class LegacyDatasetSource(str, enum.Enum): @@ -184,7 +215,18 @@ def _validate_path(value): @config_class() class GPTLegacyConfig(Config): - split: list[float] = Field( + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ): + # TODO v0.3: Remove. + cls._handle_renamed_field(default, "split", "ratio") + return super()._from_dict(default, strict, flat) + + ratio: list[float] = Field( default_factory=lambda: [969, 30, 1], desc="Split ratio for train, valid and test datasets.", hint=FieldHint.deprecated, @@ -245,7 +287,11 @@ def build_split_sample( dataset_configs = [ GPTSplitDatasetConfig( dataset=GPTMemmapDatasetConfig(path=prefix), - ratios=self.split, + ratios={ + PhaseType.training: self.ratio[0], + PhaseType.validation: self.ratio[1], + PhaseType.test: self.ratio[2], + }, ) for prefix in dataset_prefixes ] diff --git a/fast_llm/data/dataset/gpt/slice.py b/fast_llm/data/dataset/gpt/slice.py index df1a511a..87684e01 100644 --- a/fast_llm/data/dataset/gpt/slice.py +++ b/fast_llm/data/dataset/gpt/slice.py @@ -1,5 +1,6 @@ -from fast_llm.data.dataset.abstract import PhaseSplits, SamplableSplitDataset +from fast_llm.data.dataset.abstract import SamplableSplitDataset from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset +from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -55,7 +56,7 @@ def name(self): return self._name @classmethod - def from_splits(cls, dataset: GPTIndexedDataset, phase_split: PhaseSplits[float]): + def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, float]): """ Create a set of GPT datasets from a MMapIndexedDataset, each containing approximately the requested proportion of the total tokens. diff --git a/tests/common.py b/tests/common.py index 8c83b2ab..e7b8906f 100644 --- a/tests/common.py +++ b/tests/common.py @@ -61,9 +61,9 @@ "data.dataset.type=split", "data.dataset.dataset.type=memmap", f"data.dataset.dataset.path={DATASET_PREFIX}", - f"data.dataset.ratios.training=969", - f"data.dataset.ratios.validation=30", - f"data.dataset.ratios.test=1", + f"data.dataset.ratios.Training=969", + f"data.dataset.ratios.Validation=30", + f"data.dataset.ratios.Test=1", "optimizer.learning_rate.base=0.0001", ] CONFIG_BASE_MEGATRON = [ From e013ba2d3ac1ebf63d10b343676bfe39b511edd9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Jan 2025 15:21:38 -0500 Subject: [PATCH 03/19] fix --- fast_llm/data/dataset/gpt/config.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index f9c78d44..d475f74d 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -22,7 +22,7 @@ from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset - from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset + from fast_llm.data.dataset.gpt.dummy import GPTDummySampledDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset @@ -102,15 +102,32 @@ def build(self, data: "GPTData") -> "GPTIndexedDataset": @config_class() -class GPTDummyDatasetConfig(GPTSamplableDatasetConfig, type_="dummy"): +class GPTDummyDatasetConfig(GPTSampledSplitDatasetConfig, type_="dummy"): + # NA -> (unsampled, unsplit) + _abstract = False name: str = Field( default="dummy", desc="The name of the dataset.", hint=FieldHint.core, ) - def build(self, data: "GPTData") -> "GPTDummyDataset": - return GPTDummyDataset(self.name, data.max_sequence_length, data.vocab_size) + def build_split_sample( + self, + data: "GPTData", + config: PhaseSplits[GPTSamplingConfig], + default_phase: PhaseType = PhaseType.training, + ) -> "SampledSplitDataset[GPTDummySampledDataset]": + from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset, GPTDummySampledDataset + + return SampledSplitDataset[GPTDummySampledDataset]( + self.name, + { + phase: GPTDummyDataset(f"{self.name}_{phase.value}", data.max_sequence_length, data.vocab_size).sample( + phase_config, data + ) + for phase, phase_config in config.items() + }, + ) @config_class() @@ -191,7 +208,7 @@ def build_split( @config_class() class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig, type_="blended"): _abstract = False - datasets: list[GPTSampledDatasetConfig] = FieldUpdate(desc="UINGBRI") + datasets: list[GPTSampledDatasetConfig] = FieldUpdate() class LegacyDatasetSource(str, enum.Enum): From 06eaaa9faaeada1fa9e276638814b2a5ec3339c8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 8 Jan 2025 19:15:16 -0500 Subject: [PATCH 04/19] Dataset improvements --- fast_llm/csrc/data.cpp | 64 +++----- fast_llm/data/data/config.py | 2 +- fast_llm/data/data/gpt/config.py | 11 ++ fast_llm/data/data/gpt/data.py | 16 ++ fast_llm/data/dataset/blended.py | 23 ++- fast_llm/data/dataset/gpt/abstract.py | 58 -------- fast_llm/data/dataset/gpt/concatenated.py | 42 ------ fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/dataset/gpt/indexed.py | 60 ++++++++ fast_llm/data/dataset/gpt/memmap.py | 4 +- fast_llm/data/dataset/gpt/sampled.py | 169 +++++++++++++++------- fast_llm/data/dataset/gpt/slice.py | 72 --------- fast_llm/data/dataset/indexed.py | 103 +++++++++++++ tools/concatenate_dataset.py | 2 +- 14 files changed, 342 insertions(+), 289 deletions(-) delete mode 100644 fast_llm/data/dataset/gpt/abstract.py delete mode 100644 fast_llm/data/dataset/gpt/concatenated.py create mode 100644 fast_llm/data/dataset/gpt/indexed.py delete mode 100644 fast_llm/data/dataset/gpt/slice.py create mode 100644 fast_llm/data/dataset/indexed.py diff --git a/fast_llm/csrc/data.cpp b/fast_llm/csrc/data.cpp index b7e52924..67ae946d 100644 --- a/fast_llm/csrc/data.cpp +++ b/fast_llm/csrc/data.cpp @@ -104,9 +104,7 @@ void build_blending_indices(py::array_t& dataset_index, py::array build_sample_idx(const py::array_t& sizes_, const py::array_t& doc_idx_, const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch, - const bool verbose) { + const int64_t num_samples) { /* Sample index (sample_idx) is used for gpt2 like dataset for which the documents are flattened and the samples are built based on this 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] @@ -115,29 +113,14 @@ py::array build_sample_idx(const py::array_t& sizes_, // Consistency checks. assert(seq_length > 1); - assert(num_epochs > 0); - assert(tokens_per_epoch > 1); // Remove bound checks. auto sizes = sizes_.unchecked<1>(); auto doc_idx = doc_idx_.unchecked<1>(); // Mapping and it's length (1D). - int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; int32_t* sample_idx = new int32_t[2*(num_samples+1)]; - if (verbose) { - cout << " using:" << endl << std::flush; - cout << " number of documents: " << - doc_idx_.shape(0) / num_epochs << endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " sequence length: " << seq_length << - endl << std::flush; - cout << " total number of samples: " << num_samples << - endl << std::flush; - } - // Index into sample_idx. int64_t sample_index = 0; // Index into doc_idx. @@ -151,30 +134,29 @@ py::array build_sample_idx(const py::array_t& sizes_, while (sample_index <= num_samples) { // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; - while (remaining_seq_length != 0) { + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length > 0) { // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; - // And add it to the current sequence. - remaining_seq_length -= doc_length; - // If we have more than a full sequence, adjust offset and set - // remaining length to zero so we return from the while loop. - // Note that -1 here is for the same reason we have -1 in - // `_num_epochs` calculations. - if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - 1); - remaining_seq_length = 0; - } else { - // Otherwise, start from the beginning of the next document. - ++doc_idx_index; - doc_offset = 0; - } - } - // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + } else { + // Otherwise, start from the beginning of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; } // Method to deallocate memory. diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 3485c2e0..b6b1585e 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -10,7 +10,7 @@ class SamplingConfig(Config): num_samples: int = Field(default=1, desc="Number of samples to generate.", valid=check_field(Assert.gt, 0)) seed: int = Field(default=0, desc="Random seed.") cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.") - verbose: bool = Field(default=True, desc="Log sampling progress.") + # verbose: bool = Field(default=True, desc="Log sampling progress.") @config_class() diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 37bf8752..55d0c28a 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -41,6 +41,17 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): hint=FieldHint.feature, valid=check_field(Assert.gt, 0), ) + shuffle_epochs: bool = Field( + default=True, + desc="Shuffle all epochs together. Adds extra randomness," + " but makes it harder to change the training length after training is started.", + hint=FieldHint.feature, + ) + distributed_data_sampling: bool = Field( + default=True, + desc="When possible, distribute data sampling across all available processes to speed it up.", + hint=FieldHint.performance, + ) multiprocessing_context: MultiprocessingContext = Field( default=MultiprocessingContext.spawn, desc="Multiprocessing context. Do not touch.", diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 12396c97..ec627968 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,10 +1,12 @@ import logging import pathlib +import time import warnings import torch import torch.utils.data +from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import PhaseSplits, SampledSplitDataset @@ -48,6 +50,8 @@ def __init__( super().__init__(config, distributed_config) self._vocab_size = vocab_size self._max_sequence_length = max_sequence_length + self._sampling_rank = -1 + self._sampling_time = None @property def vocab_size(self) -> int: @@ -86,8 +90,20 @@ def setup(self, distributed: Distributed, samples_per_phase: PhaseSplits[int]): } ) self._datasets = self._config.dataset.build_split_sample(self, sampling_config) + safe_barrier(self._distributed.world_group, "build_split_sample") self._is_setup = True + def get_next_sampling_rank_and_verbose(self) -> tuple[int, bool]: + sampling_time = time.perf_counter() + verbose = self._sampling_time is None or sampling_time - self._sampling_time > 60 + if verbose: + self._sampling_time = sampling_time + if self._config.distributed_data_sampling: + self._sampling_rank += 1 + return self._sampling_rank % self._distributed_config.world_size, verbose + else: + return 0, verbose + @property def tokenizer(self): assert self._is_setup diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 970e7c62..69b45d32 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -4,7 +4,6 @@ import numpy as np -from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.data.config import SamplingConfig from fast_llm.data.dataset.abstract import SampledDataset @@ -48,9 +47,7 @@ def __init__( if sampling_config.cache_directory is None: self._dataset_idx_filename, self._sample_idx_filename = None, None - self._dataset_index, self._sample_index = self._build_blending_indices( - sampling_config.verbose and len(self._datasets) <= 20 - ) + self._dataset_index, self._sample_index = self._build_blending_indices() else: group = data.distributed.world_group self._dataset_idx_filename = sampling_config.cache_directory / (self._name + "_blending_dataset_idx.npy") @@ -61,16 +58,11 @@ def __init__( if (group is None or group.rank() == 0) and not ( self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file() ): - dataset_index, sample_index = self._build_blending_indices( - sampling_config.verbose and len(self._datasets) <= 20 - ) + dataset_index, sample_index = self._build_blending_indices() sampling_config.cache_directory.mkdir(exist_ok=True, parents=True) np.save(self._dataset_idx_filename, dataset_index) np.save(self._sample_idx_filename, sample_index) - safe_barrier(group, self._name) - self._load_mappings(sampling_config.verbose) - def __getstate__(self): return ( self._datasets, @@ -92,12 +84,13 @@ def __setstate__(self, state): ) = state if isinstance(dataset_index, pathlib.Path): self._dataset_idx_filename, self._sample_idx_filename = dataset_index, sample_index - self._load_mappings(False) else: self._dataset_idx_filename, self._sample_idx_filename = None, None self._dataset_index, self._sample_index = dataset_index, sample_index - def _load_mappings(self, verbose): + def _load_mappings(self, verbose: bool = False): + if hasattr(self, "_dataset_index"): + return if verbose: log_main_rank(lambda: f" > loading blending dataset index mapping from {self._dataset_idx_filename}") self._dataset_index = np.load(self._dataset_idx_filename, mmap_mode="r") @@ -108,7 +101,7 @@ def _load_mappings(self, verbose): def __len__(self): return self._num_samples - def _build_blending_indices(self, verbose: bool): + def _build_blending_indices(self): assert _extension_available, ( "The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly." ) @@ -121,7 +114,8 @@ def _build_blending_indices(self, verbose: bool): self._weights, len(self._datasets), self._num_samples, - verbose, + # TODO: Verbose option? + True, # verbose ) available_samples_per_dataset = np.array([len(dataset) for dataset in self._datasets]) sampled_per_dataset = np.bincount(dataset_index) @@ -141,6 +135,7 @@ def _build_blending_indices(self, verbose: bool): return dataset_index, dataset_sample_index def __getitem__(self, idx): + self._load_mappings() start_time = time.perf_counter() dataset_index = self._dataset_index[idx] dataset = self._datasets[dataset_index] diff --git a/fast_llm/data/dataset/gpt/abstract.py b/fast_llm/data/dataset/gpt/abstract.py deleted file mode 100644 index 40f1532a..00000000 --- a/fast_llm/data/dataset/gpt/abstract.py +++ /dev/null @@ -1,58 +0,0 @@ -import abc -import typing - -import numpy as np - -from fast_llm.data.dataset.abstract import SamplableDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig - -if typing.TYPE_CHECKING: - from fast_llm.data.data.gpt.data import GPTData - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - -try: - from fast_llm.csrc.data import build_sample_idx # noqa - - _extension_available = True -except ImportError: - _extension_available = False - - -class GPTIndexedDataset(SamplableDataset): - """ - A GPT dataset containing a list of unsampled, unprocessed samples. - TODO: Move sampling responsibility here? - """ - - def get(self, document: int, offset: int = 0, length: int | None = None): - pass - - @property - def num_documents(self) -> int: - """ - Number of documents in the dataset. - Can be calculated from document sizes but may be overridden if there is a better method. - """ - return len(self.get_document_sizes()) - - @property - def num_tokens(self) -> int: - """ - Number of tokens in the dataset. - Can be calculated from document sizes but may be overridden if there is a better method. - """ - return self.get_document_sizes().sum() - - @abc.abstractmethod - def get_document_sizes(self) -> np.ndarray: - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - - def sample(self, config: GPTSamplingConfig, data: "GPTData") -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - return GPTSampledIndexedDataset(self, config, data) diff --git a/fast_llm/data/dataset/gpt/concatenated.py b/fast_llm/data/dataset/gpt/concatenated.py deleted file mode 100644 index 2b133dba..00000000 --- a/fast_llm/data/dataset/gpt/concatenated.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy as np - -from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset -from fast_llm.utils import padded_cumsum - - -class GPTConcatenatedDataset(GPTIndexedDataset): - - def __init__( - self, - name: str, - datasets: list[GPTIndexedDataset], - ): - self._name = name - self._datasets = datasets - sizes = [dataset.num_documents for dataset in self._datasets] - self._dataset_splits = padded_cumsum(sizes) - self._num_documents = sum(sizes) - - @property - def num_tokens(self) -> int: - return sum(dataset.num_tokens for dataset in self._datasets) - - def num_documents(self) -> int: - return sum(dataset.num_documents for dataset in self._datasets) - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - - def get(self, document: int, offset: int = 0, length: int | None = None): - """ - Get the sample (document) with the given index (in the dataset slice), - optionally sub-sampled to a specific offset (starting point) and maximum length - (end = min(offset + length, sample_length). - """ - dataset = np.searchsorted(self._dataset_splits[1:], document, side="right") - return self._datasets[dataset].get(document - self._dataset_splits[dataset], offset, length) - - @property - def name(self) -> str: - return self._name diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index d475f74d..48c704c8 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -20,9 +20,8 @@ if typing.TYPE_CHECKING: from fast_llm.data.data.gpt.data import GPTData - from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset - from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset from fast_llm.data.dataset.gpt.dummy import GPTDummySampledDataset + from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset @@ -200,8 +199,6 @@ def build_split( data: "GPTData", default_phase: PhaseType = PhaseType.training, ) -> SamplableSplitDataset: - from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice - return GPTDatasetSlice.from_splits(self.dataset.build(data), self.ratios) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py new file mode 100644 index 00000000..bf1fface --- /dev/null +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -0,0 +1,60 @@ +import abc +import typing + +import numpy as np + +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.indexed import ConcatenatedIndexedDataset, IndexedDataset, IndexedDatasetSlice + +if typing.TYPE_CHECKING: + from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset + + +class GPTIndexedDataset(IndexedDataset): + """ + A GPT dataset containing a list of samples. + """ + + # def get(self, index: int, offset: int = 0, length: int | None = None): + # pass + + # def __len__(self) -> int: + # """ + # Number of documents in the dataset. + # Can be calculated from document sizes but may be overridden if there is a better method. + # """ + # return len(self.get_document_sizes()) + + @abc.abstractmethod + def get_document_sizes(self) -> np.ndarray: + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + def sample(self, config: GPTSamplingConfig, data: GPTData) -> "GPTSampledIndexedDataset": + from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset + + return GPTSampledIndexedDataset(self, config, data) + + +class GPTDatasetSlice(IndexedDatasetSlice, GPTIndexedDataset): + """ + A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. + """ + + _dataset: GPTIndexedDataset + + def get_document_sizes(self): + # TODO: This can be really big. + return self._dataset.get_document_sizes()[self._begin : self._end] + + +class GPTConcatenatedDataset(ConcatenatedIndexedDataset, GPTIndexedDataset): + _datasets: list[GPTIndexedDataset] + + def get_document_sizes(self) -> np.ndarray: + # TODO: This can be really big. + return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 30184415..c1b0c562 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -4,7 +4,7 @@ import numpy as np -from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset +from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div @@ -76,7 +76,7 @@ def name(self): return self._name @property - def num_documents(self) -> int: + def __len__(self) -> int: return self._num_documents @property diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 785c7aac..fcbd5066 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -1,16 +1,17 @@ +import logging import math import numpy as np -from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.data.dataset.gpt.fim.fim import Fim +from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import MAX_SEED -from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) try: from fast_llm.csrc.data import build_sample_idx # noqa @@ -38,6 +39,7 @@ def __init__( assert isinstance(data, GPTData) self._indexed_dataset = indexed_dataset self._sampling_config = sampling_config + self._shuffle_epochs = data.config.shuffle_epochs if data.config.fim.rate > 0: assert data.tokenizer is not None @@ -54,79 +56,109 @@ def __init__( group = data.distributed.world_group # Build the indexed mapping if it doesn't exist. # TODO: This only works if the dataset location is accessible by all job. - if (group is None or group.rank() == 0) and not ( + + rank, verbose = data.get_next_sampling_rank_and_verbose() + if (group is None or group.rank() == rank) and not ( self._doc_idx_filename.is_file() and self._sample_idx_filename.is_file() and self._shuffle_idx_filename.is_file() ): - if self._sampling_config.verbose: - log_main_rank(" > Building the index map on rank 0 ...") + if verbose: + logger.info(f" > Building the index map on rank {rank} ...") doc_idx, sample_idx, shuffle_idx = self._sample() self._sampling_config.cache_directory.mkdir(parents=True, exist_ok=True) np.save(self._doc_idx_filename, doc_idx) np.save(self._sample_idx_filename, sample_idx) np.save(self._shuffle_idx_filename, shuffle_idx) - safe_barrier(group, self._indexed_dataset.name) - self._load_mappings(self._sampling_config.verbose) - - def _sample(self): + def _sample(self, verbose: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Create a `GPTSampledDataset` with the requested parameters. """ document_sizes = self._indexed_dataset.get_document_sizes() - num_documents = len(document_sizes) - num_tokens = document_sizes.sum() + documents_per_epoch = len(document_sizes) + tokens_per_epoch = document_sizes.sum() np_rng = np.random.RandomState(seed=self._sampling_config.seed) - num_epochs = math.ceil( - (self._sampling_config.sequence_length * self._sampling_config.num_samples + 1) / num_tokens - ) - # For the last epoch, decide whether include the entire epoch - # in the global shuffle or not. - # Get the number of samples for the last epoch - main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // self._sampling_config.sequence_length - last_epoch_samples = self._sampling_config.num_samples - main_epochs_samples - samples_per_epoch = (num_tokens - 1) // self._sampling_config.sequence_length - # If we have less than 80% of the samples for the last epoch, separate out the epoch and treat it differently. - # Note: the 80% number is just based on common sense and can be adjusted if needed. + total_tokens = self._sampling_config.sequence_length * self._sampling_config.num_samples + num_epochs = math.ceil((total_tokens + 1) / tokens_per_epoch) + epoch_begins_in_sample_index = ( + np.arange(num_epochs) * tokens_per_epoch - 1 + ) // self._sampling_config.sequence_length + + # Treat the last epoch differently if we use less than 80% of it. + # Necessary to match the behavior of Megatron-LM. + last_epoch_samples = self._sampling_config.num_samples - epoch_begins_in_sample_index[-1] + samples_per_epoch = (tokens_per_epoch - 1) // self._sampling_config.sequence_length separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch - doc_idx = np.tile(np.arange(num_documents, dtype=np.int32), num_epochs) - if separate_last_epoch: - np_rng.shuffle(doc_idx[:-num_documents]) - np_rng.shuffle(doc_idx[-num_documents:]) + # Shuffle documents. + doc_idx = np.tile(np.arange(documents_per_epoch, dtype=np.int32), num_epochs) + if self._shuffle_epochs: + if separate_last_epoch: + np_rng.shuffle(doc_idx[:-documents_per_epoch]) + np_rng.shuffle(doc_idx[-documents_per_epoch:]) + else: + np_rng.shuffle(doc_idx) else: - np_rng.shuffle(doc_idx) + for epoch in range(num_epochs): + # Reseed each epoch to make sampling reproducible with a different number of epochs. + np.random.RandomState(seed=self._sampling_config.seed + 738741 * epoch + 90823).shuffle( + doc_idx[epoch * documents_per_epoch : (epoch + 1) * documents_per_epoch] + ) - assert _extension_available, ( - "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." - ) + if self._fast_sampling: + # TODO: Crop token_cumsum and doc_idx to num samples + token_cumsum = document_sizes[doc_idx].cumsum(dtype=np.int64) + # TODO: Verify. + # Trim the unused part of the last epoch. + total_documents = np.searchsorted(token_cumsum, total_tokens) + token_cumsum = token_cumsum[:total_documents] + doc_idx = token_cumsum[:doc_idx] + + # Shuffle samples. + sample_idx = np.arange(self._sampling_config.num_samples) + if self._shuffle_epochs: + if separate_last_epoch: + np_rng.shuffle(sample_idx[: epoch_begins_in_sample_index[-1]]) + np_rng.shuffle(sample_idx[epoch_begins_in_sample_index[-1] :]) + else: + np_rng.shuffle(sample_idx) + else: + for epoch in range(num_epochs): + # Shuffle samples within an epoch, excluding the first one which may span two epochs. + # TODO: Include the first one if it's entirely in the epoch. + # Reseed each epoch to make sampling reproducible with a different number of epochs. + np.random.RandomState(seed=self._sampling_config.seed + 36478921 * epoch + 587469).shuffle( + sample_idx[epoch_begins_in_sample_index[epoch] + 1 : epoch_begins_in_sample_index[epoch + 1]] + ) + return doc_idx, sample_idx, token_cumsum + + assert ( + _extension_available + ), "The C++ extension for dataset sampling is missing. Please make sure Fast-LLM is installed correctly." + if verbose: + logger.info(f" > Building sample index for {self._sampling_config.num_samples} samples ...") sample_idx = build_sample_idx( document_sizes, doc_idx, self._sampling_config.sequence_length, - num_epochs, - num_tokens, - self._sampling_config.verbose, + self._sampling_config.num_samples, ) - # shuffle-idx. # -1 is due to data structure used to retrieve the index: # sample i --> [sample_idx[i], sample_idx[i+1]) total_size = sample_idx.shape[0] - 1 - # TODO: Isn't the dataset already shuffled above? shuffle_idx = np.arange( 0, total_size, dtype=np.int64 if total_size >= (np.iinfo(np.uint32).max - 1) else np.uint32 ) if separate_last_epoch: - np_rng.shuffle(shuffle_idx[:main_epochs_samples]) - np_rng.shuffle(shuffle_idx[main_epochs_samples:]) + np_rng.shuffle(shuffle_idx[: epoch_begins_in_sample_index[-1]]) + np_rng.shuffle(shuffle_idx[epoch_begins_in_sample_index[-1] :]) else: np_rng.shuffle(shuffle_idx) - Assert.geq(len(shuffle_idx), self._sampling_config.num_samples) # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. return doc_idx, sample_idx, shuffle_idx[: self._sampling_config.num_samples] @@ -150,9 +182,10 @@ def __setstate__(self, state): self._shuffle_idx_filename, ) = state self._sampling_config = GPTSamplingConfig.from_dict(sampling_config) - self._load_mappings(False) - def _load_mappings(self, verbose): + def _load_mappings(self, verbose=False): + if hasattr(self, "_doc_idx"): + return if verbose: log_main_rank(lambda: f" > loading doc-idx mapping from {self._doc_idx_filename}") self._doc_idx = np.load(self._doc_idx_filename, mmap_mode="r") @@ -176,19 +209,47 @@ def __getitem__(self, idx): with the requested sampling index. The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). """ - # Get the shuffled index. - shuffled_idx = self._shuffle_idx[idx] - # Start and end documents and offsets. - doc_f, offset_f = self._sample_idx[shuffled_idx] - doc_l, offset_l = self._sample_idx[shuffled_idx + 1] - sample_list = [ - self._indexed_dataset.get( - self._doc_idx[doc], - offset=(doc == doc_f) * offset_f, - length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, - ) - for doc in range(doc_f, doc_l + 1) - ] + # Lazy load mappings + self._load_mappings() + + if self._fast_sampling: + token_cumsum = self._shuffle_idx + token_start_index = idx * self._sampling_config.sequence_length + doc_begin = np.searchsorted(token_cumsum, token_start_index) + + sample_list = [] + current_doc = doc_begin + remaining_tokens = self._sampling_config.sequence_length + 1 + while remaining_tokens > 0: + offset = token_start_index - token_cumsum[current_doc] if current_doc == doc_begin else 0 + # TODO: Boundary + document_size = token_cumsum[current_doc] - token_cumsum[current_doc - 1] + length = min(document_size - offset, remaining_tokens) + sample_list.append( + self._indexed_dataset.get( + self._doc_idx[current_doc], + offset=offset, + length=length, + ) + ) + remaining_tokens -= length + current_doc += 1 + + else: + # Get the shuffled index. + shuffled_idx = self._shuffle_idx[idx] + # Start and end documents and offsets. + doc_f, offset_f = self._sample_idx[shuffled_idx] + doc_l, offset_l = self._sample_idx[shuffled_idx + 1] + sample_list = [ + self._indexed_dataset.get( + self._doc_idx[doc], + offset=(doc == doc_f) * offset_f, + length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, + ) + for doc in range(doc_f, doc_l + 1) + ] + sample = np.concatenate( sample_list, dtype=np.int64, diff --git a/fast_llm/data/dataset/gpt/slice.py b/fast_llm/data/dataset/gpt/slice.py deleted file mode 100644 index 87684e01..00000000 --- a/fast_llm/data/dataset/gpt/slice.py +++ /dev/null @@ -1,72 +0,0 @@ -from fast_llm.data.dataset.abstract import SamplableSplitDataset -from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum - - -class GPTDatasetSlice(GPTIndexedDataset): - """ - A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. - """ - - def __init__( - self, - name: str, - dataset: GPTIndexedDataset, - begin: int | None = None, - end: int | None = None, - ): - self._name = name - self._dataset = dataset - self._begin = 0 if begin is None else begin - dataset_documents = dataset.num_documents - self._end = dataset_documents if end is None else end - - # Checks - try: - Assert.geq(self._begin, 0) - Assert.in_range_incl(self._end, self._begin + 1, dataset_documents) - except Exception as e: - raise AssertionError(f"Invalid document indices for dataset {name} with length {dataset_documents}") from e - - def __getitem__(self, index: int): - """ - Get the sample (document) with the given index (in the split dataset). - """ - return self.get(index) - - def get(self, document: int, offset: int = 0, length: int | None = None): - """ - Get the sample (document) with the given index (in the dataset slice), - optionally sub-sampled to a specific offset (starting point) and maximum length - (end = min(offset + length, sample_length). - """ - return self._dataset.get(document + self._begin, offset, length) - - @property - def num_documents(self): - return self._end - self._begin - - def get_document_sizes(self): - # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] - - @property - def name(self): - return self._name - - @classmethod - def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, float]): - """ - Create a set of GPT datasets from a MMapIndexedDataset, - each containing approximately the requested proportion of the total tokens. - """ - probabilities = normalize_probabilities(list(phase_split.values())) - splits = [round(x) for x in padded_cumsum(probabilities) * dataset.num_documents] - return SamplableSplitDataset[GPTDatasetSlice]( - f"{dataset.name}_split", - { - phase: GPTDatasetSlice(f"{dataset.name}_{phase.value}", dataset, split_begin, split_end) - for phase, split_begin, split_end in zip(phase_split, splits[:-1], splits[1:]) - }, - ) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py new file mode 100644 index 00000000..9ec73663 --- /dev/null +++ b/fast_llm/data/dataset/indexed.py @@ -0,0 +1,103 @@ +import abc + +import numpy as np + +from fast_llm.data.data.abstract import Data +from fast_llm.data.data.config import SamplingConfig +from fast_llm.data.dataset.abstract import SamplableDataset, SamplableSplitDataset, SampledDataset +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum + + +class IndexedDataset(SamplableDataset): + """ + A dataset containing a list of samples. + TODO: Move sampling responsibility here? + """ + + @abc.abstractmethod + def get(self, index: int, *args, **kwargs): + pass + + @abc.abstractmethod + def __len__(self) -> int: + """ + Number of samples in the dataset. + """ + + @abc.abstractmethod + def sample(self, config: SamplingConfig, data: Data) -> SampledDataset: + pass + + +class IndexedDatasetSlice(IndexedDataset): + + def __init__( + self, + name: str, + dataset: IndexedDataset, + begin: int | None = None, + end: int | None = None, + ): + self._name = name + self._dataset = dataset + self._begin = 0 if begin is None else begin + num_samples = len(dataset) + self._end = num_samples if end is None else end + + # Checks + try: + Assert.geq(self._begin, 0) + Assert.in_range_incl(self._end, self._begin + 1, num_samples) + except Exception as e: + raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e + + def __getitem__(self, index: int): + """ + Get the sample (document) with the given index (in the split dataset). + """ + return self.get(index) + + @property + def __len__(self): + return self._end - self._begin + + @classmethod + def from_splits(cls, dataset: IndexedDataset, phase_split: dict[PhaseType, float]): + """ + Create a set of GPT datasets from a MMapIndexedDataset, + each containing approximately the requested proportion of the total tokens. + """ + probabilities = normalize_probabilities(list(phase_split.values())) + splits = [round(x) for x in padded_cumsum(probabilities) * len(dataset)] + return SamplableSplitDataset[cls]( + f"{dataset.name}_split", + { + phase: cls(f"{dataset.name}_{phase.value}", dataset, split_begin, split_end) + for phase, split_begin, split_end in zip(phase_split, splits[:-1], splits[1:]) + }, + ) + + +class ConcatenatedIndexedDataset(IndexedDataset): + + def __init__( + self, + name: str, + datasets: list[IndexedDataset], + ): + self._name = name + self._datasets = datasets + sizes = [len(dataset) for dataset in self._datasets] + self._dataset_splits = padded_cumsum(sizes) + + def __len__(self) -> int: + return self._dataset_splits[-1] + + def get(self, index: int, *args, **kwargs): + dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get(index - self._dataset_splits[dataset], *args, **kwargs) + + @property + def name(self) -> str: + return self._name diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index 48e7041e..3c595a4d 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -34,7 +34,7 @@ def run(self): dataset = GPTMemmapDataset("dataset", prefix) dataset_dict = { "prefix": str(prefix.relative_to(self.directory)), - "num_documents": dataset.num_documents, + "num_documents": dataset.__len__, "num_tokens": dataset.num_tokens, } if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: From 3992df7ffc53f1f8282e649f1f447e9ebbdf4117 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Jan 2025 13:23:22 -0500 Subject: [PATCH 05/19] fix --- tools/concatenate_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index 3c595a4d..bbfa4b21 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -34,7 +34,7 @@ def run(self): dataset = GPTMemmapDataset("dataset", prefix) dataset_dict = { "prefix": str(prefix.relative_to(self.directory)), - "num_documents": dataset.__len__, + "num_documents": len(dataset), "num_tokens": dataset.num_tokens, } if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: From 82285aeb229dc363115b4bce0635d9734b1a56d5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Jan 2025 13:38:15 -0500 Subject: [PATCH 06/19] Generalize indexed --- fast_llm/data/dataset/gpt/abstract.py | 58 ------------ fast_llm/data/dataset/gpt/concatenated.py | 42 --------- fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/dataset/gpt/indexed.py | 60 +++++++++++++ fast_llm/data/dataset/gpt/memmap.py | 4 +- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/data/dataset/gpt/slice.py | 72 --------------- fast_llm/data/dataset/indexed.py | 103 ++++++++++++++++++++++ tools/concatenate_dataset.py | 2 +- 9 files changed, 168 insertions(+), 180 deletions(-) delete mode 100644 fast_llm/data/dataset/gpt/abstract.py delete mode 100644 fast_llm/data/dataset/gpt/concatenated.py create mode 100644 fast_llm/data/dataset/gpt/indexed.py delete mode 100644 fast_llm/data/dataset/gpt/slice.py create mode 100644 fast_llm/data/dataset/indexed.py diff --git a/fast_llm/data/dataset/gpt/abstract.py b/fast_llm/data/dataset/gpt/abstract.py deleted file mode 100644 index 40f1532a..00000000 --- a/fast_llm/data/dataset/gpt/abstract.py +++ /dev/null @@ -1,58 +0,0 @@ -import abc -import typing - -import numpy as np - -from fast_llm.data.dataset.abstract import SamplableDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig - -if typing.TYPE_CHECKING: - from fast_llm.data.data.gpt.data import GPTData - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - -try: - from fast_llm.csrc.data import build_sample_idx # noqa - - _extension_available = True -except ImportError: - _extension_available = False - - -class GPTIndexedDataset(SamplableDataset): - """ - A GPT dataset containing a list of unsampled, unprocessed samples. - TODO: Move sampling responsibility here? - """ - - def get(self, document: int, offset: int = 0, length: int | None = None): - pass - - @property - def num_documents(self) -> int: - """ - Number of documents in the dataset. - Can be calculated from document sizes but may be overridden if there is a better method. - """ - return len(self.get_document_sizes()) - - @property - def num_tokens(self) -> int: - """ - Number of tokens in the dataset. - Can be calculated from document sizes but may be overridden if there is a better method. - """ - return self.get_document_sizes().sum() - - @abc.abstractmethod - def get_document_sizes(self) -> np.ndarray: - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - - def sample(self, config: GPTSamplingConfig, data: "GPTData") -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - return GPTSampledIndexedDataset(self, config, data) diff --git a/fast_llm/data/dataset/gpt/concatenated.py b/fast_llm/data/dataset/gpt/concatenated.py deleted file mode 100644 index 2b133dba..00000000 --- a/fast_llm/data/dataset/gpt/concatenated.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy as np - -from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset -from fast_llm.utils import padded_cumsum - - -class GPTConcatenatedDataset(GPTIndexedDataset): - - def __init__( - self, - name: str, - datasets: list[GPTIndexedDataset], - ): - self._name = name - self._datasets = datasets - sizes = [dataset.num_documents for dataset in self._datasets] - self._dataset_splits = padded_cumsum(sizes) - self._num_documents = sum(sizes) - - @property - def num_tokens(self) -> int: - return sum(dataset.num_tokens for dataset in self._datasets) - - def num_documents(self) -> int: - return sum(dataset.num_documents for dataset in self._datasets) - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - - def get(self, document: int, offset: int = 0, length: int | None = None): - """ - Get the sample (document) with the given index (in the dataset slice), - optionally sub-sampled to a specific offset (starting point) and maximum length - (end = min(offset + length, sample_length). - """ - dataset = np.searchsorted(self._dataset_splits[1:], document, side="right") - return self._datasets[dataset].get(document - self._dataset_splits[dataset], offset, length) - - @property - def name(self) -> str: - return self._name diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index d475f74d..48c704c8 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -20,9 +20,8 @@ if typing.TYPE_CHECKING: from fast_llm.data.data.gpt.data import GPTData - from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset - from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset from fast_llm.data.dataset.gpt.dummy import GPTDummySampledDataset + from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset @@ -200,8 +199,6 @@ def build_split( data: "GPTData", default_phase: PhaseType = PhaseType.training, ) -> SamplableSplitDataset: - from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice - return GPTDatasetSlice.from_splits(self.dataset.build(data), self.ratios) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py new file mode 100644 index 00000000..bf1fface --- /dev/null +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -0,0 +1,60 @@ +import abc +import typing + +import numpy as np + +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.indexed import ConcatenatedIndexedDataset, IndexedDataset, IndexedDatasetSlice + +if typing.TYPE_CHECKING: + from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset + + +class GPTIndexedDataset(IndexedDataset): + """ + A GPT dataset containing a list of samples. + """ + + # def get(self, index: int, offset: int = 0, length: int | None = None): + # pass + + # def __len__(self) -> int: + # """ + # Number of documents in the dataset. + # Can be calculated from document sizes but may be overridden if there is a better method. + # """ + # return len(self.get_document_sizes()) + + @abc.abstractmethod + def get_document_sizes(self) -> np.ndarray: + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + def sample(self, config: GPTSamplingConfig, data: GPTData) -> "GPTSampledIndexedDataset": + from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset + + return GPTSampledIndexedDataset(self, config, data) + + +class GPTDatasetSlice(IndexedDatasetSlice, GPTIndexedDataset): + """ + A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. + """ + + _dataset: GPTIndexedDataset + + def get_document_sizes(self): + # TODO: This can be really big. + return self._dataset.get_document_sizes()[self._begin : self._end] + + +class GPTConcatenatedDataset(ConcatenatedIndexedDataset, GPTIndexedDataset): + _datasets: list[GPTIndexedDataset] + + def get_document_sizes(self) -> np.ndarray: + # TODO: This can be really big. + return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 30184415..c1b0c562 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -4,7 +4,7 @@ import numpy as np -from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset +from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div @@ -76,7 +76,7 @@ def name(self): return self._name @property - def num_documents(self) -> int: + def __len__(self) -> int: return self._num_documents @property diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 785c7aac..c47923d3 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -5,9 +5,9 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.data.dataset.gpt.fim.fim import Fim +from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.utils import Assert diff --git a/fast_llm/data/dataset/gpt/slice.py b/fast_llm/data/dataset/gpt/slice.py deleted file mode 100644 index 87684e01..00000000 --- a/fast_llm/data/dataset/gpt/slice.py +++ /dev/null @@ -1,72 +0,0 @@ -from fast_llm.data.dataset.abstract import SamplableSplitDataset -from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum - - -class GPTDatasetSlice(GPTIndexedDataset): - """ - A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. - """ - - def __init__( - self, - name: str, - dataset: GPTIndexedDataset, - begin: int | None = None, - end: int | None = None, - ): - self._name = name - self._dataset = dataset - self._begin = 0 if begin is None else begin - dataset_documents = dataset.num_documents - self._end = dataset_documents if end is None else end - - # Checks - try: - Assert.geq(self._begin, 0) - Assert.in_range_incl(self._end, self._begin + 1, dataset_documents) - except Exception as e: - raise AssertionError(f"Invalid document indices for dataset {name} with length {dataset_documents}") from e - - def __getitem__(self, index: int): - """ - Get the sample (document) with the given index (in the split dataset). - """ - return self.get(index) - - def get(self, document: int, offset: int = 0, length: int | None = None): - """ - Get the sample (document) with the given index (in the dataset slice), - optionally sub-sampled to a specific offset (starting point) and maximum length - (end = min(offset + length, sample_length). - """ - return self._dataset.get(document + self._begin, offset, length) - - @property - def num_documents(self): - return self._end - self._begin - - def get_document_sizes(self): - # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] - - @property - def name(self): - return self._name - - @classmethod - def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, float]): - """ - Create a set of GPT datasets from a MMapIndexedDataset, - each containing approximately the requested proportion of the total tokens. - """ - probabilities = normalize_probabilities(list(phase_split.values())) - splits = [round(x) for x in padded_cumsum(probabilities) * dataset.num_documents] - return SamplableSplitDataset[GPTDatasetSlice]( - f"{dataset.name}_split", - { - phase: GPTDatasetSlice(f"{dataset.name}_{phase.value}", dataset, split_begin, split_end) - for phase, split_begin, split_end in zip(phase_split, splits[:-1], splits[1:]) - }, - ) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py new file mode 100644 index 00000000..9ec73663 --- /dev/null +++ b/fast_llm/data/dataset/indexed.py @@ -0,0 +1,103 @@ +import abc + +import numpy as np + +from fast_llm.data.data.abstract import Data +from fast_llm.data.data.config import SamplingConfig +from fast_llm.data.dataset.abstract import SamplableDataset, SamplableSplitDataset, SampledDataset +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum + + +class IndexedDataset(SamplableDataset): + """ + A dataset containing a list of samples. + TODO: Move sampling responsibility here? + """ + + @abc.abstractmethod + def get(self, index: int, *args, **kwargs): + pass + + @abc.abstractmethod + def __len__(self) -> int: + """ + Number of samples in the dataset. + """ + + @abc.abstractmethod + def sample(self, config: SamplingConfig, data: Data) -> SampledDataset: + pass + + +class IndexedDatasetSlice(IndexedDataset): + + def __init__( + self, + name: str, + dataset: IndexedDataset, + begin: int | None = None, + end: int | None = None, + ): + self._name = name + self._dataset = dataset + self._begin = 0 if begin is None else begin + num_samples = len(dataset) + self._end = num_samples if end is None else end + + # Checks + try: + Assert.geq(self._begin, 0) + Assert.in_range_incl(self._end, self._begin + 1, num_samples) + except Exception as e: + raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e + + def __getitem__(self, index: int): + """ + Get the sample (document) with the given index (in the split dataset). + """ + return self.get(index) + + @property + def __len__(self): + return self._end - self._begin + + @classmethod + def from_splits(cls, dataset: IndexedDataset, phase_split: dict[PhaseType, float]): + """ + Create a set of GPT datasets from a MMapIndexedDataset, + each containing approximately the requested proportion of the total tokens. + """ + probabilities = normalize_probabilities(list(phase_split.values())) + splits = [round(x) for x in padded_cumsum(probabilities) * len(dataset)] + return SamplableSplitDataset[cls]( + f"{dataset.name}_split", + { + phase: cls(f"{dataset.name}_{phase.value}", dataset, split_begin, split_end) + for phase, split_begin, split_end in zip(phase_split, splits[:-1], splits[1:]) + }, + ) + + +class ConcatenatedIndexedDataset(IndexedDataset): + + def __init__( + self, + name: str, + datasets: list[IndexedDataset], + ): + self._name = name + self._datasets = datasets + sizes = [len(dataset) for dataset in self._datasets] + self._dataset_splits = padded_cumsum(sizes) + + def __len__(self) -> int: + return self._dataset_splits[-1] + + def get(self, index: int, *args, **kwargs): + dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get(index - self._dataset_splits[dataset], *args, **kwargs) + + @property + def name(self) -> str: + return self._name diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index 48e7041e..bbfa4b21 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -34,7 +34,7 @@ def run(self): dataset = GPTMemmapDataset("dataset", prefix) dataset_dict = { "prefix": str(prefix.relative_to(self.directory)), - "num_documents": dataset.num_documents, + "num_documents": len(dataset), "num_tokens": dataset.num_tokens, } if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: From 7011ca3b64f1c078248f6905574ae976d5d9c8b2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 9 Jan 2025 15:32:14 -0500 Subject: [PATCH 07/19] fix --- fast_llm/data/dataset/gpt/config.py | 6 ++++-- fast_llm/data/dataset/gpt/memmap.py | 1 - fast_llm/data/dataset/indexed.py | 13 +++++++++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 48c704c8..e1c3144c 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -168,7 +168,7 @@ class GPTConcatenatedDatasetConfig(GPTDatasetConfig, SamplableDatasetConfig, typ ) def build(self, data: "GPTData") -> "GPTConcatenatedDataset": - from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset + from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset return GPTConcatenatedDataset(self.name, [dataset.build(data) for dataset in self.datasets]) @@ -198,7 +198,9 @@ def build_split( self, data: "GPTData", default_phase: PhaseType = PhaseType.training, - ) -> SamplableSplitDataset: + ) -> "SamplableSplitDataset[GPTDatasetSlice]": + from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice + return GPTDatasetSlice.from_splits(self.dataset.build(data), self.ratios) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c1b0c562..0ee0aa50 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -75,7 +75,6 @@ def get(self, idx, offset=0, length=None): def name(self): return self._name - @property def __len__(self) -> int: return self._num_documents diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 9ec73663..eea22877 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -52,16 +52,21 @@ def __init__( except Exception as e: raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e - def __getitem__(self, index: int): + def get(self, document: int, offset: int = 0, length: int | None = None): """ - Get the sample (document) with the given index (in the split dataset). + Get the sample (document) with the given index (in the dataset slice), + optionally sub-sampled to a specific offset (starting point) and maximum length + (end = min(offset + length, sample_length). """ - return self.get(index) + return self._dataset.get(document + self._begin, offset, length) - @property def __len__(self): return self._end - self._begin + @property + def name(self): + return self._name + @classmethod def from_splits(cls, dataset: IndexedDataset, phase_split: dict[PhaseType, float]): """ From 95747150db7d560b5c29e857b616d5ecd7635df7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 10 Jan 2025 13:18:18 -0500 Subject: [PATCH 08/19] Modularize fim, decouple data from dataset, basic tests, misc --- fast_llm/data/data/abstract.py | 10 +- fast_llm/data/data/config.py | 13 +- fast_llm/data/data/gpt/config.py | 6 - fast_llm/data/data/gpt/data.py | 36 ++++-- fast_llm/data/dataset/abstract.py | 12 +- fast_llm/data/dataset/blended.py | 27 +--- fast_llm/data/dataset/config.py | 44 +++---- fast_llm/data/dataset/gpt/config.py | 144 ++++++++++++++++++--- fast_llm/data/dataset/gpt/dummy.py | 7 +- fast_llm/data/dataset/gpt/{fim => }/fim.py | 40 +++++- fast_llm/data/dataset/gpt/fim/__init__.py | 0 fast_llm/data/dataset/gpt/fim/config.py | 58 --------- fast_llm/data/dataset/gpt/indexed.py | 5 +- fast_llm/data/dataset/gpt/sampled.py | 63 ++++----- fast_llm/data/dataset/indexed.py | 8 +- fast_llm/data/dataset/monitor.py | 52 ++++++++ fast_llm/engine/training/trainer.py | 6 +- fast_llm/models/custom/data.py | 13 +- tests/common.py | 31 +++-- tests/test_dataset.py | 61 +++++++++ 20 files changed, 396 insertions(+), 240 deletions(-) rename fast_llm/data/dataset/gpt/{fim => }/fim.py (87%) delete mode 100644 fast_llm/data/dataset/gpt/fim/__init__.py delete mode 100644 fast_llm/data/dataset/gpt/fim/config.py create mode 100644 fast_llm/data/dataset/monitor.py create mode 100644 tests/test_dataset.py diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index a2e419d5..1866efbb 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -1,4 +1,5 @@ import abc +import pathlib import typing from fast_llm.data.data.config import DataConfig @@ -12,15 +13,22 @@ class Data(abc.ABC): _distributed: "Distributed" _samples_per_phase: dict[PhaseType, int] + _cache_directory: pathlib.Path | None def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: self._config = config self._distributed_config = distributed_config # TODO: Improve interface - def setup(self, distributed: "Distributed", samples_per_phase: dict[PhaseType, int]): + def setup( + self, + distributed: "Distributed", + samples_per_phase: dict[PhaseType, int], + cache_directory: pathlib.Path, + ): self._distributed = distributed self._samples_per_phase = samples_per_phase + self._cache_directory = cache_directory @property def config(self): diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 3485c2e0..b10017f2 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -1,16 +1,7 @@ -import pathlib import typing -from fast_llm.config import Config, Field, check_field, config_class -from fast_llm.utils import Assert - - -@config_class -class SamplingConfig(Config): - num_samples: int = Field(default=1, desc="Number of samples to generate.", valid=check_field(Assert.gt, 0)) - seed: int = Field(default=0, desc="Random seed.") - cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.") - verbose: bool = Field(default=True, desc="Log sampling progress.") +from fast_llm.config import Config, config_class +from fast_llm.data.dataset.config import SamplingConfig @config_class() diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 37bf8752..0b120281 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -4,7 +4,6 @@ from fast_llm.data.config import MultiprocessingContext, TokenizerConfig from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.gpt.config import GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledSplitDatasetConfig -from fast_llm.data.dataset.gpt.fim.config import FimConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -30,11 +29,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - fim: FimConfig = Field( - default_factory=FimConfig, - desc="Configuration for Fill In the Middle (FIM).", - hint=FieldHint.feature, - ) data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 12396c97..b8aff67c 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -9,9 +9,10 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import PhaseSplits, SampledSplitDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.tokenizer import Tokenizer -from fast_llm.engine.config_utils.run import get_run, log_main_rank +from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import BatchConfig @@ -31,7 +32,6 @@ class GPTData(Data): _config: GPTDataConfig _tokenizer: Tokenizer | None _distributed: Distributed - _cache_directory: pathlib.Path | None _is_setup: bool = False def __init__( @@ -57,35 +57,47 @@ def vocab_size(self) -> int: def max_sequence_length(self) -> int: return self._max_sequence_length - def setup(self, distributed: Distributed, samples_per_phase: PhaseSplits[int]): + def setup( + self, + distributed: Distributed, + samples_per_phase: PhaseSplits[int], + cache_directory: pathlib.Path, + ): """ Load the datasets, and prepare or load the samplings. This may take a while and a significant amount of cpu memory. """ - super().setup(distributed, samples_per_phase) - run = get_run() + super().setup(distributed, samples_per_phase, cache_directory) log_main_rank(f"Preparing dataset. This may take several minutes.") - self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None + self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer) - if run.experiment_directory is None: + if self._cache_directory is None: + # TODO: Avoid this warnings.warn(f"Using the dataset directory for the index cache.") - self._cache_directory = None - else: - self._cache_directory = run.experiment_directory / "dataset_cache" sampling_config = PhaseSplits[GPTSamplingConfig]( { phase: GPTSamplingConfig( num_samples=samples_per_phase[phase], - sequence_length=self._max_sequence_length, seed=self._distributed_config.seed, cache_directory=self._cache_directory, verbose=True, + distributed=distributed, + sequence_length=self._max_sequence_length, + vocab_size=self._vocab_size, + tokenizer=self._tokenizer, ) for phase, num_samples in samples_per_phase.items() if num_samples > 0 } ) - self._datasets = self._config.dataset.build_split_sample(self, sampling_config) + datasets = self._config.dataset.build_split_sample(sampling_config) + self._datasets = SampledSplitDataset( + datasets.name, + { + phase: DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + for phase, dataset in datasets.items() + }, + ) self._is_setup = True @property diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 19cb384e..c6448bc6 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,8 +1,7 @@ import abc import typing -from fast_llm.data.data.abstract import Data -from fast_llm.data.data.config import DataConfig, SamplingConfig +from fast_llm.data.dataset.config import SamplingConfig from fast_llm.engine.distributed.config import PhaseType @@ -42,10 +41,9 @@ def as_split(self, default_phase: PhaseType = PhaseType.training): class SamplableDataset(Dataset): - # TODO: Move to dataset config? - _data_config_class: typing.ClassVar[type[DataConfig]] - def sample(self, config: SamplingConfig, data: Data) -> SampledDataset: + @abc.abstractmethod + def sample(self, config: SamplingConfig) -> SampledDataset: pass def as_split(self, default_phase: PhaseType = PhaseType.training) -> "SplitDataset": @@ -80,10 +78,10 @@ class SampledSplitDataset(SplitDataset[_SampledDatasetType], typing.Generic[_Sam class SamplableSplitDataset(SplitDataset[_SamplableDatasetType], typing.Generic[_SamplableDatasetType]): - def sample(self, sampling_configs: PhaseSplits[SamplingConfig], data: Data): + def sample(self, sampling_configs: PhaseSplits[SamplingConfig]): return SampledSplitDataset( f"{self.name}_sampled", - {phase: self[phase].sample(sampling_config, data) for phase, sampling_config in sampling_configs.items()}, + {phase: self[phase].sample(sampling_config) for phase, sampling_config in sampling_configs.items()}, ) diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 970e7c62..9fce891c 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -1,13 +1,11 @@ import logging import pathlib -import time import numpy as np from fast_llm.core.distributed import safe_barrier -from fast_llm.data.data.abstract import Data -from fast_llm.data.data.config import SamplingConfig from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.config import SamplingConfig from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert, normalize_probabilities @@ -35,8 +33,6 @@ def __init__( datasets: list[SampledDataset], weights: list[float], sampling_config: SamplingConfig, - # TODO: Generalize - data: Data, ): self._name = name assert len(datasets) > 0 @@ -44,7 +40,6 @@ def __init__( self._datasets = datasets self._weights = normalize_probabilities(weights) self._num_samples = sampling_config.num_samples - self._data_sample_warn_time_ms = data.config.data_sample_warn_time_ms if sampling_config.cache_directory is None: self._dataset_idx_filename, self._sample_idx_filename = None, None @@ -52,7 +47,7 @@ def __init__( sampling_config.verbose and len(self._datasets) <= 20 ) else: - group = data.distributed.world_group + group = sampling_config.distributed.world_group self._dataset_idx_filename = sampling_config.cache_directory / (self._name + "_blending_dataset_idx.npy") self._sample_idx_filename = sampling_config.cache_directory / (self._name + "_blending_sample_idx.npy") @@ -141,23 +136,7 @@ def _build_blending_indices(self, verbose: bool): return dataset_index, dataset_sample_index def __getitem__(self, idx): - start_time = time.perf_counter() - dataset_index = self._dataset_index[idx] - dataset = self._datasets[dataset_index] - sample_index = self._sample_index[idx] - try: - sample = dataset[sample_index] - sample_time = (time.perf_counter() - start_time) * 1000 - if sample_time > self._data_sample_warn_time_ms: - logger.warning( - f"Sample {sample_index} from dataset {dataset_index} ({dataset.name})" - f" took {sample_time:,.2f} ms to load" - ) - return sample - - except Exception: - logger.error(f"Failed to get sample {sample_index} from dataset {dataset_index} ({dataset.name})") - raise + return self._datasets[self._dataset_index[idx]][self._sample_index[idx]] @property def name(self): diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 6fe9314f..682f67ca 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -1,9 +1,10 @@ +import dataclasses import functools import math +import pathlib +import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.data.abstract import Data -from fast_llm.data.data.config import SamplingConfig from fast_llm.data.dataset.abstract import ( PhaseSplits, SamplableDataset, @@ -20,12 +21,20 @@ class DatasetConfig(Config): _abstract = True +@dataclasses.dataclass +class SamplingConfig: + num_samples: int = 1 + seed: int = 0 + cache_directory: pathlib.Path | None = None + verbose: bool = True + distributed: typing.Any = None + + @config_class() class SampledSplitDatasetConfig(DatasetConfig): def build_split_sample( self, - data: Data, config: PhaseSplits[SamplingConfig], default_phase: PhaseType = PhaseType.training, ) -> SampledSplitDataset: @@ -49,16 +58,15 @@ class SampledDatasetConfig(SampledSplitDatasetConfig): (See `fast_llm.data.sampler.Sampler`.) """ - def build_sample(self, data: Data, config: SamplingConfig) -> SampledDataset: + def build_sample(self, config: SamplingConfig) -> SampledDataset: raise NotImplementedError() def build_split_sample( self, - data: Data, config: PhaseSplits[SamplingConfig], default_phase: PhaseType = PhaseType.training, ) -> SampledSplitDataset: - dataset = self.build_sample(data, config[default_phase]) + dataset = self.build_sample(config[default_phase]) return SampledSplitDataset(dataset.name, {default_phase: dataset}) @property @@ -75,23 +83,20 @@ class SamplableSplitDatasetConfig(SampledSplitDatasetConfig): def build_split( self, - data: Data, default_phase: PhaseType = PhaseType.training, ) -> SamplableSplitDataset: raise NotImplementedError() def build_split_sample( self, - data: Data, config: PhaseSplits[SamplingConfig], default_phase: PhaseType = PhaseType.training, ) -> SampledSplitDataset: - split_dataset = self.build_split(data) + split_dataset = self.build_split(default_phase) # TODO: Name - # TODO: Arg order not matching with dataset return SampledSplitDataset( "dataset", - {phase: split_dataset[phase].sample(phase_config, data) for phase, phase_config in config.items()}, + {phase: split_dataset[phase].sample(phase_config) for phase, phase_config in config.items()}, ) @property @@ -105,18 +110,17 @@ def split(self): @config_class() class SamplableDatasetConfig(SampledDatasetConfig, SamplableSplitDatasetConfig): - def build(self, data: Data) -> SamplableDataset: + def build(self) -> SamplableDataset: raise NotImplementedError() - def build_sample(self, data: Data, config: SamplingConfig) -> SampledDataset: - return self.build(data).sample(config, data) + def build_sample(self, config: SamplingConfig) -> SampledDataset: + return self.build().sample(config) def build_split( self, - data: Data, default_phase: PhaseType = PhaseType.training, ) -> SamplableSplitDataset: - dataset = self.build(data) + dataset = self.build() return SamplableSplitDataset(dataset.name, {default_phase: dataset}) @property @@ -157,7 +161,6 @@ def split(self): def build_sample( self, - data: "Data", config: SamplingConfig, ) -> SampledDataset: from fast_llm.data.dataset.blended import BlendedDataset @@ -167,7 +170,6 @@ def build_sample( # Build and sample the datasets. sampled_datasets = [ dataset.build_sample( - data, # Blending is deterministic and the error will never be higher than 1. config.to_copy({"num_samples": math.ceil(weight * config.num_samples) + 1}), ) @@ -179,12 +181,10 @@ def build_sample( sampled_datasets, self.weights, config, - data, ) def build_split_sample( self, - data: "Data", config: PhaseSplits[SamplingConfig], default_phase: PhaseType = PhaseType.training, ) -> SampledSplitDataset: @@ -192,12 +192,11 @@ def build_split_sample( if not self.split: # Take the base class shortcut using build_sample if it's available. - return super().build_split_sample(data, config, default_phase) + return super().build_split_sample(config, default_phase) # Build, sample and split the datasets. sampled_datasets = [ dataset.build_split_sample( - data, # Blending is deterministic and the error will never be higher than 1. PhaseSplits[SamplingConfig]( { @@ -219,7 +218,6 @@ def build_split_sample( [dataset[phase] for dataset in sampled_datasets], self.weights, phase_config, - data, ) for phase, phase_config in config.items() }, diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index e1c3144c..247145b8 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,12 +1,12 @@ +import dataclasses import enum import functools import json import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class -from fast_llm.data.data.config import SamplingConfig -from fast_llm.data.dataset.abstract import PhaseSplits, SamplableSplitDataset, SampledSplitDataset +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.data.dataset.abstract import PhaseSplits, SamplableSplitDataset, SampledDataset, SampledSplitDataset from fast_llm.data.dataset.config import ( BlendedDatasetConfig, DatasetConfig, @@ -14,20 +14,23 @@ SamplableSplitDatasetConfig, SampledDatasetConfig, SampledSplitDatasetConfig, + SamplingConfig, ) from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, Registry, normalize_probabilities if typing.TYPE_CHECKING: - from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.dummy import GPTDummySampledDataset from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -@config_class +@dataclasses.dataclass class GPTSamplingConfig(SamplingConfig): - sequence_length: int = Field(default=None, desc="Number of token in each sample.") + # TODO: Sort these out + sequence_length: int | None = None + vocab_size: int | None = None + tokenizer: typing.Any = None @config_class() @@ -96,7 +99,7 @@ class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig, @config_class() class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig): - def build(self, data: "GPTData") -> "GPTIndexedDataset": + def build(self) -> "GPTIndexedDataset": raise NotImplementedError() @@ -112,7 +115,6 @@ class GPTDummyDatasetConfig(GPTSampledSplitDatasetConfig, type_="dummy"): def build_split_sample( self, - data: "GPTData", config: PhaseSplits[GPTSamplingConfig], default_phase: PhaseType = PhaseType.training, ) -> "SampledSplitDataset[GPTDummySampledDataset]": @@ -121,9 +123,9 @@ def build_split_sample( return SampledSplitDataset[GPTDummySampledDataset]( self.name, { - phase: GPTDummyDataset(f"{self.name}_{phase.value}", data.max_sequence_length, data.vocab_size).sample( - phase_config, data - ) + phase: GPTDummyDataset( + f"{self.name}_{phase.value}", phase_config.sequence_length, phase_config.vocab_size + ).sample(phase_config) for phase, phase_config in config.items() }, ) @@ -139,7 +141,7 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig, type_="memmap"): hint=FieldHint.core, ) - def build(self, data: "GPTData") -> "GPTMemmapDataset": + def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path) @@ -167,10 +169,10 @@ class GPTConcatenatedDatasetConfig(GPTDatasetConfig, SamplableDatasetConfig, typ valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)), ) - def build(self, data: "GPTData") -> "GPTConcatenatedDataset": + def build(self) -> "GPTConcatenatedDataset": from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset - return GPTConcatenatedDataset(self.name, [dataset.build(data) for dataset in self.datasets]) + return GPTConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets]) @config_class() @@ -196,12 +198,11 @@ class GPTSplitDatasetConfig(GPTSamplableSplitDatasetConfig, type_="split"): def build_split( self, - data: "GPTData", default_phase: PhaseType = PhaseType.training, ) -> "SamplableSplitDataset[GPTDatasetSlice]": from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice - return GPTDatasetSlice.from_splits(self.dataset.build(data), self.ratios) + return GPTDatasetSlice.from_splits(self.dataset.build(), self.ratios) @config_class() @@ -229,6 +230,109 @@ def _validate_path(value): return [value] if isinstance(value, str) else value +@config_class() +class FimConfig(Config): + """ + Configuration for FIM. + """ + + dataset: GPTSampledDatasetConfig = Field( + default=None, + desc="The dataset to wrap with fim.", + hint=FieldHint.core, + ) + rate: float = Field( + default=0.5, + desc="FIM rate for each sample.", + hint=FieldHint.core, + valid=check_field(Assert.in_range_incl, 0, 1), + ) + max_middle_len: int | None = Field( + default=None, + desc="Maximum length of the middle segment in FIM.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + split_sample: str | None = Field( + default=None, + desc="Split samples on this token and permute each fragment separately.", + hint=FieldHint.feature, + ) + fragment_rate: float = Field( + default=0.0, + desc="FIM rate for each fragment when using fim_split_sample.", + hint=FieldHint.feature, + valid=check_field(Assert.in_range_incl, 0, 1), + ) + ignore_prefix: str | None = Field( + default=None, + desc="Do not apply FIM to fragments that start with this prefix.", + hint=FieldHint.feature, + ) + spm_rate: float = Field( + default=0.5, + desc="TODO.", + hint=FieldHint.feature, + valid=check_field(Assert.in_range_incl, 0, 1), + ) + truncate_or_pad: bool = Field( + default=False, + desc="TODO.", + hint=FieldHint.feature, + ) + + +@config_class() +class FimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig, type_="fim"): + """ + Configuration for FIM. + """ + + dataset: GPTSampledDatasetConfig = Field( + default=None, + desc="The dataset to wrap with fim.", + hint=FieldHint.core, + ) + + @property + def split(self): + return self.dataset.split + + def build_sample( + self, + config: GPTSamplingConfig, + ) -> SampledDataset: + from fast_llm.data.dataset.gpt.fim import FimDataset + + assert not self.split + return FimDataset(self, self.dataset.build_sample(config), config) + + def build_split_sample( + self, + config: PhaseSplits[GPTSamplingConfig], + default_phase: PhaseType = PhaseType.training, + ) -> SampledSplitDataset: + from fast_llm.data.dataset.gpt.fim import FimDataset + + if not self.split: + # Take the base class shortcut using build_sample if it's available. + return super().build_split_sample(config, default_phase) + + # Build, sample and split the datasets. + sampled_datasets = self.dataset.build_split_sample( + # Blending is deterministic and the error will never be higher than 1. + PhaseSplits[SamplingConfig]({phase: phase_config for phase, phase_config in config.items()}), + default_phase, + ) + + # Blend the datasets for each phase. + return SampledSplitDataset[FimDataset]( + # TODO: Name + "fim", + {phase: FimDataset(self, sampled_datasets[phase], phase_config) for phase, phase_config in config.items()}, + ) + + @config_class() class GPTLegacyConfig(Config): @classmethod @@ -259,6 +363,11 @@ def _from_dict( hint=FieldHint.deprecated, valid=_validate_path, ) + fim: FimConfig = Field( + default_factory=FimConfig, + desc="Configuration for Fill In the Middle (FIM).", + hint=FieldHint.feature, + ) @config_class() @@ -267,7 +376,6 @@ class GPTLegacyDatasetConfig(GPTSampledSplitDatasetConfig, GPTLegacyConfig, type def build_split_sample( self, - data: "GPTData", config: PhaseSplits[GPTSamplingConfig], default_phase: PhaseType = PhaseType.training, ) -> SampledSplitDataset: @@ -321,4 +429,4 @@ def build_split_sample( else dataset_configs[0] ) - return dataset_config.build_split_sample(data, config, default_phase) + return dataset_config.build_split_sample(config, default_phase) diff --git a/fast_llm/data/dataset/gpt/dummy.py b/fast_llm/data/dataset/gpt/dummy.py index 68a864c8..f6eb5d51 100644 --- a/fast_llm/data/dataset/gpt/dummy.py +++ b/fast_llm/data/dataset/gpt/dummy.py @@ -1,13 +1,8 @@ -import typing - import numpy as np from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig -if typing.TYPE_CHECKING: - from fast_llm.data.data.gpt.data import GPTData - class GPTDummyDataset(SamplableDataset): """ @@ -18,7 +13,7 @@ def __init__(self, name: str, sequence_length: int, vocab_size: int): self._dummy_sample = np.random.randint(0, vocab_size, size=(sequence_length + 1,), dtype=np.int64) self._name = name - def sample(self, config: GPTSamplingConfig, data: "GPTData"): + def sample(self, config: GPTSamplingConfig): return GPTDummySampledDataset(self, config) def get(self): diff --git a/fast_llm/data/dataset/gpt/fim/fim.py b/fast_llm/data/dataset/gpt/fim.py similarity index 87% rename from fast_llm/data/dataset/gpt/fim/fim.py rename to fast_llm/data/dataset/gpt/fim.py index 3afa53ba..1e325113 100644 --- a/fast_llm/data/dataset/gpt/fim/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -1,18 +1,31 @@ import numpy as np -from fast_llm.data.dataset.gpt.fim.config import FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX, FimConfig -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingConfig +from fast_llm.engine.distributed.config import MAX_SEED +FIM_PREFIX = "" +FIM_MIDDLE = "" +FIM_PAD = "" +FIM_SUFFIX = "" -class Fim: + +class FimDataset(SampledDataset): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py """ - def __init__(self, config: FimConfig, tokenizer: Tokenizer): - self._config = config.validate() - self._tokenizer = tokenizer + def __init__( + self, + config: FimConfig, + dataset: SampledDataset, + sampling_config: GPTSamplingConfig, + ): + self._config = config + self._dataset = dataset + self._sampling_config = sampling_config + self._tokenizer = sampling_config.tokenizer self._suffix_tok_id, self._prefix_tok_id, self._middle_tok_id, self._pad_tok_id = ( self._tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD] ) @@ -20,7 +33,20 @@ def __init__(self, config: FimConfig, tokenizer: Tokenizer): self._tokenizer.vocab[self._config.split_sample] if self._config.split_sample is not None else None ) - def __call__(self, sample, np_rng): + def __len__(self): + return len(self._dataset) + + def __getitem__(self, idx): + sample = self._fim( + self._dataset[idx], np.random.RandomState(seed=(self._sampling_config.seed + idx) % MAX_SEED) + ) + return sample + + @property + def name(self): + return f"{self._dataset.name}_fim" + + def _fim(self, sample, np_rng): # FIM # TODO: permute segments in sample_list, before concatenating. sample_len = sample.shape[0] diff --git a/fast_llm/data/dataset/gpt/fim/__init__.py b/fast_llm/data/dataset/gpt/fim/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fast_llm/data/dataset/gpt/fim/config.py b/fast_llm/data/dataset/gpt/fim/config.py deleted file mode 100644 index d693ad86..00000000 --- a/fast_llm/data/dataset/gpt/fim/config.py +++ /dev/null @@ -1,58 +0,0 @@ -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.utils import Assert - -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_PAD = "" -FIM_SUFFIX = "" - - -@config_class() -class FimConfig(Config): - """ - Configuration for FIM. - """ - - rate: float = Field( - default=0.0, - desc="FIM rate for each sample.", - hint=FieldHint.core, - valid=check_field(Assert.in_range_incl, 0, 1), - ) - max_middle_len: int | None = Field( - default=None, - desc="Maximum length of the middle segment in FIM.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - split_sample: str | None = Field( - default=None, - desc="Split samples on this token and permute each fragment separately.", - hint=FieldHint.feature, - ) - fragment_rate: float = Field( - default=0.0, - desc="FIM rate for each fragment when using fim_split_sample.", - hint=FieldHint.feature, - valid=check_field(Assert.in_range_incl, 0, 1), - ) - ignore_prefix: str | None = Field( - default=None, - desc="Do not apply FIM to fragments that start with this prefix.", - hint=FieldHint.feature, - ) - spm_rate: float = Field( - default=0.5, - desc="TODO.", - hint=FieldHint.feature, - valid=check_field(Assert.in_range_incl, 0, 1), - ) - truncate_or_pad: bool = Field( - default=False, - desc="TODO.", - hint=FieldHint.feature, - ) - - def _validate(self): - super()._validate() - Assert.in_range_incl(self.rate, 0, 1) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index bf1fface..fe0ef6a3 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -3,7 +3,6 @@ import numpy as np -from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.data.dataset.indexed import ConcatenatedIndexedDataset, IndexedDataset, IndexedDatasetSlice @@ -34,10 +33,10 @@ def get_document_sizes(self) -> np.ndarray: and derived classes should try to avoid holding the whole array im memory. """ - def sample(self, config: GPTSamplingConfig, data: GPTData) -> "GPTSampledIndexedDataset": + def sample(self, config: GPTSamplingConfig) -> "GPTSampledIndexedDataset": from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - return GPTSampledIndexedDataset(self, config, data) + return GPTSampledIndexedDataset(self, config) class GPTDatasetSlice(IndexedDatasetSlice, GPTIndexedDataset): diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index c47923d3..b28882a1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -3,13 +3,10 @@ import numpy as np from fast_llm.core.distributed import safe_barrier -from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig -from fast_llm.data.dataset.gpt.fim.fim import Fim from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.utils import Assert try: @@ -32,26 +29,20 @@ def __init__( self, indexed_dataset: GPTIndexedDataset, sampling_config: GPTSamplingConfig, - data: GPTData, ): assert isinstance(sampling_config, GPTSamplingConfig) - assert isinstance(data, GPTData) self._indexed_dataset = indexed_dataset - self._sampling_config = sampling_config - if data.config.fim.rate > 0: - assert data.tokenizer is not None - self._fim = Fim(data.config.fim, data.tokenizer) - else: - self._fim = None - - cache_prefix = f"{self.name}_ns_{self._sampling_config.num_samples}_sl_{self._sampling_config.sequence_length}_s_{self._sampling_config.seed}" + cache_prefix = ( + f"{self.name}_ns_{sampling_config.num_samples}_sl_{sampling_config.sequence_length}" + f"_s_{sampling_config.seed}" + ) # TODO: Any way to combine into a single file? (Memmap is harder) - self._doc_idx_filename = self._sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy") - self._sample_idx_filename = self._sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy") - self._shuffle_idx_filename = self._sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy") + self._doc_idx_filename = sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy") + self._sample_idx_filename = sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy") + self._shuffle_idx_filename = sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy") - group = data.distributed.world_group + group = sampling_config.distributed.world_group # Build the indexed mapping if it doesn't exist. # TODO: This only works if the dataset location is accessible by all job. if (group is None or group.rank() == 0) and not ( @@ -59,35 +50,33 @@ def __init__( and self._sample_idx_filename.is_file() and self._shuffle_idx_filename.is_file() ): - if self._sampling_config.verbose: + if sampling_config.verbose: log_main_rank(" > Building the index map on rank 0 ...") - doc_idx, sample_idx, shuffle_idx = self._sample() - self._sampling_config.cache_directory.mkdir(parents=True, exist_ok=True) + doc_idx, sample_idx, shuffle_idx = self._sample(sampling_config) + sampling_config.cache_directory.mkdir(parents=True, exist_ok=True) np.save(self._doc_idx_filename, doc_idx) np.save(self._sample_idx_filename, sample_idx) np.save(self._shuffle_idx_filename, shuffle_idx) safe_barrier(group, self._indexed_dataset.name) - self._load_mappings(self._sampling_config.verbose) + self._load_mappings(sampling_config.verbose) - def _sample(self): + def _sample(self, sampling_config: GPTSamplingConfig): """ Create a `GPTSampledDataset` with the requested parameters. """ document_sizes = self._indexed_dataset.get_document_sizes() num_documents = len(document_sizes) num_tokens = document_sizes.sum() - np_rng = np.random.RandomState(seed=self._sampling_config.seed) + np_rng = np.random.RandomState(seed=sampling_config.seed) - num_epochs = math.ceil( - (self._sampling_config.sequence_length * self._sampling_config.num_samples + 1) / num_tokens - ) + num_epochs = math.ceil((sampling_config.sequence_length * sampling_config.num_samples + 1) / num_tokens) # For the last epoch, decide whether include the entire epoch # in the global shuffle or not. # Get the number of samples for the last epoch - main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // self._sampling_config.sequence_length - last_epoch_samples = self._sampling_config.num_samples - main_epochs_samples - samples_per_epoch = (num_tokens - 1) // self._sampling_config.sequence_length + main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // sampling_config.sequence_length + last_epoch_samples = sampling_config.num_samples - main_epochs_samples + samples_per_epoch = (num_tokens - 1) // sampling_config.sequence_length # If we have less than 80% of the samples for the last epoch, separate out the epoch and treat it differently. # Note: the 80% number is just based on common sense and can be adjusted if needed. separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch @@ -106,10 +95,10 @@ def _sample(self): sample_idx = build_sample_idx( document_sizes, doc_idx, - self._sampling_config.sequence_length, + sampling_config.sequence_length, num_epochs, num_tokens, - self._sampling_config.verbose, + sampling_config.verbose, ) # shuffle-idx. @@ -126,15 +115,13 @@ def _sample(self): else: np_rng.shuffle(shuffle_idx) - Assert.geq(len(shuffle_idx), self._sampling_config.num_samples) + Assert.geq(len(shuffle_idx), sampling_config.num_samples) # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. - return doc_idx, sample_idx, shuffle_idx[: self._sampling_config.num_samples] + return doc_idx, sample_idx, shuffle_idx[: sampling_config.num_samples] def __getstate__(self): return ( self._indexed_dataset, - self._fim, - self._sampling_config.to_serialized(), self._doc_idx_filename, self._sample_idx_filename, self._shuffle_idx_filename, @@ -143,13 +130,10 @@ def __getstate__(self): def __setstate__(self, state): ( self._indexed_dataset, - self._fim, - sampling_config, self._doc_idx_filename, self._sample_idx_filename, self._shuffle_idx_filename, ) = state - self._sampling_config = GPTSamplingConfig.from_dict(sampling_config) self._load_mappings(False) def _load_mappings(self, verbose): @@ -193,9 +177,6 @@ def __getitem__(self, idx): sample_list, dtype=np.int64, ) - if self._fim is not None: - sample = self._fim(sample, np.random.RandomState(seed=(self._sampling_config.seed + idx) % MAX_SEED)) - return sample @property diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index eea22877..7979586a 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -2,9 +2,7 @@ import numpy as np -from fast_llm.data.data.abstract import Data -from fast_llm.data.data.config import SamplingConfig -from fast_llm.data.dataset.abstract import SamplableDataset, SamplableSplitDataset, SampledDataset +from fast_llm.data.dataset.abstract import SamplableDataset, SamplableSplitDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -25,10 +23,6 @@ def __len__(self) -> int: Number of samples in the dataset. """ - @abc.abstractmethod - def sample(self, config: SamplingConfig, data: Data) -> SampledDataset: - pass - class IndexedDatasetSlice(IndexedDataset): diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py new file mode 100644 index 00000000..7892928f --- /dev/null +++ b/fast_llm/data/dataset/monitor.py @@ -0,0 +1,52 @@ +import logging +import time + +from fast_llm.data.dataset.abstract import SampledDataset + +try: + from fast_llm.csrc.data import build_blending_indices # noqa + + _extension_available = True +except ImportError: + _extension_available = False + +logger = logging.getLogger(__name__) + + +class DatasetMonitor(SampledDataset): + """ + A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. + The sampling order of each dataset is respected, but there is no strict guarantee + on the total number of samples from each dataset. + The sampling exactly matches Megatron-LM with matching parameters. + """ + + def __init__( + self, + dataset: SampledDataset, + data_sample_warn_time_ms: float, + ): + self._dataset = dataset + self._data_sample_warn_time_ms = data_sample_warn_time_ms + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, idx): + start_time = time.perf_counter() + try: + sample = self._dataset[idx] + sample_time = (time.perf_counter() - start_time) * 1000 + if sample_time > self._data_sample_warn_time_ms: + logger.warning( + f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" + ) + return sample + + except Exception: + logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + raise + + @property + def name(self): + return self._dataset.name diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index b484ea7b..d2bccdf4 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -108,7 +108,11 @@ def setup(self, distributed: Distributed, run: Run): # Setup the datasets. log_main_rank("Preparing datasets...") - self._data.setup(distributed, self._samples_per_split) + self._data.setup( + distributed, + self._samples_per_split, + None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", + ) @abc.abstractmethod def _get_data(self) -> Data: diff --git a/fast_llm/models/custom/data.py b/fast_llm/models/custom/data.py index cbf53d48..285506a7 100644 --- a/fast_llm/models/custom/data.py +++ b/fast_llm/models/custom/data.py @@ -1,5 +1,9 @@ +import pathlib + from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.abstract import PhaseSplits from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.custom.config import CustomDataConfig @@ -15,9 +19,14 @@ def __init__( # TODO: Adjust or reimplement. super().__init__(config, distributed_config, vocab_size, max_sequence_length) - def setup(self, distributed, samples_per_phase): + def setup( + self, + distributed: Distributed, + samples_per_phase: PhaseSplits[int], + cache_directory: pathlib.Path, + ): # TODO: Adjust or reimplement. - return super().setup(distributed, samples_per_phase) + return super().setup(distributed, samples_per_phase, cache_directory) def get_iterator( self, diff --git a/tests/common.py b/tests/common.py index e7b8906f..b14bad70 100644 --- a/tests/common.py +++ b/tests/common.py @@ -17,7 +17,6 @@ MixtralGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs @@ -34,10 +33,13 @@ ARTIFACT_PATH = "runs/0/artifacts" -TOKENIZER_PATH = TEST_RESULTS_PATH / "data" / "tokenizer" +TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -DATASET_PREFIX = TEST_RESULTS_PATH / "data" / "dataset/data" +DATASET_PREFIX = TEST_RESULTS_PATH / "dataset" / "common" +TEST_VOCAB_SIZE = 8192 +# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% +TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" CONFIG_BASE_FAST_LLM = [ "training.logs.interval=1", @@ -47,7 +49,7 @@ "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", "model.base_model.transformer.init_method_std=0.022", - "model.base_model.vocab_size=8192", + f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", @@ -89,7 +91,7 @@ "--valid-num-workers=0", "--tokenizer-type=NullTokenizer", # Megatron messes with the vocab size, so we have to subtract 1. - "--vocab-size=8191", + f"--vocab-size={TEST_VOCAB_SIZE-1}", f"--data-path={DATASET_PREFIX}", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) @@ -153,7 +155,7 @@ _CONFIGS = { "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), - "sc1": ("gpt", HuggingfaceGPTModelForCausalLM, CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), + "sc1": ("gpt", CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), "starcoder2": ( "gpt", CONFIG_SC2_FAST_LLM, @@ -198,21 +200,24 @@ requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def get_test_data(): +def get_test_dataset( + prefix=DATASET_PREFIX, seed=1234, num_tokens=1000000, characters=TEST_CHARACTERS, vocab_size=TEST_VOCAB_SIZE +): if not TOKENIZER_FILE.is_file(): import transformers transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) - if not (DATASET_PREFIX.with_suffix(".idx").is_file() and DATASET_PREFIX.with_suffix(".bin").is_file()): + if not (prefix.with_suffix(".idx").is_file() and prefix.with_suffix(".bin").is_file()): import transformers - characters = (string.ascii_lowercase) * 5 + " " * 30 + "\n" - documents = "".join(random.Random(1234).choices(characters, k=1000000)).splitlines() + documents = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) - documents = [np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % 8192 for document in documents] - GPTMemmapDataset.write_dataset(DATASET_PREFIX, documents) + documents = [ + np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size for document in documents + ] + GPTMemmapDataset.write_dataset(prefix, documents) def run_test_script( @@ -269,7 +274,7 @@ def run_test_script( if skip: print("Reusing existing run.") else: - get_test_data() + get_test_dataset() if num_gpus == 1 and not is_megatron: CliTrainingConfig.parse_and_run(script) else: diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 00000000..25f3d0e2 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,61 @@ +import pathlib + +import numpy as np + +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.abstract import PhaseSplits +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.utils import Assert +from tests.common import DATASET_PREFIX, TEST_RESULTS_PATH, TEST_VOCAB_SIZE, get_test_dataset + + +def get_test_data( + config: dict, + samples_per_phase: dict[PhaseType, int], + cache_directory: pathlib.Path | None = None, + sequence_length: int = 512, +): + distributed_config = DistributedConfig() + distributed = Distributed(distributed_config, use_cpu=True) + data = GPTData(GPTDataConfig.from_dict(config), distributed_config, TEST_VOCAB_SIZE, sequence_length) + data.setup(distributed, PhaseSplits[int](samples_per_phase), cache_directory) + return data + + +DATASET_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" + + +def get_test_datasets( + config: dict, + samples_per_phase: dict[PhaseType, int], + cache_directory: pathlib.Path | None = None, + sequence_length: int = 512, +): + return get_test_data({"dataset": config}, samples_per_phase, cache_directory, sequence_length)._datasets + + +def test_dummy_dataset(): + datasets = get_test_datasets( + {"type": "dummy"}, + {PhaseType.training: 7, PhaseType.test: 4}, + ) + Assert.eq(datasets.keys(), {PhaseType.training, PhaseType.test}) + train = datasets[PhaseType.training] + Assert.eq(len(train), 7) + assert all(np.all(train[i] == train._dataset._dummy_sample) for i in range(7)) + test = datasets[PhaseType.test] + Assert.eq(len(test), 4) + assert all(np.all(test[i] == test._dataset._dummy_sample) for i in range(4)) + + +def test_memmap_dataset(): + get_test_dataset() + dataset = get_test_datasets( + {"type": "memmap", "path": DATASET_PREFIX}, + {PhaseType.training: 1}, + sequence_length=5, + )[PhaseType.training] + Assert.eq(len(dataset), 5) + raise AssertionError() From 5532b9722f1b8f70eff44841f79221f70dd88975 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sat, 11 Jan 2025 14:11:39 -0500 Subject: [PATCH 09/19] Make tests pass --- fast_llm/data/dataset/abstract.py | 8 +++++--- fast_llm/data/dataset/config.py | 6 ++++-- fast_llm/data/dataset/gpt/config.py | 13 +++++++------ tests/test_dataset.py | 4 ++++ 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index c6448bc6..e69f819e 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,9 +1,11 @@ import abc import typing -from fast_llm.data.dataset.config import SamplingConfig from fast_llm.engine.distributed.config import PhaseType +if typing.TYPE_CHECKING: + from fast_llm.data.dataset.config import SamplingConfig + class Dataset(abc.ABC): """ @@ -43,7 +45,7 @@ def as_split(self, default_phase: PhaseType = PhaseType.training): class SamplableDataset(Dataset): @abc.abstractmethod - def sample(self, config: SamplingConfig) -> SampledDataset: + def sample(self, config: "SamplingConfig") -> SampledDataset: pass def as_split(self, default_phase: PhaseType = PhaseType.training) -> "SplitDataset": @@ -78,7 +80,7 @@ class SampledSplitDataset(SplitDataset[_SampledDatasetType], typing.Generic[_Sam class SamplableSplitDataset(SplitDataset[_SamplableDatasetType], typing.Generic[_SamplableDatasetType]): - def sample(self, sampling_configs: PhaseSplits[SamplingConfig]): + def sample(self, sampling_configs: "PhaseSplits[SamplingConfig]"): return SampledSplitDataset( f"{self.name}_sampled", {phase: self[phase].sample(sampling_config) for phase, sampling_config in sampling_configs.items()}, diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 682f67ca..afbf2b9f 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -171,7 +171,7 @@ def build_sample( sampled_datasets = [ dataset.build_sample( # Blending is deterministic and the error will never be higher than 1. - config.to_copy({"num_samples": math.ceil(weight * config.num_samples) + 1}), + dataclasses.replace(config, num_samples=math.ceil(weight * config.num_samples) + 1), ) for dataset, weight in zip(self.datasets, self.weights, strict=True) ] @@ -200,7 +200,9 @@ def build_split_sample( # Blending is deterministic and the error will never be higher than 1. PhaseSplits[SamplingConfig]( { - phase: phase_config.to_copy({"num_samples": math.ceil(weight * phase_config.num_samples) + 1}) + phase: dataclasses.replace( + phase_config, num_samples=math.ceil(weight * phase_config.num_samples) + 1 + ) for phase, phase_config in config.items() } ), diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 247145b8..e40731c5 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -236,13 +236,9 @@ class FimConfig(Config): Configuration for FIM. """ - dataset: GPTSampledDatasetConfig = Field( - default=None, - desc="The dataset to wrap with fim.", - hint=FieldHint.core, - ) rate: float = Field( - default=0.5, + # TODO: Use meaningful default now that fim is a wrapper? (bad for legacy config) + default=0.0, desc="FIM rate for each sample.", hint=FieldHint.core, valid=check_field(Assert.in_range_incl, 0, 1), @@ -428,5 +424,10 @@ def build_split_sample( if len(dataset_configs) > 1 else dataset_configs[0] ) + if self.fim.rate > 0: + dataset_config = FimSampledDatasetConfig.from_dict( + self.fim, + {"dataset": dataset_config}, + ) return dataset_config.build_split_sample(config, default_phase) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 25f3d0e2..ed66eb0b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -17,6 +17,7 @@ def get_test_data( cache_directory: pathlib.Path | None = None, sequence_length: int = 512, ): + # TODO: Update distributed_config = DistributedConfig() distributed = Distributed(distributed_config, use_cpu=True) data = GPTData(GPTDataConfig.from_dict(config), distributed_config, TEST_VOCAB_SIZE, sequence_length) @@ -33,10 +34,12 @@ def get_test_datasets( cache_directory: pathlib.Path | None = None, sequence_length: int = 512, ): + # TODO: Update return get_test_data({"dataset": config}, samples_per_phase, cache_directory, sequence_length)._datasets def test_dummy_dataset(): + # TODO: Update datasets = get_test_datasets( {"type": "dummy"}, {PhaseType.training: 7, PhaseType.test: 4}, @@ -51,6 +54,7 @@ def test_dummy_dataset(): def test_memmap_dataset(): + # TODO: Update get_test_dataset() dataset = get_test_datasets( {"type": "memmap", "path": DATASET_PREFIX}, From 5d5e0ab3a2537310f0db280ed64d68c456275111 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sat, 11 Jan 2025 18:53:36 -0500 Subject: [PATCH 10/19] Remove split datasets --- fast_llm/data/data/gpt/config.py | 21 +-- fast_llm/data/data/gpt/data.py | 31 ++-- fast_llm/data/dataset/abstract.py | 52 ------ fast_llm/data/dataset/config.py | 228 +++++++++++---------------- fast_llm/data/dataset/gpt/config.py | 165 +++++-------------- fast_llm/data/dataset/gpt/dummy.py | 2 +- fast_llm/data/dataset/gpt/indexed.py | 6 +- fast_llm/data/dataset/indexed.py | 27 +--- 8 files changed, 159 insertions(+), 373 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 0b120281..f1d19581 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -3,7 +3,8 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledSplitDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledDatasetConfig +from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -24,8 +25,9 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the tokenizer (for FIM).", hint=FieldHint.feature, ) - dataset: GPTSampledSplitDatasetConfig = Field( - default_factory=GPTSampledSplitDatasetConfig, + # TODO: Review field. Move closer to phase definition in training config? + datasets: dict[PhaseType, GPTSampledDatasetConfig] = Field( + default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, ) @@ -42,11 +44,12 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): ) def _validate(self): - if self.dataset.type is None: - logger.warning("Using the legacy dataset definition format." " Specify it through `data.dataset` instead.") - self.dataset = GPTLegacyDatasetConfig( - ratio=self.ratio, - format=self.format, - path=self.path, + if not self.datasets: + logger.warning( + "Using the legacy dataset definition format." " Specify it through `data.datasets` instead." ) + self.datasets = { + phase: GPTLegacyDatasetConfig.from_dict(self, strict=False) + for phase in (PhaseType.training, PhaseType.validation, PhaseType.test) + } super()._validate() diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index b8aff67c..0b63b02a 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -7,7 +7,7 @@ from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.data.dataset.abstract import PhaseSplits, SampledSplitDataset +from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator @@ -28,7 +28,7 @@ class GPTData(Data): TODO: Separate generic and GPT classes. """ - _datasets: SampledSplitDataset + _datasets: dict[PhaseType, SampledDataset] _config: GPTDataConfig _tokenizer: Tokenizer | None _distributed: Distributed @@ -60,7 +60,7 @@ def max_sequence_length(self) -> int: def setup( self, distributed: Distributed, - samples_per_phase: PhaseSplits[int], + samples_per_phase: dict[PhaseType, int], cache_directory: pathlib.Path, ): """ @@ -74,30 +74,25 @@ def setup( if self._cache_directory is None: # TODO: Avoid this warnings.warn(f"Using the dataset directory for the index cache.") - sampling_config = PhaseSplits[GPTSamplingConfig]( - { - phase: GPTSamplingConfig( + + self._datasets = {} + for phase, num_samples in samples_per_phase.items(): + if num_samples > 0: + # TODO: Do the check earlier. + assert phase in self._config.datasets + sampling_config = GPTSamplingConfig( num_samples=samples_per_phase[phase], seed=self._distributed_config.seed, cache_directory=self._cache_directory, verbose=True, distributed=distributed, + phase=phase, sequence_length=self._max_sequence_length, vocab_size=self._vocab_size, tokenizer=self._tokenizer, ) - for phase, num_samples in samples_per_phase.items() - if num_samples > 0 - } - ) - datasets = self._config.dataset.build_split_sample(sampling_config) - self._datasets = SampledSplitDataset( - datasets.name, - { - phase: DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) - for phase, dataset in datasets.items() - }, - ) + dataset = self._config.datasets[phase].build_and_sample(sampling_config) + self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) self._is_setup = True @property diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index e69f819e..c16587d1 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,8 +1,6 @@ import abc import typing -from fast_llm.engine.distributed.config import PhaseType - if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingConfig @@ -19,10 +17,6 @@ def name(self): A name for the dataset to facilitate identification and debugging. """ - @abc.abstractmethod - def as_split(self, default_phase: PhaseType = PhaseType.training): - pass - class SampledDataset(Dataset): """ @@ -38,55 +32,9 @@ def __getitem__(self, index: int): def __len__(self): pass - def as_split(self, default_phase: PhaseType = PhaseType.training): - return SplitDataset(self.name, {default_phase: self}) - class SamplableDataset(Dataset): @abc.abstractmethod def sample(self, config: "SamplingConfig") -> SampledDataset: pass - - def as_split(self, default_phase: PhaseType = PhaseType.training) -> "SplitDataset": - return SplitDataset(self.name, {default_phase: self}) - - -_SplittableType = typing.TypeVar("_SplittableType") -_DatasetType = typing.TypeVar("_DatasetType", bound=Dataset) -_SampledDatasetType = typing.TypeVar("_SampledDatasetType", bound=SampledDataset) -_SamplableDatasetType = typing.TypeVar("_SamplableDatasetType", bound=SamplableDataset) - - -class PhaseSplits(dict[PhaseType, _SplittableType], typing.Generic[_SplittableType]): - pass - - -class SplitDataset(Dataset, PhaseSplits[_DatasetType], typing.Generic[_DatasetType]): - def __init__(self, name: str, datasets: dict[PhaseType, _DatasetType]): - super().__init__(datasets) - self._name = name - - def as_split(self, default_phase: PhaseType = PhaseType.training): - return self - - @property - def name(self): - return self._name - - -class SampledSplitDataset(SplitDataset[_SampledDatasetType], typing.Generic[_SampledDatasetType]): - pass - - -class SamplableSplitDataset(SplitDataset[_SamplableDatasetType], typing.Generic[_SamplableDatasetType]): - def sample(self, sampling_configs: "PhaseSplits[SamplingConfig]"): - return SampledSplitDataset( - f"{self.name}_sampled", - {phase: self[phase].sample(sampling_config) for phase, sampling_config in sampling_configs.items()}, - ) - - -class CopySplitDataset(SamplableSplitDataset): - def __init__(self, name: str, dataset: _SplittableType, phases: list[PhaseType]): - super().__init__(name, {phase: dataset for phase in phases}) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index afbf2b9f..e842b87c 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -5,136 +5,132 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.dataset.abstract import ( - PhaseSplits, - SamplableDataset, - SamplableSplitDataset, - SampledDataset, - SampledSplitDataset, -) +from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset + from fast_llm.engine.distributed.distributed import Distributed + @config_class() class DatasetConfig(Config): _abstract = True -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class SamplingConfig: - num_samples: int = 1 - seed: int = 0 - cache_directory: pathlib.Path | None = None - verbose: bool = True - distributed: typing.Any = None + # TODO: Have a separate configuration (subset?) for `build`? + num_samples: int + seed: int + cache_directory: pathlib.Path | None + verbose: bool + distributed: "Distributed" + phase: PhaseType @config_class() -class SampledSplitDatasetConfig(DatasetConfig): - - def build_split_sample( - self, - config: PhaseSplits[SamplingConfig], - default_phase: PhaseType = PhaseType.training, - ) -> SampledSplitDataset: - raise NotImplementedError() - - @property - def sampled(self): - # Generally hard-coded, but some classes allow for more flexible values. - return True - - @property - def split(self): - # Generally hard-coded, but some classes allow for more flexible values. - return True - - -@config_class() -class SampledDatasetConfig(SampledSplitDatasetConfig): +class SampledDatasetConfig(DatasetConfig): """ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. (See `fast_llm.data.sampler.Sampler`.) """ - def build_sample(self, config: SamplingConfig) -> SampledDataset: + def build_and_sample(self, config: SamplingConfig) -> SampledDataset: raise NotImplementedError() - def build_split_sample( - self, - config: PhaseSplits[SamplingConfig], - default_phase: PhaseType = PhaseType.training, - ) -> SampledSplitDataset: - dataset = self.build_sample(config[default_phase]) - return SampledSplitDataset(dataset.name, {default_phase: dataset}) - @property - def sampled(self): - return True +@config_class() +class SamplableDatasetConfig(SampledDatasetConfig): + def build(self) -> SamplableDataset: + raise NotImplementedError() - @property - def split(self): - return False + def build_and_sample(self, config: SamplingConfig) -> SampledDataset: + return self.build().sample(config) @config_class() -class SamplableSplitDatasetConfig(SampledSplitDatasetConfig): - - def build_split( - self, - default_phase: PhaseType = PhaseType.training, - ) -> SamplableSplitDataset: +class IndexedDatasetConfig(SamplableDatasetConfig): + def build(self) -> "IndexedDataset": raise NotImplementedError() - def build_split_sample( - self, - config: PhaseSplits[SamplingConfig], - default_phase: PhaseType = PhaseType.training, - ) -> SampledSplitDataset: - split_dataset = self.build_split(default_phase) - # TODO: Name - return SampledSplitDataset( - "dataset", - {phase: split_dataset[phase].sample(phase_config) for phase, phase_config in config.items()}, - ) - @property - def sampled(self): - return False +@config_class() +class ConcatenatedDatasetConfig(SamplableDatasetConfig): + """ + Concatenate multiple indexed datasets as if they were one. + TODO: Make a post-sampling version? (staged training) + """ + + _abstract = False + name: str = Field( + default="concatenated", + desc="The name of the dataset.", + hint=FieldHint.core, + ) + datasets: list[IndexedDatasetConfig] = Field( + default_factory=list, + desc="The datasets to concatenate.", + hint=FieldHint.core, + valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)), + ) + + def build(self) -> "ConcatenatedDataset": + from fast_llm.data.dataset.indexed import ConcatenatedDataset + + return self._build(ConcatenatedDataset) - @property - def split(self): - return True + def _build[T: ConcatenatedDataset](self, cls: type[T]) -> T: + return cls(self.name, [dataset.build() for dataset in self.datasets]) @config_class() -class SamplableDatasetConfig(SampledDatasetConfig, SamplableSplitDatasetConfig): - def build(self) -> SamplableDataset: - raise NotImplementedError() +class DatasetSliceConfig(SamplableDatasetConfig): + """ + Use a fraction of an indexed dataset, specified by the range (begin, end). + Typically used to subsample a dataset, or to reserve part of the dataset for validation and/or testing. + Ex. use (0.0, 0.9) for train, (0.9, 1.0) for validation for a 90%-10% split. + TODO: This is suboptimal (duplication between train/test, unnecessary sub-datasets in the case of concatenation, + leads to higher resource usage than necessary; more open files?) + """ - def build_sample(self, config: SamplingConfig) -> SampledDataset: - return self.build().sample(config) + _abstract = False + dataset: IndexedDatasetConfig = Field( + default=None, + desc="The dataset to split.", + hint=FieldHint.core, + ) + begin: float = Field( + default=0, + desc="The beginning of the dataset split, as a fraction of the total samples.", + hint=FieldHint.core, + ) + end: float = Field( + default=1, + desc="The end of the dataset split, as a fraction of the total samples.", + hint=FieldHint.core, + ) - def build_split( - self, - default_phase: PhaseType = PhaseType.training, - ) -> SamplableSplitDataset: - dataset = self.build() - return SamplableSplitDataset(dataset.name, {default_phase: dataset}) + def build(self) -> "DatasetSlice": + from fast_llm.data.dataset.indexed import DatasetSlice - @property - def sampled(self): - return False + return self._build(DatasetSlice) - @property - def split(self): - return False + def _build[T: DatasetSlice](self, cls: type[T]) -> T: + dataset = self.dataset.build() + size = len(dataset) + return cls( + f"{dataset.name}_{self.begin}_{self.end}", + dataset, + round(self.begin * size), + round(self.end * size), + ) @config_class() class BlendedDatasetConfig(SampledDatasetConfig): - # [(?sampled, ?split)] -> (sampled, ?split) + _abstract = False name: str = Field( default="blended", desc="The name of the dataset.", @@ -155,21 +151,15 @@ class BlendedDatasetConfig(SampledDatasetConfig): def __post_init__(self): Assert.eq(len(self.datasets), len(self.weights)) - @property - def split(self): - return any(dataset.split for dataset in self.datasets) - - def build_sample( + def build_and_sample( self, config: SamplingConfig, ) -> SampledDataset: from fast_llm.data.dataset.blended import BlendedDataset - assert not self.split - # Build and sample the datasets. sampled_datasets = [ - dataset.build_sample( + dataset.build_and_sample( # Blending is deterministic and the error will never be higher than 1. dataclasses.replace(config, num_samples=math.ceil(weight * config.num_samples) + 1), ) @@ -182,45 +172,3 @@ def build_sample( self.weights, config, ) - - def build_split_sample( - self, - config: PhaseSplits[SamplingConfig], - default_phase: PhaseType = PhaseType.training, - ) -> SampledSplitDataset: - from fast_llm.data.dataset.blended import BlendedDataset - - if not self.split: - # Take the base class shortcut using build_sample if it's available. - return super().build_split_sample(config, default_phase) - - # Build, sample and split the datasets. - sampled_datasets = [ - dataset.build_split_sample( - # Blending is deterministic and the error will never be higher than 1. - PhaseSplits[SamplingConfig]( - { - phase: dataclasses.replace( - phase_config, num_samples=math.ceil(weight * phase_config.num_samples) + 1 - ) - for phase, phase_config in config.items() - } - ), - default_phase, - ) - for dataset, weight in zip(self.datasets, self.weights, strict=True) - ] - - # Blend the datasets for each phase. - return SampledSplitDataset[BlendedDataset]( - self.name, - { - phase: BlendedDataset( - f"{self.name}_{phase.value}", - [dataset[phase] for dataset in sampled_datasets], - self.weights, - phase_config, - ) - for phase, phase_config in config.items() - }, - ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index e40731c5..2fecc3ef 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,36 +1,36 @@ import dataclasses import enum -import functools import json import pathlib import typing from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none -from fast_llm.data.dataset.abstract import PhaseSplits, SamplableSplitDataset, SampledDataset, SampledSplitDataset +from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import ( BlendedDatasetConfig, + ConcatenatedDatasetConfig, DatasetConfig, + DatasetSliceConfig, SamplableDatasetConfig, - SamplableSplitDatasetConfig, SampledDatasetConfig, - SampledSplitDatasetConfig, SamplingConfig, ) from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, Registry, normalize_probabilities +from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.dummy import GPTDummySampledDataset from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset + from fast_llm.data.tokenizer import Tokenizer @dataclasses.dataclass class GPTSamplingConfig(SamplingConfig): # TODO: Sort these out - sequence_length: int | None = None - vocab_size: int | None = None - tokenizer: typing.Any = None + sequence_length: int + vocab_size: int + tokenizer: "Tokenizer" @config_class() @@ -78,22 +78,12 @@ def __init_subclass__(cls, type_: str | None = None, **kwargs): @config_class() -class GPTSampledSplitDatasetConfig(SampledSplitDatasetConfig, GPTDatasetConfig): +class GPTSampledDatasetConfig(SampledDatasetConfig, GPTDatasetConfig): pass @config_class() -class GPTSampledDatasetConfig(SampledDatasetConfig, GPTSampledSplitDatasetConfig): - pass - - -@config_class() -class GPTSamplableSplitDatasetConfig(SamplableSplitDatasetConfig, GPTSampledSplitDatasetConfig): - pass - - -@config_class() -class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig, GPTSamplableSplitDatasetConfig): +class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig): pass @@ -104,8 +94,8 @@ def build(self) -> "GPTIndexedDataset": @config_class() -class GPTDummyDatasetConfig(GPTSampledSplitDatasetConfig, type_="dummy"): - # NA -> (unsampled, unsplit) +class GPTDummyDatasetConfig(GPTSampledDatasetConfig, type_="dummy"): + # TODO: Can't make it a samplable dataset because necessary info is in sampling config. _abstract = False name: str = Field( default="dummy", @@ -113,27 +103,14 @@ class GPTDummyDatasetConfig(GPTSampledSplitDatasetConfig, type_="dummy"): hint=FieldHint.core, ) - def build_split_sample( - self, - config: PhaseSplits[GPTSamplingConfig], - default_phase: PhaseType = PhaseType.training, - ) -> "SampledSplitDataset[GPTDummySampledDataset]": - from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset, GPTDummySampledDataset - - return SampledSplitDataset[GPTDummySampledDataset]( - self.name, - { - phase: GPTDummyDataset( - f"{self.name}_{phase.value}", phase_config.sequence_length, phase_config.vocab_size - ).sample(phase_config) - for phase, phase_config in config.items() - }, - ) + def build_and_sample(self, config: GPTSamplingConfig) -> "GPTDummySampledDataset": + from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset + + return GPTDummyDataset(self.name, config.sequence_length, config.vocab_size).sample(config) @config_class() class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig, type_="memmap"): - # Path -> (unsampled, unsplit) _abstract = False path: pathlib.Path = Field( default=None, @@ -148,61 +125,23 @@ def build(self) -> "GPTMemmapDataset": @config_class() -class GPTConcatenatedDatasetConfig(GPTDatasetConfig, SamplableDatasetConfig, type_="concatenated"): - """ - Concatenate multiple datasets as if they were one. - Must be done before sampling and splitting. - TODO: OK after sampling (staged training?) or splitting (Equal split for each sub-dataset, probably better? - [(unsampled, unsplit)] -> (unsampled, unsplit) - """ - +class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig, type_="concatenated"): _abstract = False - name: str = Field( - default="concatenated", - desc="The name of the dataset.", - hint=FieldHint.core, - ) - datasets: list[GPTIndexedDatasetConfig] = Field( - default_factory=list, - desc="The datasets to concatenate.", - hint=FieldHint.core, - valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)), - ) def build(self) -> "GPTConcatenatedDataset": from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset - return GPTConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets]) + return self._build(GPTConcatenatedDataset) @config_class() -class GPTSplitDatasetConfig(GPTSamplableSplitDatasetConfig, type_="split"): - """ - Split a single dataset into multiple phases. - Must be done before sampling. - TODO: Ok after sampling? - (unsampled, unsplit) -> (unsampled, split) - """ - +class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig, type_="slice"): _abstract = False - dataset: GPTIndexedDatasetConfig = Field( - default=None, - desc="The dataset to split.", - hint=FieldHint.core, - ) - ratios: dict[PhaseType, float] = Field( - default=None, - desc="The split ratio for each phase", - hint=FieldHint.core, - ) - def build_split( - self, - default_phase: PhaseType = PhaseType.training, - ) -> "SamplableSplitDataset[GPTDatasetSlice]": + def build(self) -> "GPTDatasetSlice": from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice - return GPTDatasetSlice.from_splits(self.dataset.build(), self.ratios) + return self._build(GPTDatasetSlice) @config_class() @@ -290,43 +229,13 @@ class FimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig, type_="fim"): hint=FieldHint.core, ) - @property - def split(self): - return self.dataset.split - - def build_sample( + def build_and_sample( self, config: GPTSamplingConfig, ) -> SampledDataset: from fast_llm.data.dataset.gpt.fim import FimDataset - assert not self.split - return FimDataset(self, self.dataset.build_sample(config), config) - - def build_split_sample( - self, - config: PhaseSplits[GPTSamplingConfig], - default_phase: PhaseType = PhaseType.training, - ) -> SampledSplitDataset: - from fast_llm.data.dataset.gpt.fim import FimDataset - - if not self.split: - # Take the base class shortcut using build_sample if it's available. - return super().build_split_sample(config, default_phase) - - # Build, sample and split the datasets. - sampled_datasets = self.dataset.build_split_sample( - # Blending is deterministic and the error will never be higher than 1. - PhaseSplits[SamplingConfig]({phase: phase_config for phase, phase_config in config.items()}), - default_phase, - ) - - # Blend the datasets for each phase. - return SampledSplitDataset[FimDataset]( - # TODO: Name - "fim", - {phase: FimDataset(self, sampled_datasets[phase], phase_config) for phase, phase_config in config.items()}, - ) + return FimDataset(self, self.dataset.build_and_sample(config), config) @config_class() @@ -367,18 +276,13 @@ def _from_dict( @config_class() -class GPTLegacyDatasetConfig(GPTSampledSplitDatasetConfig, GPTLegacyConfig, type_="legacy"): +class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig, type_="legacy"): _abstract = False - def build_split_sample( - self, - config: PhaseSplits[GPTSamplingConfig], - default_phase: PhaseType = PhaseType.training, - ) -> SampledSplitDataset: + def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset: if self.format == LegacyDatasetSource.random: Assert.eq(len(self.path), 0) - # TODO: Multiple phase. dataset_config = GPTDummyDatasetConfig() else: if self.format == LegacyDatasetSource.file: @@ -404,14 +308,19 @@ def build_split_sample( else: raise NotImplementedError(self.format) + phase_splits = padded_cumsum(self.ratio) + phase_index = { + PhaseType.training: 0, + PhaseType.validation: 1, + PhaseType.test: 2, + }[config.phase] + dataset_configs = [ - GPTSplitDatasetConfig( + GPTDatasetSliceConfig( + # TODO: this duplicates memmap datasets for each phase. dataset=GPTMemmapDatasetConfig(path=prefix), - ratios={ - PhaseType.training: self.ratio[0], - PhaseType.validation: self.ratio[1], - PhaseType.test: self.ratio[2], - }, + begin=phase_splits[phase_index], + end=phase_splits[phase_index + 1], ) for prefix in dataset_prefixes ] @@ -430,4 +339,4 @@ def build_split_sample( {"dataset": dataset_config}, ) - return dataset_config.build_split_sample(config, default_phase) + return dataset_config.build_and_sample(config) diff --git a/fast_llm/data/dataset/gpt/dummy.py b/fast_llm/data/dataset/gpt/dummy.py index f6eb5d51..d96034c4 100644 --- a/fast_llm/data/dataset/gpt/dummy.py +++ b/fast_llm/data/dataset/gpt/dummy.py @@ -13,7 +13,7 @@ def __init__(self, name: str, sequence_length: int, vocab_size: int): self._dummy_sample = np.random.randint(0, vocab_size, size=(sequence_length + 1,), dtype=np.int64) self._name = name - def sample(self, config: GPTSamplingConfig): + def sample(self, config: GPTSamplingConfig) -> "GPTDummySampledDataset": return GPTDummySampledDataset(self, config) def get(self): diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index fe0ef6a3..20bdeb30 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -4,7 +4,7 @@ import numpy as np from fast_llm.data.dataset.gpt.config import GPTSamplingConfig -from fast_llm.data.dataset.indexed import ConcatenatedIndexedDataset, IndexedDataset, IndexedDatasetSlice +from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset @@ -39,7 +39,7 @@ def sample(self, config: GPTSamplingConfig) -> "GPTSampledIndexedDataset": return GPTSampledIndexedDataset(self, config) -class GPTDatasetSlice(IndexedDatasetSlice, GPTIndexedDataset): +class GPTDatasetSlice(DatasetSlice, GPTIndexedDataset): """ A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. """ @@ -51,7 +51,7 @@ def get_document_sizes(self): return self._dataset.get_document_sizes()[self._begin : self._end] -class GPTConcatenatedDataset(ConcatenatedIndexedDataset, GPTIndexedDataset): +class GPTConcatenatedDataset(ConcatenatedDataset, GPTIndexedDataset): _datasets: list[GPTIndexedDataset] def get_document_sizes(self) -> np.ndarray: diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 7979586a..28269183 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -2,9 +2,8 @@ import numpy as np -from fast_llm.data.dataset.abstract import SamplableDataset, SamplableSplitDataset -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum +from fast_llm.data.dataset.abstract import SamplableDataset +from fast_llm.utils import Assert, padded_cumsum class IndexedDataset(SamplableDataset): @@ -24,7 +23,7 @@ def __len__(self) -> int: """ -class IndexedDatasetSlice(IndexedDataset): +class DatasetSlice(IndexedDataset): def __init__( self, @@ -61,24 +60,8 @@ def __len__(self): def name(self): return self._name - @classmethod - def from_splits(cls, dataset: IndexedDataset, phase_split: dict[PhaseType, float]): - """ - Create a set of GPT datasets from a MMapIndexedDataset, - each containing approximately the requested proportion of the total tokens. - """ - probabilities = normalize_probabilities(list(phase_split.values())) - splits = [round(x) for x in padded_cumsum(probabilities) * len(dataset)] - return SamplableSplitDataset[cls]( - f"{dataset.name}_split", - { - phase: cls(f"{dataset.name}_{phase.value}", dataset, split_begin, split_end) - for phase, split_begin, split_end in zip(phase_split, splits[:-1], splits[1:]) - }, - ) - - -class ConcatenatedIndexedDataset(IndexedDataset): + +class ConcatenatedDataset[IndexedDatasetType: IndexedDataset](IndexedDataset): def __init__( self, From baacc4e8cc9975b2f8282f9d4d0627453edfbb25 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 13 Jan 2025 11:56:44 -0500 Subject: [PATCH 11/19] Make tests pass --- ...ral-4-node-benchmark.yaml => mistral.yaml} | 7 ++++--- fast_llm/data/dataset/gpt/config.py | 5 ++++- tests/common.py | 20 +++++++++++++------ tests/test_dataset.py | 1 - 4 files changed, 22 insertions(+), 11 deletions(-) rename examples/{mistral-4-node-benchmark.yaml => mistral.yaml} (92%) diff --git a/examples/mistral-4-node-benchmark.yaml b/examples/mistral.yaml similarity index 92% rename from examples/mistral-4-node-benchmark.yaml rename to examples/mistral.yaml index 0a71d392..51c01595 100644 --- a/examples/mistral-4-node-benchmark.yaml +++ b/examples/mistral.yaml @@ -11,8 +11,9 @@ batch: micro_batch_size: 1 batch_size: 32 data: - dataset: - type: dummy + datasets: + Training: + type: dummy optimizer: learning_rate: base: 1.0e-05 @@ -52,4 +53,4 @@ model: distributed_timeout: 3600 seed: 984059 run: - experiment_dir: mistral_4_nodes_benchmark + experiment_dir: mistral_example diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 2fecc3ef..90c39bf3 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -11,6 +11,7 @@ ConcatenatedDatasetConfig, DatasetConfig, DatasetSliceConfig, + IndexedDatasetConfig, SamplableDatasetConfig, SampledDatasetConfig, SamplingConfig, @@ -88,7 +89,7 @@ class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig) @config_class() -class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig): +class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig, IndexedDatasetConfig): def build(self) -> "GPTIndexedDataset": raise NotImplementedError() @@ -127,6 +128,7 @@ def build(self) -> "GPTMemmapDataset": @config_class() class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig, type_="concatenated"): _abstract = False + datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() def build(self) -> "GPTConcatenatedDataset": from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset @@ -137,6 +139,7 @@ def build(self) -> "GPTConcatenatedDataset": @config_class() class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig, type_="slice"): _abstract = False + dataset: GPTIndexedDatasetConfig = FieldUpdate() def build(self) -> "GPTDatasetSlice": from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice diff --git a/tests/common.py b/tests/common.py index b14bad70..8da60bb9 100644 --- a/tests/common.py +++ b/tests/common.py @@ -60,12 +60,20 @@ "training.num_workers=0", "batch.batch_size=8", "batch.sequence_length=512", - "data.dataset.type=split", - "data.dataset.dataset.type=memmap", - f"data.dataset.dataset.path={DATASET_PREFIX}", - f"data.dataset.ratios.Training=969", - f"data.dataset.ratios.Validation=30", - f"data.dataset.ratios.Test=1", + "data.datasets.Training.type=slice", + "data.datasets.Training.end=0.969", + "data.datasets.Training.dataset.type=memmap", + f"data.datasets.Training.dataset.path={DATASET_PREFIX}", + "data.datasets.Validation.type=slice", + "data.datasets.Validation.begin=0.969", + "data.datasets.Validation.end=0.999", + "data.datasets.Validation.dataset.type=memmap", + f"data.datasets.Validation.dataset.path={DATASET_PREFIX}", + "data.datasets.Test.type=slice", + "data.datasets.Test.begin=0.999", + "data.datasets.Test.end=1", + "data.datasets.Test.dataset.type=memmap", + f"data.datasets.Test.dataset.path={DATASET_PREFIX}", "optimizer.learning_rate.base=0.0001", ] CONFIG_BASE_MEGATRON = [ diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ed66eb0b..054a8f5b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -4,7 +4,6 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.abstract import PhaseSplits from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.utils import Assert From 09640d8dd3ac23b31d7d7cd464103299b53262d8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 13 Jan 2025 14:31:53 -0500 Subject: [PATCH 12/19] misc --- tests/test_dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 054a8f5b..969db0e0 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -4,6 +4,7 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.gpt.config import GPTDatasetConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.utils import Assert @@ -27,7 +28,7 @@ def get_test_data( DATASET_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" -def get_test_datasets( +def get_dataset( config: dict, samples_per_phase: dict[PhaseType, int], cache_directory: pathlib.Path | None = None, @@ -39,6 +40,8 @@ def get_test_datasets( def test_dummy_dataset(): # TODO: Update + dataset = GPTDatasetConfig.from_dict({"type": "dummy"}) + datasets = get_test_datasets( {"type": "dummy"}, {PhaseType.training: 7, PhaseType.test: 4}, From a73acf6cd257e69bbbc02d528d91cc07891031c9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 13 Jan 2025 14:33:46 -0500 Subject: [PATCH 13/19] misc --- fast_llm/config.py | 65 ++++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index d3deda23..525adff0 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -291,7 +291,7 @@ def __post_init__(self): if _AUTO_VALIDATE: self.validate() - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. """ @@ -307,7 +307,7 @@ def __setattr__(self, key, value): ) super().__setattr__(key, value) - def __delattr__(self, key): + def __delattr__(self, key: str) -> None: """ Make the class read-only after validation. """ @@ -318,7 +318,7 @@ def __delattr__(self, key): ) super().__delattr__(key) - def validate(self, *, _is_validating=False): + def validate[T](self: T, *, _is_validating: bool = False) -> T: """ Validate a class and mark it as read-only This should not be overridden in derived classes. @@ -334,7 +334,7 @@ def validate(self, *, _is_validating=False): self._validated = True return self - def _validate(self): + def _validate(self) -> None: """ Verify that the type hints are respected, and fix some know entries compatible with the type hint (ex. `int -> float`, `str -> pathlib.Path`) @@ -522,7 +522,7 @@ def fields(cls) -> typing.Iterable[tuple[str, Field]]: return cls.__dataclass_fields__.items() # noqa @classmethod - def get_field(cls, name) -> Field: + def get_field(cls, name: str) -> Field: return cls.__dataclass_fields__[name] # noqa def _to_dict( @@ -531,7 +531,7 @@ def _to_dict( all_fields: bool = False, format_: _ConfigDictFormat = _ConfigDictFormat.nested, serializable: bool = False, - ): + ) -> dict[str, typing.Any]: """ Serialize the config to a dict that can (generally) be used to reconstruct an identical `Config`. When not flat, the dict includes a `__class__` entry which allows support for derived classes. @@ -561,12 +561,12 @@ def _add_field_to_args( args: dict | list, name: str | None, field: Field | None, - value, + value: typing.Any, verbose: int | None = None, all_fields: bool = False, format_: _ConfigDictFormat = _ConfigDictFormat.nested, serializable: bool = False, - ): + ) -> None: if ( field is not None and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR) @@ -622,7 +622,7 @@ def _add_field_to_args( raise NotImplementedError(format_) @classmethod - def _serialize_value(cls, value): + def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: value = value if hasattr(value, "__fast_llm_serialize__"): value = value.__fast_llm_serialize__() @@ -634,24 +634,24 @@ def _serialize_value(cls, value): value = str(value) return value - def to_copy( - self, - *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], - strict: bool = True, - ): + def to_copy[ + T + ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: return self.from_dict(self, *updates, strict=strict) - def to_serialized(self, verbose: int | None = FieldVerboseLevel.core): + def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]: return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) - def to_logs( + def to_logs[ + T + ]( self, verbose: int | None = FieldVerboseLevel.core, - log_fn=logger.info, + log_fn: typing.Callable[[str], T] = logger.info, title: str | None = None, width: int = 80, fill_char: str = "-", - ): + ) -> T: arg_dict = self.to_serialized(verbose=verbose) if title is None: title = self._get_class_name() @@ -662,16 +662,18 @@ def to_logs( ) @classmethod - def _get_class_name(cls): + def _get_class_name(cls) -> str: return get_type_name(cls) @classmethod - def from_dict( - cls, + def from_dict[ + T + ]( + cls: type[T], default: typing.Union["Config", dict[str, typing.Any]], *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True, - ): + ) -> T: if isinstance(default, Config): default = default._to_dict() for update in updates: @@ -683,21 +685,16 @@ def from_dict( return cls._from_dict(default, strict) @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ): + def from_flat_dict[ + T + ](cls: type[T], default: dict[str, typing.Any], strict: bool = True,) -> T: # TODO v0.3: Remove flat format return cls._from_dict(default, strict, True) @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ): + def _from_dict[ + T + ](cls: type[T], default: dict[str, typing.Any], strict: bool = True, flat: bool = False,) -> T: # TODO v0.3: Remove flat format out_arg_dict = {} @@ -814,7 +811,7 @@ def _handle_renamed_field( old_name: str | tuple[str, ...], new_name: str | tuple[str, ...], fn: typing.Callable | None = None, - ): + ) -> None: if old_name in default: warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.") value = pop_nested_dict_value(default, old_name) From 148b4489273571dd2b88807cb9858d5c06dd3517 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 15 Jan 2025 15:57:09 -0500 Subject: [PATCH 14/19] Type hints --- fast_llm/data/auto.py | 3 ++- fast_llm/data/data/abstract.py | 11 ++++------- fast_llm/data/data/gpt/config.py | 2 +- fast_llm/data/data/gpt/data.py | 12 ++++++------ fast_llm/data/dataset/abstract.py | 6 +++--- fast_llm/data/dataset/blended.py | 13 +++++++------ fast_llm/data/dataset/config.py | 2 +- fast_llm/data/dataset/gpt/config.py | 10 +++++----- fast_llm/data/dataset/gpt/dummy.py | 10 +++++----- fast_llm/data/dataset/gpt/fim.py | 19 ++++++++++--------- fast_llm/data/dataset/gpt/indexed.py | 2 +- fast_llm/data/dataset/gpt/memmap.py | 10 +++++----- fast_llm/data/dataset/gpt/sampled.py | 16 +++++++++------- fast_llm/data/dataset/indexed.py | 13 +++++++------ fast_llm/data/dataset/monitor.py | 7 ++++--- fast_llm/data/iterator.py | 6 ++++-- fast_llm/data/preparator/config.py | 11 ++--------- fast_llm/data/preparator/gpt_memmap/config.py | 8 +++++--- .../data/preparator/gpt_memmap/prepare.py | 12 ++++++------ fast_llm/data/tokenizer.py | 14 ++++++++------ fast_llm/models/auto.py | 6 ++++-- 21 files changed, 99 insertions(+), 94 deletions(-) diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index 1a5b7986..902faf1c 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -1,7 +1,8 @@ +from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.utils import Registry -dataset_preparator_registry = Registry( +dataset_preparator_registry = Registry[str, DatasetPreparatorConfig]( "DatasetPreparator", { dataset_preparator.preparator_name: dataset_preparator diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index 584451f1..e737bf7a 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -2,6 +2,7 @@ import pathlib import typing +from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig @@ -10,13 +11,13 @@ from fast_llm.engine.distributed.distributed import Distributed -class Data(abc.ABC): +class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): _distributed: "Distributed" _samples_per_phase: dict[PhaseType, int] _cache_directory: pathlib.Path | None def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: - self._config = config + super().__init__(config) self._distributed_config = distributed_config # TODO: Improve interface @@ -25,15 +26,11 @@ def setup( distributed: "Distributed", samples_per_phase: dict[PhaseType, int], cache_directory: pathlib.Path, - ): + ) -> None: self._distributed = distributed self._samples_per_phase = samples_per_phase self._cache_directory = cache_directory - @property - def config(self): - return self._config - @property def distributed(self): return self._distributed diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index f1d19581..27a65b31 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -43,7 +43,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): hint=FieldHint.expert, ) - def _validate(self): + def _validate(self) -> None: if not self.datasets: logger.warning( "Using the legacy dataset definition format." " Specify it through `data.datasets` instead." diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 0b63b02a..abc403e1 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,5 +1,6 @@ import logging import pathlib +import typing import warnings import torch @@ -21,7 +22,7 @@ logger = logging.getLogger(__name__) -class GPTData(Data): +class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ A global class for all dataset needs, including loading, splitting, sampling and iteration. Currently hard-coded to a GPT dataset. @@ -29,7 +30,6 @@ class GPTData(Data): """ _datasets: dict[PhaseType, SampledDataset] - _config: GPTDataConfig _tokenizer: Tokenizer | None _distributed: Distributed _is_setup: bool = False @@ -59,10 +59,10 @@ def max_sequence_length(self) -> int: def setup( self, - distributed: Distributed, + distributed: "Distributed", samples_per_phase: dict[PhaseType, int], cache_directory: pathlib.Path, - ): + ) -> None: """ Load the datasets, and prepare or load the samplings. This may take a while and a significant amount of cpu memory. @@ -96,7 +96,7 @@ def setup( self._is_setup = True @property - def tokenizer(self): + def tokenizer(self) -> Tokenizer: assert self._is_setup return self._tokenizer @@ -108,7 +108,7 @@ def get_iterator( consumed_samples: int, num_workers: int, prefetch_factor: int | None = None, - ): + ) -> typing.Iterator[typing.Any]: assert self._is_setup Assert.incl(phase, self._datasets) Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index c16587d1..84d98146 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -12,7 +12,7 @@ class Dataset(abc.ABC): @property @abc.abstractmethod - def name(self): + def name(self) -> str: """ A name for the dataset to facilitate identification and debugging. """ @@ -25,11 +25,11 @@ class SampledDataset(Dataset): """ @abc.abstractmethod - def __getitem__(self, index: int): + def __getitem__(self, index: int) -> typing.Any: pass @abc.abstractmethod - def __len__(self): + def __len__(self) -> int: pass diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 9fce891c..323cc307 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -1,5 +1,6 @@ import logging import pathlib +import typing import numpy as np @@ -66,7 +67,7 @@ def __init__( safe_barrier(group, self._name) self._load_mappings(sampling_config.verbose) - def __getstate__(self): + def __getstate__(self) -> tuple[typing.Any, ...]: return ( self._datasets, self._name, @@ -76,7 +77,7 @@ def __getstate__(self): self._sample_index if self._sample_idx_filename is None else self._sample_idx_filename, ) - def __setstate__(self, state): + def __setstate__(self, state: tuple[typing.Any, ...]): ( self._datasets, self._name, @@ -92,7 +93,7 @@ def __setstate__(self, state): self._dataset_idx_filename, self._sample_idx_filename = None, None self._dataset_index, self._sample_index = dataset_index, sample_index - def _load_mappings(self, verbose): + def _load_mappings(self, verbose: bool) -> None: if verbose: log_main_rank(lambda: f" > loading blending dataset index mapping from {self._dataset_idx_filename}") self._dataset_index = np.load(self._dataset_idx_filename, mmap_mode="r") @@ -100,10 +101,10 @@ def _load_mappings(self, verbose): log_main_rank(lambda: f" > loading blending dataset index mapping from {self._sample_idx_filename}") self._sample_index = np.load(self._sample_idx_filename, mmap_mode="r") - def __len__(self): + def __len__(self) -> int: return self._num_samples - def _build_blending_indices(self, verbose: bool): + def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray]: assert _extension_available, ( "The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly." ) @@ -135,7 +136,7 @@ def _build_blending_indices(self, verbose: bool): ) return dataset_index, dataset_sample_index - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> typing.Any: return self._datasets[self._dataset_index[idx]][self._sample_index[idx]] @property diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index e842b87c..0cdb90ba 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -148,7 +148,7 @@ class BlendedDatasetConfig(SampledDatasetConfig): hint=FieldHint.core, ) - def __post_init__(self): + def __post_init__(self) -> None: Assert.eq(len(self.datasets), len(self.weights)) def build_and_sample( diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 90c39bf3..1c3db56e 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -47,7 +47,7 @@ class GPTDatasetConfig(DatasetConfig): hint=FieldHint.core, ) - def _validate(self): + def _validate(self) -> None: if self.type is not None: # Should be handled in `from_dict`, but can fail if instantiating directly. Assert.eq(self.type, self.type_) @@ -59,7 +59,7 @@ def _from_dict( default: dict[str, typing.Any], strict: bool = True, flat: bool = False, - ): + ) -> typing.Self: type_ = default.get("type") if type_ is None: actual_cls = cls @@ -71,7 +71,7 @@ def _from_dict( else: return actual_cls._from_dict(default, strict=strict, flat=flat) - def __init_subclass__(cls, type_: str | None = None, **kwargs): + def __init_subclass__(cls, type_: str | None = None, **kwargs) -> None: if type_ is not None: GPTDatasetConfig._registry[type_] = cls cls.type_ = type_ @@ -163,12 +163,12 @@ class LegacyDatasetSource(str, enum.Enum): random = "random" -def _validate_split(value): +def _validate_split(value: list[int]) -> list[int]: Assert.leq(len(value), 3) return value + [0] * (len(value) - 3) -def _validate_path(value): +def _validate_path(value: str | list[str]) -> list[str]: return [value] if isinstance(value, str) else value diff --git a/fast_llm/data/dataset/gpt/dummy.py b/fast_llm/data/dataset/gpt/dummy.py index d96034c4..484d811c 100644 --- a/fast_llm/data/dataset/gpt/dummy.py +++ b/fast_llm/data/dataset/gpt/dummy.py @@ -16,11 +16,11 @@ def __init__(self, name: str, sequence_length: int, vocab_size: int): def sample(self, config: GPTSamplingConfig) -> "GPTDummySampledDataset": return GPTDummySampledDataset(self, config) - def get(self): + def get(self) -> np.ndarray: return self._dummy_sample @property - def name(self): + def name(self) -> str: return self._name @@ -29,12 +29,12 @@ def __init__(self, dataset: GPTDummyDataset, config: GPTSamplingConfig): self._config = config self._dataset = dataset - def __len__(self): + def __len__(self) -> int: return self._config.num_samples - def __getitem__(self, idx): + def __getitem__(self, idx) -> np.ndarray: return self._dataset.get() @property - def name(self): + def name(self) -> str: return self._dataset.name diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 1e325113..0ed76d80 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -33,20 +33,20 @@ def __init__( self._tokenizer.vocab[self._config.split_sample] if self._config.split_sample is not None else None ) - def __len__(self): + def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> np.ndarray: sample = self._fim( self._dataset[idx], np.random.RandomState(seed=(self._sampling_config.seed + idx) % MAX_SEED) ) return sample @property - def name(self): + def name(self) -> str: return f"{self._dataset.name}_fim" - def _fim(self, sample, np_rng): + def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: # FIM # TODO: permute segments in sample_list, before concatenating. sample_len = sample.shape[0] @@ -81,8 +81,9 @@ def _fim(self, sample, np_rng): sample = np.concatenate([sample, np.full((-1 * diff), self._pad_tok_id)]) assert sample.shape[0] == sample_len + return sample - def _fim_split_and_permute_sequence(self, sequence, np_rng): + def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: """ fragment_fim_rate: if set, apply fim with this rate to each fragment. """ @@ -116,10 +117,10 @@ def _fim_split_and_permute_sequence(self, sequence, np_rng): def _fim_permute_sequence( self, - sequence, - np_rng, - rate, - ): + sequence: np.ndarray, + np_rng: np.random.RandomState, + rate: float, + ) -> np.ndarray: """ Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it. truncate_or_pad: if True, maintain the same sample length (if transform creates a few extra tokens, drop them). diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 20bdeb30..833c569e 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -46,7 +46,7 @@ class GPTDatasetSlice(DatasetSlice, GPTIndexedDataset): _dataset: GPTIndexedDataset - def get_document_sizes(self): + def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. return self._dataset.get_document_sizes()[self._begin : self._end] diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 0ee0aa50..cb531b56 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -22,7 +22,7 @@ class GPTMemmapDataset(GPTIndexedDataset): def __init__(self, name: str, prefix: pathlib.Path | str): self._init(name, prefix) - def _init(self, name: str, prefix: pathlib.Path | str): + def _init(self, name: str, prefix: pathlib.Path | str) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) @@ -51,10 +51,10 @@ def _init(self, name: str, prefix: pathlib.Path | str): self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - def __getstate__(self): + def __getstate__(self) -> tuple[str, pathlib.Path]: return (self._name, self._prefix) - def __setstate__(self, state): + def __setstate__(self, state: tuple[str, pathlib.Path]): self._init(*state) def __del__(self): @@ -63,7 +63,7 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - def get(self, idx, offset=0, length=None): + def get(self, idx, offset=0, length=None) -> np.ndarray: return np.frombuffer( self._bin_buffer, dtype=self._dtype, @@ -72,7 +72,7 @@ def get(self, idx, offset=0, length=None): ) @property - def name(self): + def name(self) -> str: return self._name def __len__(self) -> int: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index b28882a1..943873a9 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -1,4 +1,6 @@ import math +import pathlib +import typing import numpy as np @@ -61,7 +63,7 @@ def __init__( safe_barrier(group, self._indexed_dataset.name) self._load_mappings(sampling_config.verbose) - def _sample(self, sampling_config: GPTSamplingConfig): + def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Create a `GPTSampledDataset` with the requested parameters. """ @@ -119,7 +121,7 @@ def _sample(self, sampling_config: GPTSamplingConfig): # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. return doc_idx, sample_idx, shuffle_idx[: sampling_config.num_samples] - def __getstate__(self): + def __getstate__(self) -> tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]: return ( self._indexed_dataset, self._doc_idx_filename, @@ -127,7 +129,7 @@ def __getstate__(self): self._shuffle_idx_filename, ) - def __setstate__(self, state): + def __setstate__(self, state: tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]) -> None: ( self._indexed_dataset, self._doc_idx_filename, @@ -136,7 +138,7 @@ def __setstate__(self, state): ) = state self._load_mappings(False) - def _load_mappings(self, verbose): + def _load_mappings(self, verbose: bool) -> None: if verbose: log_main_rank(lambda: f" > loading doc-idx mapping from {self._doc_idx_filename}") self._doc_idx = np.load(self._doc_idx_filename, mmap_mode="r") @@ -149,12 +151,12 @@ def _load_mappings(self, verbose): if verbose: log_main_rank(lambda: f" loaded dataset with {len(self)} samples.") - def __len__(self): + def __len__(self) -> int: # -1 is due to data structure used to retrieve the index: # sample i --> [sample_idx[i], sample_idx[i+1]) return self._shuffle_idx.shape[0] - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> typing.Any: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. @@ -180,5 +182,5 @@ def __getitem__(self, idx): return sample @property - def name(self): + def name(self) -> str: return self._indexed_dataset.name diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 28269183..b9226724 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -1,4 +1,5 @@ import abc +import typing import numpy as np @@ -13,7 +14,7 @@ class IndexedDataset(SamplableDataset): """ @abc.abstractmethod - def get(self, index: int, *args, **kwargs): + def get(self, index: int, *args, **kwargs) -> typing.Any: pass @abc.abstractmethod @@ -45,7 +46,7 @@ def __init__( except Exception as e: raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e - def get(self, document: int, offset: int = 0, length: int | None = None): + def get(self, document: int, offset: int = 0, length: int | None = None) -> typing.Any: """ Get the sample (document) with the given index (in the dataset slice), optionally sub-sampled to a specific offset (starting point) and maximum length @@ -53,11 +54,11 @@ def get(self, document: int, offset: int = 0, length: int | None = None): """ return self._dataset.get(document + self._begin, offset, length) - def __len__(self): + def __len__(self) -> int: return self._end - self._begin @property - def name(self): + def name(self) -> str: return self._name @@ -74,11 +75,11 @@ def __init__( self._dataset_splits = padded_cumsum(sizes) def __len__(self) -> int: - return self._dataset_splits[-1] + return self._dataset_splits[-1].item() def get(self, index: int, *args, **kwargs): dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get(index - self._dataset_splits[dataset], *args, **kwargs) + return self._datasets[dataset].get(index - self._dataset_splits[dataset].item(), *args, **kwargs) @property def name(self) -> str: diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 7892928f..86bc080f 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -1,5 +1,6 @@ import logging import time +import typing from fast_llm.data.dataset.abstract import SampledDataset @@ -29,10 +30,10 @@ def __init__( self._dataset = dataset self._data_sample_warn_time_ms = data_sample_warn_time_ms - def __len__(self): + def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx): + def __getitem__(self, idx) -> typing.Any: start_time = time.perf_counter() try: sample = self._dataset[idx] @@ -48,5 +49,5 @@ def __getitem__(self, idx): raise @property - def name(self): + def name(self) -> str: return self._dataset.name diff --git a/fast_llm/data/iterator.py b/fast_llm/data/iterator.py index 8a8fdcd2..a407c025 100644 --- a/fast_llm/data/iterator.py +++ b/fast_llm/data/iterator.py @@ -1,3 +1,5 @@ +import typing + import torch.utils.data @@ -15,9 +17,9 @@ def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data self._start_idx = data_rank * micro_batch_size self._end_idx = (data_rank + 1) * micro_batch_size - def __len__(self): + def __len__(self) -> int: return self._total_samples - def __iter__(self): + def __iter__(self) -> typing.Iterator[list[int]]: for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): yield list(range(idx + self._start_idx, idx + self._end_idx)) diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index d6701be2..edf088c0 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -1,9 +1,8 @@ import abc import typing -from fast_llm.config import config_class +from fast_llm.config import Configurable, config_class from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.utils import Assert @config_class() @@ -19,15 +18,9 @@ def _get_runnable(self) -> typing.Callable[[], None]: return dataset_preparator.run -class DatasetPreparator(abc.ABC): - _config: DatasetPreparatorConfig +class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[ConfigType], abc.ABC): config_class: typing.ClassVar[type[DatasetPreparatorConfig]] = DatasetPreparatorConfig - def __init__(self, config: DatasetPreparatorConfig) -> None: - Assert.custom(isinstance, config, self.config_class) - config.validate() - self._config = config - @abc.abstractmethod def run(self) -> None: raise NotImplementedError diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 29730448..fcae902e 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -8,6 +8,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator MEMMAP_DTYPES = { 1: DataType.uint8, 2: DataType.int8, @@ -86,7 +88,7 @@ class DatasetPreparatorDistributedConfig(Config): hint=FieldHint.optional, ) - def _validate(self): + def _validate(self) -> None: if self.world_size is None: self.world_size = self.default_world_size if self.rank is None: @@ -144,14 +146,14 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.feature, ) - def _validate(self): + def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() @classmethod - def get_dataset_preparator_class(cls): + def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator return GPTMemmapDatasetPreparator diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index dd475829..d35c3c4a 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,6 +1,7 @@ import json import multiprocessing import pathlib +import typing import datasets import numpy as np @@ -15,14 +16,13 @@ from fast_llm.engine.config_utils.data_type import DataType -class GPTMemmapDatasetPreparator(DatasetPreparator): - _config: GPTMemmapDatasetPreparatorConfig - config_class = GPTMemmapDatasetPreparatorConfig +class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): + config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig _tokenizer: Tokenizer _data_type: DataType - def _tokenize_batch(self, batch): + def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._config.dataset.field] @@ -33,7 +33,7 @@ def _tokenize_batch(self, batch): "num_tokens": num_tokens, } - def _save_shard(self, args) -> dict: + def _save_shard(self, args: tuple[int, datasets.Dataset]) -> dict[str, typing.Any]: shard_idx, shard_dataset = args prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" shard_output_path = self._config.output_path / prefix @@ -62,7 +62,7 @@ def _load_dataset(self) -> datasets.Dataset: assert isinstance(dataset, datasets.Dataset) return dataset - def run(self): + def run(self) -> None: # Set transformers logging verbosity transformers.logging.set_verbosity_error() diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index f5fde98d..a9e99278 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -1,3 +1,5 @@ +import numpy as np +import torch from transformers import PreTrainedTokenizerFast from fast_llm.data.config import TokenizerConfig @@ -11,7 +13,7 @@ class Tokenizer: def __init__(self, config: TokenizerConfig): log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer = PreTrainedTokenizerFast.from_pretrained( + self.tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained( pretrained_model_name_or_path=config.path, errors="replace", max_len=None ) if self.tokenizer.eos_token_id is None: @@ -20,21 +22,21 @@ def __init__(self, config: TokenizerConfig): self._inv_vocab = {v: k for k, v in self.vocab.items()} @property - def vocab_size(self): + def vocab_size(self) -> int: return len(self.tokenizer) @property - def vocab(self): + def vocab(self) -> dict[str, int]: return self.tokenizer.vocab @property - def inv_vocab(self): + def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str): + def tokenize(self, text: str) -> list[int]: return self.tokenizer.encode(text) - def detokenize(self, token_ids): + def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) @property diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 6437e3b3..905552a8 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -1,8 +1,10 @@ +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig from fast_llm.utils import Registry -model_registry = Registry( +model_registry = Registry[str, FastLLMModelConfig]( "Model", { model.model_name: model @@ -13,7 +15,7 @@ }, ) -trainer_registry = Registry( +trainer_registry = Registry[str, TrainerConfig]( "Model", { trainer.get_field("model").type.model_name: trainer From c51a5d8bc726f6f628d3e59c9410322145cb4e45 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 15 Jan 2025 17:52:48 -0500 Subject: [PATCH 15/19] fix --- fast_llm/data/dataset/gpt/sampled.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index fcbd5066..5361561c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -1,15 +1,14 @@ import logging import math +import typing import numpy as np from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig -from fast_llm.data.dataset.gpt.fim.fim import Fim from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.distributed.config import MAX_SEED logger = logging.getLogger(__name__) @@ -41,12 +40,6 @@ def __init__( self._sampling_config = sampling_config self._shuffle_epochs = data.config.shuffle_epochs - if data.config.fim.rate > 0: - assert data.tokenizer is not None - self._fim = Fim(data.config.fim, data.tokenizer) - else: - self._fim = None - cache_prefix = f"{self.name}_ns_{self._sampling_config.num_samples}_sl_{self._sampling_config.sequence_length}_s_{self._sampling_config.seed}" # TODO: Any way to combine into a single file? (Memmap is harder) self._doc_idx_filename = self._sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy") @@ -198,12 +191,12 @@ def _load_mappings(self, verbose=False): if verbose: log_main_rank(lambda: f" loaded dataset with {len(self)} samples.") - def __len__(self): + def __len__(self) -> int: # -1 is due to data structure used to retrieve the index: # sample i --> [sample_idx[i], sample_idx[i+1]) return self._shuffle_idx.shape[0] - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> typing.Any: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. @@ -254,11 +247,8 @@ def __getitem__(self, idx): sample_list, dtype=np.int64, ) - if self._fim is not None: - sample = self._fim(sample, np.random.RandomState(seed=(self._sampling_config.seed + idx) % MAX_SEED)) - return sample @property - def name(self): + def name(self) -> str: return self._indexed_dataset.name From 34030344ea5d784c2baedc5b4277f76210f5883b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 27 Jan 2025 19:49:18 -0500 Subject: [PATCH 16/19] Fix merge --- fast_llm/csrc/data.cpp | 64 ++++++++++++++++++++------------ fast_llm/data/dataset/blended.py | 16 +++----- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/fast_llm/csrc/data.cpp b/fast_llm/csrc/data.cpp index 67ae946d..b7e52924 100644 --- a/fast_llm/csrc/data.cpp +++ b/fast_llm/csrc/data.cpp @@ -104,7 +104,9 @@ void build_blending_indices(py::array_t& dataset_index, py::array build_sample_idx(const py::array_t& sizes_, const py::array_t& doc_idx_, const int32_t seq_length, - const int64_t num_samples) { + const int32_t num_epochs, + const int64_t tokens_per_epoch, + const bool verbose) { /* Sample index (sample_idx) is used for gpt2 like dataset for which the documents are flattened and the samples are built based on this 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] @@ -113,14 +115,29 @@ py::array build_sample_idx(const py::array_t& sizes_, // Consistency checks. assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); // Remove bound checks. auto sizes = sizes_.unchecked<1>(); auto doc_idx = doc_idx_.unchecked<1>(); // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + if (verbose) { + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + } + // Index into sample_idx. int64_t sample_index = 0; // Index into doc_idx. @@ -134,29 +151,30 @@ py::array build_sample_idx(const py::array_t& sizes_, while (sample_index <= num_samples) { // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; - while (remaining_seq_length > 0) { + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; - // And add it to the current sequence. - remaining_seq_length -= doc_length; - // If we have more than a full sequence, adjust offset and set - // remaining length to zero so we return from the while loop. - // Note that -1 here is for the same reason we have -1 in - // `_num_epochs` calculations. - if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - 1); - } else { - // Otherwise, start from the beginning of the next document. - ++doc_idx_index; - doc_offset = 0; - } - } - // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the beginning of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; } // Method to deallocate memory. diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 783f4df3..23c45075 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -4,10 +4,8 @@ import numpy as np -from fast_llm.data.data.config import SamplingConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingConfig -from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert, normalize_probabilities try: @@ -85,20 +83,16 @@ def __setstate__(self, state: tuple[typing.Any, ...]): self._dataset_idx_filename, self._sample_idx_filename = None, None self._dataset_index, self._sample_index = dataset_index, sample_index - def _load_mappings(self, verbose: bool = False) -> None: - if hasattr(self, "_dataset_index"): + def _load_mappings(self) -> None: + if hasattr(self, "_dataset_index") and hasattr(self, "_sample_index"): return - if verbose: - log_main_rank(lambda: f" > loading blending dataset index mapping from {self._dataset_idx_filename}") self._dataset_index = np.load(self._dataset_idx_filename, mmap_mode="r") - if verbose: - log_main_rank(lambda: f" > loading blending dataset index mapping from {self._sample_idx_filename}") self._sample_index = np.load(self._sample_idx_filename, mmap_mode="r") def __len__(self) -> int: return self._num_samples - def _build_blending_indices(self, verbose: bool = False) -> tuple[np.ndarray, np.ndarray]: + def _build_blending_indices(self) -> tuple[np.ndarray, np.ndarray]: assert _extension_available, ( "The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly." ) @@ -111,8 +105,7 @@ def _build_blending_indices(self, verbose: bool = False) -> tuple[np.ndarray, np self._weights, len(self._datasets), self._num_samples, - # TODO: Verbose option? - True, # verbose + True, # Verbose ) available_samples_per_dataset = np.array([len(dataset) for dataset in self._datasets]) sampled_per_dataset = np.bincount(dataset_index) @@ -132,6 +125,7 @@ def _build_blending_indices(self, verbose: bool = False) -> tuple[np.ndarray, np return dataset_index, dataset_sample_index def __getitem__(self, idx: int) -> typing.Any: + self._load_mappings() return self._datasets[self._dataset_index[idx]][self._sample_index[idx].item()] @property From f6c4b56e34caaf723b21ba34f528c8c9663e56d2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 27 Jan 2025 19:51:10 -0500 Subject: [PATCH 17/19] Fix merge --- examples/mistral.yaml | 2 +- fast_llm/core/distributed.py | 21 +- fast_llm/engine/checkpoint/config.py | 10 +- fast_llm/engine/checkpoint/distributed.py | 5 +- fast_llm/engine/checkpoint/safe_load.py | 5 +- fast_llm/engine/checkpoint/state_dict.py | 4 +- fast_llm/engine/config_utils/run.py | 5 - fast_llm/engine/distributed/config.py | 3 +- fast_llm/engine/distributed/distributed.py | 11 +- fast_llm/engine/multi_stage/fast_llm_model.py | 6 +- fast_llm/engine/training/config.py | 20 +- fast_llm/engine/training/trainer.py | 33 +- tests/common.py | 33 +- tests/test_dataset.py | 674 +++++++++++++++++- tests/test_simple.py | 26 + 15 files changed, 781 insertions(+), 77 deletions(-) diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 5725f0c3..d60a7802 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -13,7 +13,7 @@ batch: data: datasets: Training: - type: dummy + type: random optimizer: learning_rate: base: 1.0e-05 diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index a642c51b..e82e0801 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -7,6 +7,7 @@ """ import contextlib +import datetime import logging import typing @@ -25,12 +26,21 @@ logger = logging.getLogger(__name__) -def broadcast(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False) -> Work | None: +def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None: + if group is not None and timeout is not None: + # TODO: Only works for nccl? + group._add_ephemeral_timeout(datetime.timedelta(seconds=timeout)) + + +def broadcast( + tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, timeout: float | None = None +) -> Work | None: """Same as torch.distributed.broadcast, but without the complication of going through the global rank.""" assert group is not None opts = torch.distributed.BroadcastOptions() opts.rootRank = src opts.rootTensor = 0 + add_ephemeral_timeout(group, timeout) work = group.broadcast([tensor], opts) if async_op: return work @@ -53,10 +63,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: ) -def safe_barrier(group: ProcessGroup | None, value: int | str = 1) -> None: +def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None: if group: hashed = hash(value) % 2**32 - out = allreduce_scalar(hashed, dtype=torch.int64, group=group) + out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout) if out != hashed * group.size(): raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})") @@ -66,9 +76,11 @@ def allreduce_scalar( dtype: torch.dtype = torch.float64, group: torch.distributed.ProcessGroup | None = None, op=ReduceOp.SUM, + timeout: float | None = None, ) -> float | int: if group: value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) + add_ephemeral_timeout(group, timeout) torch.distributed.all_reduce(value, op=op, group=group) return value.item() else: @@ -80,13 +92,14 @@ def broadcast_scalar( dtype: torch.dtype = torch.float64, group: torch.distributed.ProcessGroup | None = None, src: int = 0, + timeout: float | None = None, ) -> float | int: if not group: return value tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device())) if group.rank() == src: tensor.fill_(value) - broadcast(tensor, src, group) + broadcast(tensor, src, group, timeout=timeout) return tensor.item() diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index e9b6f40c..92f1165d 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -7,7 +7,7 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -179,6 +179,12 @@ class CheckpointPathConfigBase(CheckpointConfigBase): desc="Location of the checkpoint.", hint=FieldHint.core, ) + timeout: float | None = Field( + default=None, + desc="Custom timeout for lengthy operations.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) @config_class() @@ -248,5 +254,5 @@ def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"): def get_num_shards(self, config: CheckpointStateConfigBase) -> int: return len(self._model.state_shard_names) if config.optimizer_state else 1 - def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str]: + def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]: return self._model.state_shard_names if config.optimizer_state else self._model.state_shard_names[:1] diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index f3ba4d2a..9c171bef 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -5,7 +5,7 @@ import torch import yaml -from fast_llm.core.distributed import broadcast_scalar, safe_barrier +from fast_llm.core.distributed import broadcast_scalar from fast_llm.engine.checkpoint.config import ( CheckpointFormat, CheckpointHandler, @@ -70,7 +70,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No else: log_main_rank("Checkpoint format doesn't match, using safe load") self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn) - with SafeLoad(self._model, num_shards=num_shards) as context: + with SafeLoad(self._model, num_shards=num_shards, timeout=config.timeout) as context: for rank in range(loaded_config.distributed.world_size): loaded_model = self._model.__class__( loaded_config.to_copy({("distributed", "rank"): rank}), @@ -79,7 +79,6 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No ) path = config.path / f"rank_{rank}.safetensors" log_main_rank(f"Loading from {path}") - safe_barrier(self._model.distributed.world_group, f"load {path}") # TODO: skip shards without overlap. with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f: # TODO: Use self_shard diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 2b0d0e0c..4cf9263b 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -4,6 +4,7 @@ import torch from torch.distributed import all_reduce +from fast_llm.core.distributed import add_ephemeral_timeout from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.functional.triton.pointwise import triton_fill from fast_llm.utils import Assert @@ -23,11 +24,12 @@ class SafeLoad: In case of failure, it will attempt to find out as precisely as possible where the problem comes from. """ - def __init__(self, model: "FastLLMModel", *, num_shards: int): + def __init__(self, model: "FastLLMModel", *, num_shards: int, timeout: float | None = None): self._model = model self._distributed = self._model.distributed self._num_shards = num_shards self._self_shard = self._model.state_shard[: self._num_shards] + self._timeout = timeout def __enter__(self) -> "SafeLoad": self._loaded = 0 @@ -145,6 +147,7 @@ def _check_parameters(self, errors: list[str]) -> None: counter_tensor = torch.tensor( [counter or 0 for counter in counter_per_parameter.values()], dtype=torch.int64 ).to(self._distributed.device) + add_ephemeral_timeout(self._distributed.world_group, self._timeout) all_reduce(counter_tensor, group=self._distributed.world_group) counter_per_parameter = { parameter_name: counter diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index a42b322e..5d2e913c 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -72,7 +72,7 @@ def _serialize_metadata( return metadata.to_serialized() def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: - with SafeLoad(self._model, num_shards=self.get_num_shards(config)) as context: + with SafeLoad(self._model, num_shards=self.get_num_shards(config), timeout=config.timeout) as context: # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from # `state_dict` that are ready for conversion, # and return a dict containing the converted tensors(s). @@ -222,7 +222,7 @@ def _merge_index(self) -> None: yaml.dump( self._index, (self._config.path / f"index_{self._distributed_config.pipeline_rank}.yaml").open("w") ) - safe_barrier(self._distributed.pipeline_group, "save state dict") + safe_barrier(self._distributed.pipeline_group, "save state dict", timeout=self._config.timeout) self._index = {} if self._distributed_config.pipeline_rank == 0: for rank in range(self._distributed_config.pipeline_parallel): diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 0e12adc2..886e341d 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -199,11 +199,6 @@ def save_logged_tensors(self, iteration: int | str) -> None: torch.save(tensor_stats, self.open_artifact(f"tensor_logs_{iteration}.pt", mode="wb")) TensorLogs.reset(self._config.tensor_logs) - def barrier(self, value: int | str = 1) -> None: - from fast_llm.core.distributed import safe_barrier - - safe_barrier(self._distributed.world_group, value) - def broadcast_int(self, value: int) -> int: import torch diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index bc961c3d..1b3e73bb 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -191,7 +191,7 @@ class DistributedConfig(Config): desc="Prioritize the pipeline groups for placement of nearby ranks over data groups.", hint=FieldHint.expert, ) - distributed_timeout: float = Field( + timeout: float = Field( default=60, desc="Timeout for distributed operations.", hint=FieldHint.optional, @@ -382,4 +382,5 @@ def _from_dict( del default["sequence_first"] if "separate_init_generators" in default and strict: del default["separate_init_generators"] + cls._handle_renamed_field(default, "distributed_timeout", "timeout") return super()._from_dict(default, strict, flat) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 8cc9288c..a612d9cf 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -44,8 +44,6 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.device = torch.device(self._config.local_rank) torch.cuda.set_device(self.device) - timeout = datetime.timedelta(seconds=self._config.distributed_timeout) - # We bypass `torch.distributed.init_process_group` which makes things way more complicated for no reason. # TODO: Allow other init methods? @@ -53,7 +51,12 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): if self._config.world_size > 1: self._config.log_first_rank("Initializing TCP store.") self.store, _, _ = next( - torch.distributed.rendezvous("env://", self._config.rank, self._config.world_size, timeout=timeout) + torch.distributed.rendezvous( + "env://", + self._config.rank, + self._config.world_size, + timeout=datetime.timedelta(seconds=self._config.timeout), + ) ) self._process_groups = {} for name, distributed_dim in self._config.distributed_dims.items(): @@ -139,7 +142,7 @@ def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None: torch.distributed.PrefixStore(prefix + "/", self.store), distributed_dim.rank, distributed_dim.size, - datetime.timedelta(seconds=self._config.distributed_timeout), + datetime.timedelta(seconds=self._config.timeout), ) self._process_groups[distributed_dim.name] = group distributed_dim.setup(group) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index c9f7a6dc..b268ec29 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -85,13 +85,15 @@ def from_pretrained( model.initialize_weights() return model - def initialize_weights(self) -> None: + def initialize_weights(self, timeout: float | None = None) -> None: assert self._is_setup for stage in self._stages: stage.initialize_weights() for name, tied_parameter in self._tied_parameters.items(): if tied_parameter.group is not None: - broadcast(self._stages[tied_parameter.main_stage].weight_shard, 0, tied_parameter.group) + broadcast( + self._stages[tied_parameter.main_stage].weight_shard, 0, tied_parameter.group, timeout=timeout + ) self._finalize_load(reset_optimizer=True) def _finalize_load(self, reset_optimizer: bool = True) -> None: diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index b977323e..30add2f4 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -184,7 +184,7 @@ class TrainingCheckpointBaseConfig(IntervalConfig): def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: pass - def get_save_config(self, path: pathlib.Path) -> CheckpointSaveConfig: + def get_save_config(self, path: pathlib.Path, timeout: float | None) -> CheckpointSaveConfig: raise NotImplementedError() def to_delete(self, iterations: list[int]) -> list[int]: @@ -216,20 +216,22 @@ def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path new_path = experiment_directory / "checkpoint" return old_path if old_path.is_dir() and not new_path.is_dir() else new_path - def get_save_config(self, path: pathlib.Path) -> CheckpointSaveConfig: + def get_save_config(self, path: pathlib.Path, timeout: float | None) -> CheckpointSaveConfig: return CheckpointSaveConfig( path=path, format=DistributedCheckpointFormat, model_weights=True, optimizer_state=True, + timeout=timeout, ) - def get_load_config(self, path: pathlib.Path) -> CheckpointLoadConfig: + def get_load_config(self, path: pathlib.Path, timeout: float | None) -> CheckpointLoadConfig: return CheckpointLoadConfig( path=path, format=DistributedCheckpointFormat, model_weights=True, optimizer_state=True, + timeout=timeout, ) @@ -247,8 +249,8 @@ class TrainingExportConfig(TrainingCheckpointBaseConfig, CheckpointStateSaveConf def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "export" / self.format.name - def get_save_config(self, path: pathlib.Path) -> CheckpointSaveConfig: - return CheckpointSaveConfig.from_dict(self, {"path": path}, strict=False) + def get_save_config(self, path: pathlib.Path, timeout: float | None) -> CheckpointSaveConfig: + return CheckpointSaveConfig.from_dict(self, {"path": path, "timeout": timeout}, strict=False) @config_class() @@ -305,6 +307,14 @@ class TrainingConfig(Config): valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + timeout: float | None = Field( + default=3600, + desc="Timeout for lengthy operations such as checkpoint saving and loading," + " and dataset preparation and sampling.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + def _validate(self) -> None: super()._validate() self.shutdown.assert_sub_interval(self.checkpoint) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 4d0e68c1..d43abe56 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -82,7 +82,6 @@ def __init__(self, config: TrainerConfig): def setup(self, distributed: Distributed, run: Run) -> None: assert distributed.config is self._config.model.distributed assert not self._is_setup - self._is_setup = True self._distributed = distributed self._run = run self._wandb = Wandb(self._config.training.wandb, self._run, self._config) @@ -110,7 +109,9 @@ def setup(self, distributed: Distributed, run: Run) -> None: distributed, self._samples_per_split, None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", + timeout=self._config.training.timeout, ) + self._is_setup = True @abc.abstractmethod def _get_data(self) -> Data: @@ -314,7 +315,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: if self._config.training.export.enabled(None if done else self._completed_steps): self._save_checkpoint(self._config.training.export, metrics) - + # The profiler calls the trace_fn at the end and this could lead to + profiler.step() return done, metrics def _evaluate( @@ -336,7 +338,10 @@ def _evaluate( total_losses[name] += value self._run.save_logged_tensors(f"{phase}_{self._completed_steps}_{iter_}") - safe_barrier(self._distributed.world_group, f"{phase.value} end") + safe_barrier( + self._distributed.world_group, + f"{phase.value} end", + ) end_time = time.perf_counter() time_per_iteration = (end_time - begin_time) / num_iters model_tflops, hardware_tflops = self.get_tflops(phase, time_per_iteration) @@ -406,7 +411,7 @@ def _save_checkpoint( logger.info(f"Saving {config.save_name} at iteration {self._completed_steps}") checkpoint_directory.mkdir(exist_ok=False, parents=True) # Barrier to ensure the directory is created correctly (and didn't exist before). - self._run.barrier(f"{config.save_name} {self._completed_steps} enter") + safe_barrier(self._distributed.world_group, f"{config.save_name} {self._completed_steps} enter") metadata = { "optimizer": self._optimizer.save(), @@ -414,10 +419,16 @@ def _save_checkpoint( } if metrics is not None: metadata["metrics"] = {key.value: value for key, value in metrics.items()} - self._multi_stage.save_checkpoint(config.get_save_config(checkpoint_directory), metadata) + self._multi_stage.save_checkpoint( + config.get_save_config(checkpoint_directory, timeout=self._config.training.timeout), metadata + ) # Barrier to ensure everyone is done. - self._run.barrier(f"{config.save_name} {self._completed_steps} exit") + safe_barrier( + self._distributed.world_group, + f"{config.save_name} {self._completed_steps} exit", + timeout=self._config.training.timeout, + ) # Mark the checkpoint as complete. if self._run.is_main_rank: (checkpoint_directory / "ok").open("w") @@ -439,7 +450,9 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> checkpoint_directory = config.get_save_directory(self._run.experiment_directory) / str(iteration) Assert.custom(pathlib.Path.is_file, checkpoint_directory / "ok") - metadata = self._multi_stage.load_checkpoint(config.get_load_config(checkpoint_directory)) + metadata = self._multi_stage.load_checkpoint( + config.get_load_config(checkpoint_directory, timeout=self._config.training.timeout) + ) self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. @@ -447,7 +460,11 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> else: self._completed_steps = metadata["completed_steps"] # TODO v0.3: Move barrier, ok file to FastLLMModel - self._run.barrier(f"load {config.save_name} {iteration} exit") + safe_barrier( + self._distributed.world_group, + f"load {config.save_name} {iteration} exit", + timeout=self._config.training.timeout, + ) def _get_last_checkpoint(self) -> int | None: if self._run.experiment_directory is None: diff --git a/tests/common.py b/tests/common.py index 8da60bb9..69048f8c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -35,11 +35,14 @@ TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -DATASET_PREFIX = TEST_RESULTS_PATH / "dataset" / "common" +DATASET_CACHE = TEST_RESULTS_PATH / "dataset" +DATASET_PREFIX = DATASET_CACHE / "common" +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" TEST_VOCAB_SIZE = 8192 # Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" +TEST_DATASET_TOKENS = 1000000 CONFIG_BASE_FAST_LLM = [ "training.logs.interval=1", @@ -209,7 +212,11 @@ def get_test_dataset( - prefix=DATASET_PREFIX, seed=1234, num_tokens=1000000, characters=TEST_CHARACTERS, vocab_size=TEST_VOCAB_SIZE + prefix: pathlib.Path = DATASET_PREFIX, + seed: int = 1234, + num_tokens: int = TEST_DATASET_TOKENS, + characters: str = TEST_CHARACTERS, + vocab_size: int = TEST_VOCAB_SIZE, ): if not TOKENIZER_FILE.is_file(): import transformers @@ -228,6 +235,28 @@ def get_test_dataset( GPTMemmapDataset.write_dataset(prefix, documents) +def get_test_concatenated_memmap_dataset( + path: pathlib.Path, + num_files: int, + seed: int = 1234, + num_tokens: int = TEST_DATASET_TOKENS, + characters: str = TEST_CHARACTERS, + vocab_size: int = TEST_VOCAB_SIZE, + seed_shift: int = 55, +): + index_file = path / "index.txt" + if not index_file.is_file(): + for i in range(num_files): + get_test_dataset( + prefix=path / f"dataset_{i}", + seed=seed + i * seed_shift, + num_tokens=num_tokens, + characters=characters, + vocab_size=vocab_size, + ) + index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) + + def run_test_script( name: str, script: list[str], diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 969db0e0..394553e5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,67 +1,667 @@ import pathlib +import typing import numpy as np +import pytest +from fast_llm.config import NoAutoValidate +from fast_llm.data.config import TokenizerConfig from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.gpt.config import GPTDatasetConfig +from fast_llm.data.dataset.gpt.config import ( + GPTBlendedDatasetConfig, + GPTConcatenatedDatasetConfig, + GPTConcatenatedMemmapConfig, + GPTDatasetSliceConfig, + GPTFimSampledDatasetConfig, + GPTMemmapDatasetConfig, + GPTRandomDatasetConfig, + GPTSampledDatasetConfig, + GPTSamplingConfig, +) +from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert -from tests.common import DATASET_PREFIX, TEST_RESULTS_PATH, TEST_VOCAB_SIZE, get_test_dataset +from tests.common import ( + DATASET_CACHE, + DATASET_PREFIX, + DATASET_SAMPLING_CACHE, + TEST_VOCAB_SIZE, + TOKENIZER_PATH, + get_test_concatenated_memmap_dataset, + get_test_dataset, +) -def get_test_data( - config: dict, - samples_per_phase: dict[PhaseType, int], +def get_sampling_config( + num_samples: int, + *, + seed: int = 54983, cache_directory: pathlib.Path | None = None, + distributed: Distributed = Distributed(DistributedConfig(), use_cpu=True), + phase=PhaseType.training, sequence_length: int = 512, -): - # TODO: Update - distributed_config = DistributedConfig() - distributed = Distributed(distributed_config, use_cpu=True) - data = GPTData(GPTDataConfig.from_dict(config), distributed_config, TEST_VOCAB_SIZE, sequence_length) - data.setup(distributed, PhaseSplits[int](samples_per_phase), cache_directory) - return data + vocab_size=TEST_VOCAB_SIZE, + tokenizer: Tokenizer | None = None, +) -> GPTSamplingConfig: + # Config with convenient defaults. + return GPTSamplingConfig( + num_samples=num_samples, + seed=seed, + cache_directory=cache_directory, + distributed=distributed, + phase=phase, + sequence_length=sequence_length, + vocab_size=vocab_size, + tokenizer=tokenizer, + ) -DATASET_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" +def _get_dataset_config[T: GPTSampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: + dataset_config = GPTSampledDatasetConfig.from_dict(config) + Assert.custom(isinstance, dataset_config, cls) + return typing.cast(cls, dataset_config) -def get_dataset( +def get_test_data_and_samples( config: dict, samples_per_phase: dict[PhaseType, int], + seed: int = 54983, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, + vocab_size=TEST_VOCAB_SIZE, ): - # TODO: Update - return get_test_data({"dataset": config}, samples_per_phase, cache_directory, sequence_length)._datasets + distributed_config = DistributedConfig(seed=seed) + distributed = Distributed(distributed_config, use_cpu=True) + data = GPTData(GPTDataConfig.from_dict(config), distributed_config, vocab_size, sequence_length) + data.setup(distributed, samples_per_phase, cache_directory) + with NoAutoValidate(): + batch_config = BatchConfig(batch_size=1, sequence_length=sequence_length) + batch_config.setup(distributed_config) + batch_config.validate() + samples = { + phase: [batch[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] + for phase, samples in samples_per_phase.items() + } + return data, samples + + +_DATASET_PREFIX_MIX_1 = DATASET_PREFIX.with_name("blended_mix_1") +_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap" + + +def _get_test_dataset_mix_1(): + return get_test_dataset(prefix=_DATASET_PREFIX_MIX_1, seed=2345) + + +def _get_test_dataset_concatenated_memmap(): + return get_test_concatenated_memmap_dataset(_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP, 4) + + +RANDOM_DATASET_EXPECTED_SAMPLES = [ + [3954, 4105, 6766, 859, 5494, 1675, 1303, 6913], + [1654, 5701, 32, 1662, 7053, 3487, 1861, 1502], + [5409, 6240, 5504, 7458, 7667, 3955, 3151, 3912], + [5640, 6131, 7750, 2699, 1349, 2585, 7113, 6981], +] + + +def test_gpt_random_dataset(): + # Make sure the random dataset works and check for unintended changes in behavior. + sampled = _get_dataset_config({"type": "random"}, GPTRandomDatasetConfig).build_and_sample( + get_sampling_config(4, sequence_length=7) + ) + Assert.eq(len(sampled), 4) + Assert.all_equal( + np.stack([sampled[i] for i in range(4)]), + np.array(RANDOM_DATASET_EXPECTED_SAMPLES), + ) + + +def test_gpt_random_data(): + _, samples = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "random", + } + } + }, + {PhaseType.training: 4}, + sequence_length=7, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(RANDOM_DATASET_EXPECTED_SAMPLES), + ) + + +def test_gpt_random_data_legacy(): + _, samples = get_test_data_and_samples({"format": "random"}, {PhaseType.training: 4}, sequence_length=7) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(RANDOM_DATASET_EXPECTED_SAMPLES), + ) + + +# Most documents are too long to write here, we test a few known short ones. +MEMMAP_DATASET_EXPECTED_LENGTH = 6153 +MEMMAP_DATASET_EXPECTED_TOKENS = 508327 +MEMMAP_DATASET_EXPECTED_SAMPLES = { + 9: [], + 10: [80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], + 13: [78, 727, 74, 317, 1358, 89], + 15: [78], +} + + +@pytest.mark.parametrize("cache_directory", (None, pathlib.Path(DATASET_SAMPLING_CACHE) / "test_memmap")) +def test_gpt_memmap(cache_directory): + # Make sure the memmap dataset works and check for unintended changes in behavior. + get_test_dataset() + dataset = _get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build() + Assert.eq(len(dataset), MEMMAP_DATASET_EXPECTED_LENGTH) + sizes = dataset.get_document_sizes() + Assert.eq(sizes.sum(), MEMMAP_DATASET_EXPECTED_TOKENS) + Assert.all_equal([len(dataset.get(i)) for i in range(100)], sizes[:100]) + for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): + Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) + + +GPT_SAMPLED_EXPECTED_SAMPLES = [ + [1725, 74, 207, 1635, 4440, 2774], + [359, 489, 4266, 2052, 5351, 80], + [374, 7534, 87, 1073, 79, 480], + [8008, 498, 71, 727, 80, 315], + [2210, 8179, 73, 2582, 897, 1178], + [409, 5091, 328, 1378, 5483, 88], + [83, 4457, 3316, 333, 489, 317], + [330, 155, 2449, 1136, 1106, 5370], +] + + +def test_gpt_sampled(): + # Make sure the memmap dataset works and check for unintended changes in behavior. + get_test_dataset() + sampled = _get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build_and_sample( + get_sampling_config(8, sequence_length=5) + ) + Assert.eq(len(sampled), 8) + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES), + ) + + +def test_gpt_sampled_data(): + get_test_dataset() + _, samples = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "memmap", + "path": DATASET_PREFIX, + } + } + }, + {PhaseType.training: 8}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES), + ) + + +def test_gpt_sampled_data_legacy(): + _, samples = get_test_data_and_samples( + {"format": "list", "path": [str(DATASET_PREFIX)], "split": [1, 0, 0]}, + {PhaseType.training: 8}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES), + ) + + +GPT_CONCATENATED_EXPECTED_SAMPLES = [ + [243, 498, 7172, 777, 306, 74], + [821, 6042, 89, 977, 4797, 499], + [387, 74, 330, 328, 1858, 484], + [7722, 3069, 819, 4266, 304, 80], + [80, 634, 4913, 373, 207, 1046], + [72, 65, 5570, 73, 2210, 5514], + [7983, 977, 4147, 4739, 890, 386], + [5375, 275, 69, 771, 593, 8171], +] + + +def test_gpt_concatenate(): + # Make sure the dataset concatenation works and check for unintended changes in behavior. + get_test_dataset() + dataset = _get_dataset_config( + {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, + GPTConcatenatedDatasetConfig, + ).build() + Assert.eq(len(dataset), 3 * MEMMAP_DATASET_EXPECTED_LENGTH) + sizes = dataset.get_document_sizes() + Assert.eq(sizes.sum(), 3 * MEMMAP_DATASET_EXPECTED_TOKENS) + for i in range(3): + begin = i * MEMMAP_DATASET_EXPECTED_LENGTH + Assert.all_equal([len(dataset.get(begin + i)) for i in range(100)], sizes[begin : begin + 100]) + for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): + Assert.all_equal(dataset.get(begin + i), np.array(sample, dtype=np.uint16)) + sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) + Assert.eq(len(sampled), 8) + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_CONCATENATED_EXPECTED_SAMPLES), + ) + + +def test_gpt_concatenate_data(): + _, samples = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "concatenated", + "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)], + } + } + }, + {PhaseType.training: 8}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_CONCATENATED_EXPECTED_SAMPLES), + ) + + +GPT_SLICE_EXPECTED_TRAINING_SAMPLES = [ + [2625, 76, 2625, 2639, 74, 243], + [207, 481, 5546, 74, 414, 498], + [74, 333, 1963, 310, 5337, 3628], + [79, 2361, 80, 2012, 84, 480], +] + +GPT_SLICE_EXPECTED_VALIDATION_SAMPLES = [ + [2352, 3687, 2311, 4900, 542, 3732], + [2551, 5283, 900, 3140, 328, 68], + [7979, 2283, 329, 727, 2740, 2818], + [4117, 8056, 79, 1798, 243, 498], + [243, 542, 387, 6476, 6686, 785], + [95, 6641, 207, 279, 2304, 602], + [89, 4446, 947, 293, 947, 1544], + [243, 3712, 86, 476, 80, 2547], +] + + +def test_gpt_slice(): + # Make sure dataset splitting works and check for unintended changes in behavior. + get_test_dataset() + # samples[9:18] + dataset = _get_dataset_config( + {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, + GPTDatasetSliceConfig, + ).build() + Assert.eq(len(dataset), 9) + sizes = dataset.get_document_sizes() + Assert.all_equal([len(dataset.get(i)) for i in range(9)], sizes[:9]) + for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): + Assert.all_equal(dataset.get(i - 9), np.array(sample, dtype=np.uint16)) + sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) + Assert.eq(len(sampled), 8) + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), + ) -def test_dummy_dataset(): - # TODO: Update - dataset = GPTDatasetConfig.from_dict({"type": "dummy"}) +def test_gpt_slice_data(): + _, samples = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "slice", + "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "begin": 0, + "end": 0.0015, + }, + "Validation": { + "type": "slice", + "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "begin": 0.0015, + "end": 0.003, + }, + "Test": { + "type": "slice", + "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "begin": 0.003, + "end": 1, + }, + } + }, + {PhaseType.training: 4, PhaseType.validation: 8, PhaseType.test: 5}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.validation]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES), + ) + + +def test_gpt_slice_data_legacy(): + get_test_dataset() + _, samples = get_test_data_and_samples( + {"format": "list", "path": [str(DATASET_PREFIX)], "split": [0.0015, 0.0015, 0.997]}, + {PhaseType.training: 4, PhaseType.validation: 8, PhaseType.test: 5}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.validation]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES), + ) + + +COMPOSED_DATASET_EXPECTED_LENGTH = 24806 +COMPOSED_DATASET_EXPECTED_TOKENS = 2033639 + +COMPOSED_DATASET_EXPECTED_SAMPLES = { + **MEMMAP_DATASET_EXPECTED_SAMPLES, + 6930: [65, 2327], + 11962: [7078, 2713, 1431], + 15958: [207], + 19362: [69], + 24098: [555, 668, 70], +} + + +GPT_COMPOSED_EXPECTED_SAMPLES = [ + [1411, 819, 6791, 7022, 285, 249], + [329, 328, 512, 1985, 3069, 7838], + [5158, 1023, 8171, 798, 1431, 313], + [1073, 3917, 275, 480, 74, 1752], + [207, 317, 269, 6662, 4357, 498], + [74, 310, 277, 7091, 668, 367], + [7828, 480, 89, 116, 4604, 69], + [79, 6042, 577, 225, 207, 207], +] + + +def test_gpt_compose(): + # Make sure dataset splitting works and check for unintended changes in behavior. + _get_test_dataset_concatenated_memmap() + # samples[9:18] + dataset = _get_dataset_config( + {"type": "concatenated_memmap", "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP}, + GPTConcatenatedMemmapConfig, + ).build() + Assert.eq(len(dataset), COMPOSED_DATASET_EXPECTED_LENGTH) + sizes = dataset.get_document_sizes() + Assert.eq(sizes.sum(), COMPOSED_DATASET_EXPECTED_TOKENS) + Assert.all_equal([len(dataset.get(i)) for i in range(0, len(dataset), 20)], sizes[::20]) + for i, sample in COMPOSED_DATASET_EXPECTED_SAMPLES.items(): + Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) + sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) + Assert.eq(len(sampled), 8) + print(np.stack([sampled[i] for i in range(8)]).tolist()) + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_COMPOSED_EXPECTED_SAMPLES), + ) + + +def test_gpt_composed_data(): + _get_test_dataset_concatenated_memmap() + _, samples = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "composed", + "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP, + } + } + }, + {PhaseType.training: 8}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_COMPOSED_EXPECTED_SAMPLES), + ) + + +GPT_BLENDED_EXPECTED_SAMPLES = [ + [1725, 74, 207, 1635, 4440, 2774], + [2066, 207, 6436, 2360, 2210, 6633], + [359, 489, 4266, 2052, 5351, 80], + [374, 7534, 87, 1073, 79, 480], + [8008, 498, 71, 727, 80, 315], + [555, 3042, 83, 207, 498, 3373], + [2210, 8179, 73, 2582, 897, 1178], + [409, 5091, 328, 1378, 5483, 88], +] + + +def test_gpt_blended(): + # Make sure dataset blending works and check for unintended changes in behavior. + get_test_dataset() + _get_test_dataset_mix_1() + sampled = _get_dataset_config( + { + "type": "blended", + "datasets": [ + {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, + ], + "weights": [0.75, 0.25], + }, + GPTBlendedDatasetConfig, + ).build_and_sample(get_sampling_config(8, sequence_length=5)) + Assert.eq(len(sampled), 8) + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_BLENDED_EXPECTED_SAMPLES), + ) + + +def test_gpt_blended_data(): + get_test_dataset() + _get_test_dataset_mix_1() + _, samples = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "blended", + "datasets": [ + {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, + ], + "weights": [0.75, 0.25], + } + } + }, + {PhaseType.training: 8}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_BLENDED_EXPECTED_SAMPLES), + ) + + +GPT_BLENDED_LEGACY_EXPECTED_SAMPLES = [ + [1725, 74, 207, 1635, 4440, 2774], + [328, 80, 263, 890, 1797, 88], + [359, 489, 4266, 2052, 5351, 80], + [374, 7534, 87, 1073, 79, 480], + [8008, 498, 71, 727, 80, 315], + [1852, 71, 776, 7878, 7390, 80], + [2210, 8179, 73, 2582, 897, 1178], + [409, 5091, 328, 1378, 5483, 88], +] - datasets = get_test_datasets( - {"type": "dummy"}, - {PhaseType.training: 7, PhaseType.test: 4}, + +def test_gpt_blended_data_legacy(): + get_test_dataset() + _get_test_dataset_mix_1() + _, samples = get_test_data_and_samples( + { + "format": "list", + "path": ["0.75", str(DATASET_PREFIX), "0.25", str(_DATASET_PREFIX_MIX_1)], + "split": [1, 0, 0], + }, + {PhaseType.training: 8}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_BLENDED_LEGACY_EXPECTED_SAMPLES), ) - Assert.eq(datasets.keys(), {PhaseType.training, PhaseType.test}) - train = datasets[PhaseType.training] - Assert.eq(len(train), 7) - assert all(np.all(train[i] == train._dataset._dummy_sample) for i in range(7)) - test = datasets[PhaseType.test] - Assert.eq(len(test), 4) - assert all(np.all(test[i] == test._dataset._dummy_sample) for i in range(4)) -def test_memmap_dataset(): - # TODO: Update +GPT_BLENDED_MIXED_EXPECTED_SAMPLES = [ + [1725, 74, 207, 1635, 4440, 2774], + [916, 6683, 7685, 1277, 5106, 378], + [359, 489, 4266, 2052, 5351, 80], + [3359, 6803, 780, 4561, 669, 7878], + [374, 7534, 87, 1073, 79, 480], + [8008, 498, 71, 727, 80, 315], + [6920, 2218, 2921, 3963, 7606, 6904], + [2210, 8179, 73, 2582, 897, 1178], +] + + +def test_gpt_blended_mixed(): + # Make sure dataset blending works and check for unintended changes in behavior. get_test_dataset() - dataset = get_test_datasets( - {"type": "memmap", "path": DATASET_PREFIX}, - {PhaseType.training: 1}, + sampled = _get_dataset_config( + { + "type": "blended", + "datasets": [ + {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "random"}, + ], + "weights": [0.6, 0.4], + }, + GPTBlendedDatasetConfig, + ).build_and_sample(get_sampling_config(8, sequence_length=5)) + Assert.eq(len(sampled), 8) + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES), + ) + + +def test_gpt_blended_mixed_data(): + _, samples = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "blended", + "datasets": [{"type": "memmap", "path": DATASET_PREFIX}, {"type": "random"}], + "weights": [0.6, 0.4], + } + } + }, + {PhaseType.training: 8}, sequence_length=5, - )[PhaseType.training] - Assert.eq(len(dataset), 5) - raise AssertionError() + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES), + ) + + +GPT_FIM_EXPECTED_SAMPLES = [ + [1725, 74, 207, 1635, 4440, 2774], + [359, 489, 4266, 2052, 5351, 80], + [86, 89, 22255, 1073, 79, 480], + [8008, 498, 71, 727, 80, 315], + [2210, 8179, 73, 2582, 897, 1178], + [86, 89, 88, 87, 409, 70], + [86, 83, 744, 89, 64, 333], + [86, 89, 1461, 87, 330, 7876], +] + + +def test_gpt_fim(): + # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. + get_test_dataset() + # The test tokenizer doesn't have fim tokens, so we work around it. + sampling_config = get_sampling_config( + 8, sequence_length=5, tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})) + ) + sampled = _get_dataset_config( + { + "type": "fim", + "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "rate": 0.5, + "prefix_token": "w", + "middle_token": "x", + "pad_token": "y", + "suffix_token": "z", + }, + GPTFimSampledDatasetConfig, + ).build_and_sample(sampling_config) + Assert.eq(len(sampled), 8) + # TODO: Does this output make sense? + Assert.all_equal( + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_FIM_EXPECTED_SAMPLES), + ) + + +def test_gpt_fim_data(): + _, samples = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "fim", + "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "rate": 0.5, + "prefix_token": "w", + "middle_token": "x", + "pad_token": "y", + "suffix_token": "z", + } + }, + "tokenizer": {"path": TOKENIZER_PATH}, + }, + {PhaseType.training: 8}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_FIM_EXPECTED_SAMPLES), + ) + + +def test_gpt_fim_data_legacy(): + _, samples = get_test_data_and_samples( + { + "format": "list", + "path": [str(DATASET_PREFIX)], + "fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"}, + "tokenizer": {"path": TOKENIZER_PATH}, + "split": [1, 0, 0], + }, + {PhaseType.training: 8}, + sequence_length=5, + ) + Assert.all_equal( + np.stack(samples[PhaseType.training]), + np.array(GPT_FIM_EXPECTED_SAMPLES), + ) diff --git a/tests/test_simple.py b/tests/test_simple.py index aa5e842d..85727bca 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -32,6 +32,32 @@ def test_model_dp2(): run_test_script(f"test_{TEST_MODEL}_dp2", CONFIG_COMMON, num_gpus=2, compare=f"test_{TEST_MODEL}") +@pytest.mark.slow +def test_model_dp2_timeout(): + # Test sampling timeout + # TODO: Find a better way to test this + run_test_script( + f"test_{TEST_MODEL}_dp2", + CONFIG_COMMON + + [ + # Use a short timeout + "model.distributed.timeout=4", + # Make a dataset that would timeout under the distributed timeout + 'data.datasets.Training={"type":"test_slow"}', + "data.datasets.Training.type=test_slow", + "data.datasets.Training.sleep=6", + # Use a bigger timeout for the dataset. + "training.timeout=10", + # Remove testing clutter. + f"model.multi_stage.debug_param_init=0", + f"model.multi_stage.debug_layer_outputs=0", + f"model.multi_stage.debug_layer_gradients=0", + f"model.multi_stage.debug_all_param_gradients=0", + ], + num_gpus=2, + ) + + @pytest.mark.slow @pytest.mark.depends(on=["test_model"]) def test_model_tp2(): From 158a6f752890e86d5f83f4277392a4bbb5360407 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 27 Jan 2025 19:53:29 -0500 Subject: [PATCH 18/19] Fix merge --- fast_llm/data/data/abstract.py | 1 + fast_llm/data/data/gpt/data.py | 14 +- fast_llm/data/dataset/gpt/config.py | 218 ++++++++++++------ fast_llm/data/dataset/gpt/fim.py | 18 +- fast_llm/data/dataset/gpt/indexed.py | 20 +- fast_llm/data/dataset/gpt/memmap.py | 10 +- fast_llm/data/dataset/indexed.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 4 + fast_llm/profile.py | 46 ++-- fast_llm/utils.py | 28 ++- 10 files changed, 231 insertions(+), 130 deletions(-) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index e737bf7a..8e691b31 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -26,6 +26,7 @@ def setup( distributed: "Distributed", samples_per_phase: dict[PhaseType, int], cache_directory: pathlib.Path, + timeout: float | None = None, ) -> None: self._distributed = distributed self._samples_per_phase = samples_per_phase diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index abc403e1..14ca9689 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -6,6 +6,7 @@ import torch import torch.utils.data +from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset @@ -31,7 +32,6 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): _datasets: dict[PhaseType, SampledDataset] _tokenizer: Tokenizer | None - _distributed: Distributed _is_setup: bool = False def __init__( @@ -49,19 +49,12 @@ def __init__( self._vocab_size = vocab_size self._max_sequence_length = max_sequence_length - @property - def vocab_size(self) -> int: - return self._vocab_size - - @property - def max_sequence_length(self) -> int: - return self._max_sequence_length - def setup( self, distributed: "Distributed", samples_per_phase: dict[PhaseType, int], cache_directory: pathlib.Path, + timeout: float | None = None, ) -> None: """ Load the datasets, and prepare or load the samplings. @@ -84,7 +77,6 @@ def setup( num_samples=samples_per_phase[phase], seed=self._distributed_config.seed, cache_directory=self._cache_directory, - verbose=True, distributed=distributed, phase=phase, sequence_length=self._max_sequence_length, @@ -93,6 +85,8 @@ def setup( ) dataset = self._config.datasets[phase].build_and_sample(sampling_config) self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + + safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True @property diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 1c3db56e..da8eb3ca 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -2,14 +2,15 @@ import enum import json import pathlib +import time import typing +import warnings from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import ( BlendedDatasetConfig, ConcatenatedDatasetConfig, - DatasetConfig, DatasetSliceConfig, IndexedDatasetConfig, SamplableDatasetConfig, @@ -20,9 +21,9 @@ from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.dummy import GPTDummySampledDataset from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset + from fast_llm.data.dataset.gpt.random import GPTRandomDataset, GPTRandomSampledDataset from fast_llm.data.tokenizer import Tokenizer @@ -35,12 +36,13 @@ class GPTSamplingConfig(SamplingConfig): @config_class() -class GPTDatasetConfig(DatasetConfig): +class GPTSampledDatasetConfig(SampledDatasetConfig): + # TODO: Generalize dynamic types? - _registry: typing.ClassVar[Registry[str, type["GPTDatasetConfig"]]] = Registry[str, type["GPTDatasetConfig"]]( - "gpt_dataset_class", {} - ) - type_: typing.ClassVar[type["GPTDatasetConfig"] | None] = None + _registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[ + str, type["GPTDatasetConfig"] + ]("gpt_dataset_class", {}) + type_: typing.ClassVar[str | None] = None type: str | None = Field( default=None, desc="The type of dataset.", @@ -48,9 +50,10 @@ class GPTDatasetConfig(DatasetConfig): ) def _validate(self) -> None: - if self.type is not None: - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.eq(self.type, self.type_) + if self.type is None: + self.type = self.type_ + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.eq(self.type, self.__class__.type_) super()._validate() @classmethod @@ -64,6 +67,10 @@ def _from_dict( if type_ is None: actual_cls = cls else: + if type_ not in cls._registry: + raise ValueError( + f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" + ) actual_cls = cls._registry[type_] Assert.custom(issubclass, actual_cls, cls) if actual_cls == cls: @@ -71,18 +78,20 @@ def _from_dict( else: return actual_cls._from_dict(default, strict=strict, flat=flat) - def __init_subclass__(cls, type_: str | None = None, **kwargs) -> None: - if type_ is not None: - GPTDatasetConfig._registry[type_] = cls - cls.type_ = type_ + def __init_subclass__(cls) -> None: + if cls._abstract and cls.type_ is not None: + # Abstract classes should not have a `type_` + raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") + if cls.type_ is not None: + if cls.type_ in cls._registry: + raise ValueError( + f"Registry {cls._registry.name} already contains type {cls.type_}." + f" Make sure all classes either have a unique or `None` type." + ) + GPTSampledDatasetConfig._registry[cls.type_] = cls super().__init_subclass__() -@config_class() -class GPTSampledDatasetConfig(SampledDatasetConfig, GPTDatasetConfig): - pass - - @config_class() class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig): pass @@ -95,24 +104,25 @@ def build(self) -> "GPTIndexedDataset": @config_class() -class GPTDummyDatasetConfig(GPTSampledDatasetConfig, type_="dummy"): - # TODO: Can't make it a samplable dataset because necessary info is in sampling config. - _abstract = False +class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "random" name: str = Field( default="dummy", desc="The name of the dataset.", hint=FieldHint.core, ) - def build_and_sample(self, config: GPTSamplingConfig) -> "GPTDummySampledDataset": - from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset + def build(self) -> "GPTRandomDataset": + from fast_llm.data.dataset.gpt.random import GPTRandomDataset - return GPTDummyDataset(self.name, config.sequence_length, config.vocab_size).sample(config) + return GPTRandomDataset(self.name) @config_class() -class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig, type_="memmap"): - _abstract = False +class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "memmap" path: pathlib.Path = Field( default=None, desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", @@ -126,8 +136,9 @@ def build(self) -> "GPTMemmapDataset": @config_class() -class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig, type_="concatenated"): - _abstract = False +class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "concatenated" datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() def build(self) -> "GPTConcatenatedDataset": @@ -137,8 +148,9 @@ def build(self) -> "GPTConcatenatedDataset": @config_class() -class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig, type_="slice"): - _abstract = False +class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "slice" dataset: GPTIndexedDatasetConfig = FieldUpdate() def build(self) -> "GPTDatasetSlice": @@ -148,28 +160,45 @@ def build(self) -> "GPTDatasetSlice": @config_class() -class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig, type_="blended"): - _abstract = False +class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "blended" datasets: list[GPTSampledDatasetConfig] = FieldUpdate() -class LegacyDatasetSource(str, enum.Enum): - """ - An enum for the different ways to load datasets. - """ - - list = "list" - file = "file" - random = "random" - +@config_class() +class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "concatenated_memmap" + path: pathlib.Path = Field( + default=None, + desc="The path to a dataset directory.", + hint=FieldHint.core, + ) -def _validate_split(value: list[int]) -> list[int]: - Assert.leq(len(value), 3) - return value + [0] * (len(value) - 3) + def build(self) -> "GPTConcatenatedDataset": + pass + assert self.path.is_dir() + index_path = self.path / "index.txt" -def _validate_path(value: str | list[str]) -> list[str]: - return [value] if isinstance(value, str) else value + if index_path.is_file(): + prefixes = [self.path / line.strip() for line in index_path.open("r").readlines()] + else: + warnings.warn( + f"The dataset path {self.path} points to a directory." + " The dataset will be indexed automatically, which may be unsafe." + " We recommend using an index file instead." + ) + prefixes = [ + path.with_suffix("") + for path in self.path.iterdir() + if path.suffix == ".idx" and path.is_file() and path.with_suffix(".bin").is_file() + ] + dataset_config = GPTConcatenatedDatasetConfig.from_dict( + {"datasets": [{"type": "memmap", "path": prefix} for prefix in prefixes]} + ) + return dataset_config.build() @config_class() @@ -218,14 +247,37 @@ class FimConfig(Config): desc="TODO.", hint=FieldHint.feature, ) + prefix_token: str = Field( + default="", + desc="TODO.", + hint=FieldHint.feature, + ) + middle_token: str = Field( + default="", + desc="TODO.", + hint=FieldHint.feature, + ) + pad_token: str = Field( + default="", + desc="TODO.", + hint=FieldHint.feature, + ) + suffix_token: str = Field( + default="", + desc="TODO.", + hint=FieldHint.feature, + ) @config_class() -class FimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig, type_="fim"): +class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): """ Configuration for FIM. """ + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "fim" + dataset: GPTSampledDatasetConfig = Field( default=None, desc="The dataset to wrap with fim.", @@ -236,25 +288,33 @@ def build_and_sample( self, config: GPTSamplingConfig, ) -> SampledDataset: - from fast_llm.data.dataset.gpt.fim import FimDataset + from fast_llm.data.dataset.gpt.fim import GPTFimDataset + + return GPTFimDataset(self, self.dataset.build_and_sample(config), config) + + +class LegacyDatasetSource(str, enum.Enum): + """ + An enum for the different ways to load datasets. + """ + + list = "list" + file = "file" + random = "random" - return FimDataset(self, self.dataset.build_and_sample(config), config) + +def _validate_split(value: list[int]) -> list[int]: + Assert.leq(len(value), 3) + return value + [0] * (len(value) - 3) + + +def _validate_path(value: str | list[str]) -> list[str]: + return [value] if isinstance(value, str) else value @config_class() class GPTLegacyConfig(Config): - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ): - # TODO v0.3: Remove. - cls._handle_renamed_field(default, "split", "ratio") - return super()._from_dict(default, strict, flat) - - ratio: list[float] = Field( + split: list[float] = Field( default_factory=lambda: [969, 30, 1], desc="Split ratio for train, valid and test datasets.", hint=FieldHint.deprecated, @@ -279,14 +339,15 @@ def _from_dict( @config_class() -class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig, type_="legacy"): - _abstract = False +class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "legacy" def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset: if self.format == LegacyDatasetSource.random: Assert.eq(len(self.path), 0) - dataset_config = GPTDummyDatasetConfig() + dataset_config = GPTRandomDatasetConfig() else: if self.format == LegacyDatasetSource.file: Assert.eq(len(self.path), 1) @@ -311,7 +372,7 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset: else: raise NotImplementedError(self.format) - phase_splits = padded_cumsum(self.ratio) + phase_splits = padded_cumsum(self.split) phase_index = { PhaseType.training: 0, PhaseType.validation: 1, @@ -332,14 +393,37 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset: name="blended", datasets=dataset_configs, weights=dataset_weights, + legacy=True, ) if len(dataset_configs) > 1 else dataset_configs[0] ) if self.fim.rate > 0: - dataset_config = FimSampledDatasetConfig.from_dict( + dataset_config = GPTFimSampledDatasetConfig.from_dict( self.fim, {"dataset": dataset_config}, ) return dataset_config.build_and_sample(config) + + +@config_class() +class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): + """ + A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. + """ + + # TODO: This belongs to a testing plugin. + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "test_slow" + sleep: float = Field( + default=1, + desc="Sleep time during build, in seconds.", + hint=FieldHint.core, + ) + + def build_and_sample(self, config: SamplingConfig) -> "GPTRandomSampledDataset": + assert config.distributed.config.world_size > 1 + if config.distributed.config.rank == 0: + time.sleep(self.sleep) + return GPTRandomDatasetConfig().build_and_sample(config) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 0ed76d80..1ffb3bfc 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -4,13 +4,8 @@ from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingConfig from fast_llm.engine.distributed.config import MAX_SEED -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_PAD = "" -FIM_SUFFIX = "" - -class FimDataset(SampledDataset): +class GPTFimDataset(SampledDataset): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -24,10 +19,13 @@ def __init__( ): self._config = config self._dataset = dataset - self._sampling_config = sampling_config + self._seed = sampling_config.seed self._tokenizer = sampling_config.tokenizer + if self._tokenizer is None: + raise ValueError("Fim requires a tokenizer") self._suffix_tok_id, self._prefix_tok_id, self._middle_tok_id, self._pad_tok_id = ( - self._tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD] + self._tokenizer.vocab[tok] + for tok in [config.suffix_token, config.prefix_token, config.middle_token, config.pad_token] ) self.fim_split_sample = ( self._tokenizer.vocab[self._config.split_sample] if self._config.split_sample is not None else None @@ -37,9 +35,7 @@ def __len__(self) -> int: return len(self._dataset) def __getitem__(self, idx: int) -> np.ndarray: - sample = self._fim( - self._dataset[idx], np.random.RandomState(seed=(self._sampling_config.seed + idx) % MAX_SEED) - ) + sample = self._fim(self._dataset[idx], np.random.RandomState(seed=(self._seed + idx) % MAX_SEED)) return sample @property diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 833c569e..2c158bff 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -11,20 +11,6 @@ class GPTIndexedDataset(IndexedDataset): - """ - A GPT dataset containing a list of samples. - """ - - # def get(self, index: int, offset: int = 0, length: int | None = None): - # pass - - # def __len__(self) -> int: - # """ - # Number of documents in the dataset. - # Can be calculated from document sizes but may be overridden if there is a better method. - # """ - # return len(self.get_document_sizes()) - @abc.abstractmethod def get_document_sizes(self) -> np.ndarray: """ @@ -39,7 +25,7 @@ def sample(self, config: GPTSamplingConfig) -> "GPTSampledIndexedDataset": return GPTSampledIndexedDataset(self, config) -class GPTDatasetSlice(DatasetSlice, GPTIndexedDataset): +class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. """ @@ -51,7 +37,9 @@ def get_document_sizes(self) -> np.ndarray: return self._dataset.get_document_sizes()[self._begin : self._end] -class GPTConcatenatedDataset(ConcatenatedDataset, GPTIndexedDataset): +class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( + ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset +): _datasets: list[GPTIndexedDataset] def get_document_sizes(self) -> np.ndarray: diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index cb531b56..ec56be4a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -58,10 +58,12 @@ def __setstate__(self, state: tuple[str, pathlib.Path]): self._init(*state) def __del__(self): - self._bin_buffer_mmap._mmap.close() # noqa - del self._bin_buffer_mmap - self._index_bin_buffer_mmap._mmap.close() # noqa - del self._index_bin_buffer_mmap + if hasattr(self, "_bin_buffer_mmap"): + self._bin_buffer_mmap._mmap.close() # noqa + del self._bin_buffer_mmap + if hasattr(self, "_index_bin_buffer"): + self._index_bin_buffer_mmap._mmap.close() # noqa + del self._index_bin_buffer_mmap def get(self, idx, offset=0, length=None) -> np.ndarray: return np.frombuffer( diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index b9226724..8a652dda 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -24,7 +24,7 @@ def __len__(self) -> int: """ -class DatasetSlice(IndexedDataset): +class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset): def __init__( self, diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d35c3c4a..ce3ec0e6 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -166,6 +166,10 @@ def run(self) -> None: output_file = self._config.output_path / "fast_llm_dataset.json" json.dump({"datasets": dataset_dicts}, output_file.open("w")) + # Create an index file on rank 0 + index_file = self._config.output_path / "index.txt" + index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts]) + # Finalize distributed processing if self._config.distributed.world_size > 1: torch.distributed.barrier() diff --git a/fast_llm/profile.py b/fast_llm/profile.py index 3ddf0a8c..a0fc3946 100644 --- a/fast_llm/profile.py +++ b/fast_llm/profile.py @@ -143,10 +143,11 @@ def trace_fn( column_width=config.table_width, header=f"Trace for step {step}", ) - if config.log: - logger.info(table) - else: - run.open_artifact(f"profile_trace_step_{step}").write(table) + if table: + if config.log: + logger.info(table) + else: + run.open_artifact(f"profile_trace_step_{step}").write(table) if config.averages: table = build_average_table( @@ -156,13 +157,16 @@ def trace_fn( column_width=config.table_width, header=f"Averages for step {step}", ) - if config.log: - logger.info(table) - else: - run.open_artifact(f"profile_averages_step_{step}").write(table) + if table: + if config.log: + logger.info(table) + else: + run.open_artifact(f"profile_averages_step_{step}").write(table) if config.export: - profiler.export_chrome_trace(str(run.open_artifact(f"profile_chrome_step_{step}", mode=None))) + # Suppress empty profile, mainly the annoying one at the end. + if _get_events(profiler, cuda=config.cuda): + profiler.export_chrome_trace(str(run.open_artifact(f"profile_chrome_step_{step}", mode=None))) # Store results for future use. profiler.bc_profile_result = profiler.profiler.function_events @@ -195,7 +199,7 @@ def trace_fn( "input_shapes": "Input Shapes", "source_loc": "Source Location", "node_id": "Node ID", - "total_flops": "Total xflops", + "total_flops": "Total tflops", } _CPU_TRACE_COLUMNS = {"name", "cpu_self", "cpu_total", "start_time", "end_time"} @@ -220,13 +224,17 @@ def trace_fn( ) +def _get_events(profiler: "torch.profiler.profile", *, cuda: bool = True): + var_name = f"self_{'device' if cuda else 'cpu'}_time_total" + return [evt for evt in profiler.profiler.function_events if getattr(evt, var_name) > 0] + + def build_trace_table( profiler: "torch.profiler.profile", *, cuda: bool = True, cpu: bool = False, column_width=80, header="Trace" ) -> str: - var_name = f"self_{'cuda' if cuda else 'cpu'}_time_total" - events = [evt for evt in profiler.profiler.function_events if getattr(evt, var_name) > 0] + var_name = f"self_{'device' if cuda else 'cpu'}_time_total" return _build_table( - events, + _get_events(profiler, cuda=cuda), (_CPU_TRACE_COLUMNS if cpu else set()) | (_CUDA_TRACE_COLUMNS if cuda else set()), name_column_width=column_width, filter_by=None if cuda and cpu else var_name, @@ -289,7 +297,9 @@ def _build_table( result = [] sum_self_cpu_time_total = sum(event.self_cpu_time_total for event in events) - sum_self_cuda_time_total = sum(event.self_cuda_time_total for event in events) # if evt.device_type == DeviceType. + sum_self_device_time_total = sum( + event.self_device_time_total for event in events + ) # if evt.device_type == DeviceType. if header is not None: result.extend(["=" * line_length, header]) @@ -321,9 +331,9 @@ def _build_table( if "cpu_avg" in columns: row_values.append(_format_time_us(evt.cpu_time)) if "cuda" in columns: - row_values.append(_format_time_us(evt.self_cuda_time_total)) + row_values.append(_format_time_us(evt.self_device_time_total)) if "cuda_percent" in columns: - row_values.append(_format_time_share(evt.self_cuda_time_total, sum_self_cuda_time_total)) + row_values.append(_format_time_share(evt.self_device_time_total, self_device_time_total)) if "cuda_total" in columns: row_values.append(_format_time_us(evt.cuda_time_total)) if "cuda_avg" in columns: @@ -340,8 +350,8 @@ def _build_table( result.append(header_sep) if sum_self_cpu_time_total > 0: result.append(f"CPU time total: {_format_time_ms(sum_self_cpu_time_total)}") - if sum_self_cuda_time_total > 0: - result.append(f"CUDA time total: {_format_time_ms(sum_self_cuda_time_total)}") + if sum_self_device_time_total > 0: + result.append(f"CUDA time total: {_format_time_ms(sum_self_device_time_total)}") result.append("") return "\n".join(result) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 44e2d586..fe9207c9 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -144,6 +144,10 @@ def rms_close(x, y, threshold): def all_equal(x, y): import torch + # Make it work for numpy arrays. + x = torch.as_tensor(x) + y = torch.as_tensor(y) + neq = x != y if neq.any().item(): # noqa index = torch.where(neq) # noqa @@ -156,9 +160,13 @@ def all_equal(x, y): def all_different(x, y): import torch + # Make it work for numpy arrays. + x = torch.as_tensor(x) + y = torch.as_tensor(y) + eq = x == y if eq.any().item(): # noqa - index = torch.where(eq) # noqa + index = torch.where(torch.as_tensor(eq)) # noqa raise AssertionError( f"Tensors have {index[0].numel()} unexpected matching entries out of " f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}" @@ -178,6 +186,7 @@ def not_custom(fn, *args, **kwargs): class Registry[KeyType, ValueType]: + # TODO: Inherit from dict instead? def __init__(self, name: str, data: dict[KeyType, ValueType]): self._name = name self._data = data.copy() @@ -195,8 +204,21 @@ def __setitem__(self, key: KeyType, value: ValueType): def keys(self) -> list[KeyType]: return list(self._data) - def __contains__(self, item: ValueType) -> bool: - return item in self._data + def __contains__(self, key: KeyType) -> bool: + return key in self._data + + def __iter__(self) -> typing.Iterator[KeyType]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def items(self): + return self._data.items() + + @property + def name(self) -> str: + return self._name class LazyRegistry[KeyType, ValueType](Registry[KeyType, ValueType]): From a4e288b751d88373e5dc1d04c37994f102bf3a9f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 27 Jan 2025 19:54:42 -0500 Subject: [PATCH 19/19] Fix merge --- fast_llm/data/config.py | 3 ++- fast_llm/data/dataset/config.py | 38 +++++++++++++++++++------- fast_llm/data/dataset/gpt/dummy.py | 40 ---------------------------- fast_llm/data/dataset/gpt/random.py | 41 +++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 51 deletions(-) delete mode 100644 fast_llm/data/dataset/gpt/dummy.py create mode 100644 fast_llm/data/dataset/gpt/random.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 32675749..1586d370 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,4 +1,5 @@ import enum +import pathlib from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.utils import Assert @@ -28,7 +29,7 @@ class TokenizerConfig(Config): hint=FieldHint.deprecated, valid=check_field(Assert.eq, TokenizerFromFile), ) - path: str | None = Field( + path: pathlib.Path | None = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 0cdb90ba..253ea1fa 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -14,22 +14,22 @@ from fast_llm.engine.distributed.distributed import Distributed -@config_class() -class DatasetConfig(Config): - _abstract = True - - @dataclasses.dataclass(kw_only=True) class SamplingConfig: # TODO: Have a separate configuration (subset?) for `build`? num_samples: int seed: int cache_directory: pathlib.Path | None - verbose: bool + # TODO: This prevents the sampling config from being pickled in multiprocessing. distributed: "Distributed" phase: PhaseType +@config_class() +class DatasetConfig(Config): + _abstract: typing.ClassVar[bool] = True + + @config_class() class SampledDatasetConfig(DatasetConfig): """ @@ -140,15 +140,21 @@ class BlendedDatasetConfig(SampledDatasetConfig): default_factory=list, desc="The datasets to blend.", hint=FieldHint.core, - valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)), ) weights: list[float] = Field( default_factory=list, desc="The blending weight of each dataset.", hint=FieldHint.core, ) + legacy: bool = Field( + default=False, + desc="Use the legacy formulas for sub-dataset seeds and sample sizes.", + hint=FieldHint.deprecated, + ) - def __post_init__(self) -> None: + def _validate(self) -> None: + super()._validate() + Assert.geq(len(self.datasets), 2) Assert.eq(len(self.datasets), len(self.weights)) def build_and_sample( @@ -158,12 +164,24 @@ def build_and_sample( from fast_llm.data.dataset.blended import BlendedDataset # Build and sample the datasets. + # TODO: Vary the seed? + # Add 5 times the standard deviation (of a binomial distribution) + # so the probability of sampling more than this amount during blending is negligible. + sampled_datasets = [ dataset.build_and_sample( # Blending is deterministic and the error will never be higher than 1. - dataclasses.replace(config, num_samples=math.ceil(weight * config.num_samples) + 1), + dataclasses.replace( + config, + num_samples=( + math.ceil(weight * (config.num_samples + 5 * (config.num_samples * (1 - weight)) ** 0.5)) + if self.legacy + else math.ceil(weight * config.num_samples) + 1 + ), + seed=config.seed + i * (0 if self.legacy else 697), + ), ) - for dataset, weight in zip(self.datasets, self.weights, strict=True) + for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) ] # Blend the datasets. return BlendedDataset( diff --git a/fast_llm/data/dataset/gpt/dummy.py b/fast_llm/data/dataset/gpt/dummy.py deleted file mode 100644 index 484d811c..00000000 --- a/fast_llm/data/dataset/gpt/dummy.py +++ /dev/null @@ -1,40 +0,0 @@ -import numpy as np - -from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig - - -class GPTDummyDataset(SamplableDataset): - """ - A dummy dataset that always returns the same random sample, for debugging purposes. - """ - - def __init__(self, name: str, sequence_length: int, vocab_size: int): - self._dummy_sample = np.random.randint(0, vocab_size, size=(sequence_length + 1,), dtype=np.int64) - self._name = name - - def sample(self, config: GPTSamplingConfig) -> "GPTDummySampledDataset": - return GPTDummySampledDataset(self, config) - - def get(self) -> np.ndarray: - return self._dummy_sample - - @property - def name(self) -> str: - return self._name - - -class GPTDummySampledDataset(SampledDataset): - def __init__(self, dataset: GPTDummyDataset, config: GPTSamplingConfig): - self._config = config - self._dataset = dataset - - def __len__(self) -> int: - return self._config.num_samples - - def __getitem__(self, idx) -> np.ndarray: - return self._dataset.get() - - @property - def name(self) -> str: - return self._dataset.name diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py new file mode 100644 index 00000000..142dca71 --- /dev/null +++ b/fast_llm/data/dataset/gpt/random.py @@ -0,0 +1,41 @@ +import numpy as np + +from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig + + +class GPTRandomDataset(SamplableDataset): + """ + A dummy dataset that always returns the same random sample, for debugging purposes. + """ + + def __init__(self, name: str): + self._name = name + + def sample(self, config: GPTSamplingConfig) -> "GPTRandomSampledDataset": + return GPTRandomSampledDataset(config, f"{self.name}_sampled") + + @property + def name(self) -> str: + return self._name + + +class GPTRandomSampledDataset(SampledDataset): + def __init__(self, config: GPTSamplingConfig, name: str): + self._name = name + self._seed = config.seed + self._sequence_length = config.sequence_length + self._vocab_size = config.vocab_size + self._num_samples = config.num_samples + + def __len__(self) -> int: + return self._num_samples + + def __getitem__(self, idx) -> np.ndarray: + return np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( + 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + ) + + @property + def name(self) -> str: + return self._name