Skip to content

Sampling parameters, generalize batch config. #230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fast_llm.config import Configurable
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.schedule.config import BatchConfig

Expand All @@ -13,7 +14,7 @@

class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC):
_distributed: "Distributed"
_samples_per_dataset: dict[str, int]
_sampling_parameters: dict[str, SamplingParameters]
_cache_directory: pathlib.Path | None

def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None:
Expand All @@ -24,12 +25,12 @@ def __init__(self, config: DataConfig, distributed_config: DistributedConfig) ->
def setup(
self,
distributed: "Distributed",
samples_per_dataset: dict[str, int],
sampling_parameters: dict[str, SamplingParameters],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
self._distributed = distributed
self._samples_per_dataset = samples_per_dataset
self._sampling_parameters = sampling_parameters
self._cache_directory = cache_directory

@property
Expand Down
44 changes: 19 additions & 25 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingData
from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
Expand All @@ -34,15 +34,13 @@ class GPTBatch:
sequence_lengths: list[torch.Tensor] | None = None


def gpt_data_collate_fn(
batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool
) -> GPTBatch:
def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None
if use_loss_masking_spans:
if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
if not cross_document_attention:
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
Expand All @@ -57,51 +55,47 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
"""

_datasets: dict[str, SampledDataset]
_sampling_parameters: dict[str, GPTSamplingParameters]
_tokenizer: Tokenizer | None
_is_setup: bool = False

def __init__(
self,
config: GPTDataConfig,
distributed_config: DistributedConfig,
vocab_size: int,
max_sequence_length: int,
cross_document_attention: bool = True,
):
"""
Create the data and gather some basic information on the dataset(s).
Should be `setup` before use.
"""
super().__init__(config, distributed_config)
self._vocab_size = vocab_size
self._max_sequence_length = max_sequence_length
self._cross_document_attention = cross_document_attention

def setup(
self,
distributed: "Distributed",
samples_per_dataset: dict[str, int],
sampling_parameters: dict[str, GPTSamplingParameters],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
"""
Load the datasets, and prepare or load the samplings.
This may take a while and a significant amount of cpu memory.
"""
super().setup(distributed, sampling_parameters, cache_directory)

# Check and raise an error if a used dataset is not defined.
for dataset_name in samples_per_dataset.keys():
for dataset_name in self._sampling_parameters.keys():
if dataset_name not in self._config.datasets:
raise ValueError(f"Dataset {dataset_name} not found.")

# Check and warn if there are defined datasets that are not used.
unused_datasets = self._config.datasets.keys() - samples_per_dataset.keys()
unused_datasets = self._config.datasets.keys() - self._sampling_parameters.keys()
if unused_datasets:
warnings.warn(
f"The following datasets are defined but not used: {', '.join(unused_datasets)}. "
"Ensure this is intentional, or update the configuration accordingly."
)

super().setup(distributed, samples_per_dataset, cache_directory)
log_main_rank(f"Preparing dataset. This may take several minutes.")
self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer)

Expand All @@ -110,19 +104,19 @@ def setup(
warnings.warn(f"Using the dataset directory for the index cache.")

self._datasets = {}
for dataset_name, num_samples in samples_per_dataset.items():
if num_samples > 0:
for dataset_name, sampling_parameters in self._sampling_parameters.items():
if self._tokenizer is not None:
# TODO: Too constraining?
Assert.eq(self._tokenizer.vocab_size, sampling_parameters.vocab_size)
if sampling_parameters.num_samples > 0:
sampling = GPTSamplingData(
num_samples=num_samples,
config=self._config.sampling,
parameters=sampling_parameters,
cache_directory=self._cache_directory,
distributed=distributed,
dataset_name=dataset_name,
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
truncate_documents=self._config.truncate_documents,
cross_document_attention=self._cross_document_attention,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
Expand Down Expand Up @@ -152,7 +146,8 @@ def get_iterator(
dataset_name = dataset_name.lower()

Assert.incl(dataset_name, self._datasets)
Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length)
sampling_parameters = self._sampling_parameters[dataset_name]
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")
return iter(
torch.utils.data.DataLoader(
Expand All @@ -169,8 +164,7 @@ def get_iterator(
pin_memory=True,
collate_fn=partial(
gpt_data_collate_fn,
use_loss_masking_spans=self._config.sampling.use_loss_masking_spans,
cross_document_attention=self._cross_document_attention,
sampling_parameters=sampling_parameters,
),
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/blended.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
Assert.eq(len(datasets), len(weights))
self._datasets = datasets
self._weights = np.array(normalize_probabilities(weights))
self._num_samples = sampling_config.num_samples
self._num_samples = sampling_config.parameters.num_samples

def __len__(self) -> int:
return self._num_samples
Expand Down
37 changes: 32 additions & 5 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,36 @@

@config_class()
class SamplingConfig(Config):
"""
A dataset-dependent configuration for sampling.
"""

seed: int = Field(
default=784569,
desc="Seed for random sampling.",
hint=FieldHint.feature,
)


@dataclasses.dataclass(kw_only=True)
class SamplingParameters:
"""
Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model.
"""

num_samples: int


@dataclasses.dataclass(kw_only=True)
class SamplingData:
"""
Holds all the necessary information for sampling, including dataset-dependent ones (`SamplingConfig`),
usage-dependent ones (`SamplingParameters`), and others set by the `Data`.
"""

# TODO: Have a separate configuration (subset?) for `build`?
config: SamplingConfig
num_samples: int
parameters: SamplingParameters
cache_directory: pathlib.Path | None
# TODO: This prevents the sampling config from being pickled in multiprocessing.
distributed: "Distributed"
Expand Down Expand Up @@ -213,10 +231,19 @@ def build_and_sample(
# Blending is deterministic and the error will never be higher than 1.
dataclasses.replace(
sampling,
num_samples=(
math.ceil(weight * (sampling.num_samples + 5 * (sampling.num_samples * (1 - weight)) ** 0.5))
if self.legacy
else math.ceil(weight * sampling.num_samples) + 1
parameters=dataclasses.replace(
sampling.parameters,
num_samples=(
math.ceil(
weight
* (
sampling.parameters.num_samples
+ 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5
)
)
if self.legacy
else math.ceil(weight * sampling.parameters.num_samples) + 1
),
),
# TODO: Seed may not be unique for nested blended datasets.
config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}),
Expand Down
34 changes: 24 additions & 10 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SampledDatasetUpdateConfig,
SamplingConfig,
SamplingData,
SamplingParameters,
)
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum
Expand All @@ -45,34 +46,47 @@ class ShufflingType(str, enum.Enum):

@config_class()
class GPTSamplingConfig(SamplingConfig):
"""
A dataset-dependent configuration for sampling.
"""

gpu: bool = Field(
default=True,
desc="Enable fast sampling on GPU."
" Note that random sampling works differently on GPU,"
" so the sample won't match the CPU equivalent.",
hint=FieldHint.feature,
)
use_loss_masking_spans: bool = Field(
default=False,
desc="Read loss masking spans from the dataset.",
hint=FieldHint.feature,
)
shuffle: ShufflingType = Field(
default=ShufflingType.epoch,
desc="Shuffling strategy.",
hint=FieldHint.feature,
)


@dataclasses.dataclass
class GPTSamplingData(SamplingData):
config: GPTSamplingConfig
# TODO: Sort these out
@dataclasses.dataclass(kw_only=True)
class GPTSamplingParameters(SamplingParameters):
"""
Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model.
"""

sequence_length: int
vocab_size: int
use_loss_masking_spans: bool = False
cross_document_attention: bool = True


@dataclasses.dataclass(kw_only=True)
class GPTSamplingData(SamplingData):
"""
Holds all the necessary information for sampling, including dataset-dependent ones (`GPTSamplingConfig`),
usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`.
"""

config: GPTSamplingConfig
parameters: GPTSamplingParameters
tokenizer: "Tokenizer"
truncate_documents: bool = True
cross_document_attention: bool = True


@config_class()
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
dataset: SampledDataset,
sampling: GPTSamplingData,
):
if sampling.config.use_loss_masking_spans:
if sampling.parameters.use_loss_masking_spans:
raise NotImplementedError("FIM is currently not compatible with loss masking.")
self._config = config
self._dataset = dataset
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/data/dataset/gpt/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class GPTRandomSampledDataset(SampledDataset):
def __init__(self, sampling: GPTSamplingData, name: str):
self._name = name
self._seed = sampling.config.seed
self._sequence_length = sampling.sequence_length
self._vocab_size = sampling.vocab_size
self._num_samples = sampling.num_samples
self._sequence_length = sampling.parameters.sequence_length
self._vocab_size = sampling.parameters.vocab_size
self._num_samples = sampling.parameters.num_samples

def __len__(self) -> int:
return self._num_samples
Expand Down
Loading