Skip to content

Commit 01b71c9

Browse files
authored
Sampling parameters, generalize batch config. (#230)
1 parent 7a74af0 commit 01b71c9

File tree

19 files changed

+283
-172
lines changed

19 files changed

+283
-172
lines changed

fast_llm/data/data/abstract.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from fast_llm.config import Configurable
66
from fast_llm.data.data.config import DataConfig
7+
from fast_llm.data.dataset.config import SamplingParameters
78
from fast_llm.engine.distributed.config import DistributedConfig
89
from fast_llm.engine.schedule.config import BatchConfig
910

@@ -13,7 +14,7 @@
1314

1415
class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC):
1516
_distributed: "Distributed"
16-
_samples_per_dataset: dict[str, int]
17+
_sampling_parameters: dict[str, SamplingParameters]
1718
_cache_directory: pathlib.Path | None
1819

1920
def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None:
@@ -24,12 +25,12 @@ def __init__(self, config: DataConfig, distributed_config: DistributedConfig) ->
2425
def setup(
2526
self,
2627
distributed: "Distributed",
27-
samples_per_dataset: dict[str, int],
28+
sampling_parameters: dict[str, SamplingParameters],
2829
cache_directory: pathlib.Path,
2930
timeout: float | None = None,
3031
) -> None:
3132
self._distributed = distributed
32-
self._samples_per_dataset = samples_per_dataset
33+
self._sampling_parameters = sampling_parameters
3334
self._cache_directory = cache_directory
3435

3536
@property

fast_llm/data/data/gpt/data.py

+19-25
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fast_llm.data.data.abstract import Data
1414
from fast_llm.data.data.gpt.config import GPTDataConfig
1515
from fast_llm.data.dataset.abstract import SampledDataset
16-
from fast_llm.data.dataset.gpt.config import GPTSamplingData
16+
from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters
1717
from fast_llm.data.dataset.gpt.sampled import GPTSample
1818
from fast_llm.data.dataset.monitor import DatasetMonitor
1919
from fast_llm.data.iterator import SampledDatasetIterator
@@ -34,15 +34,13 @@ class GPTBatch:
3434
sequence_lengths: list[torch.Tensor] | None = None
3535

3636

37-
def gpt_data_collate_fn(
38-
batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool
39-
) -> GPTBatch:
37+
def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
4038
stacked_ids = np.stack([sample.token_ids for sample in batch])
4139
stacked_spans = None
4240
sequence_lengths = None
43-
if use_loss_masking_spans:
41+
if sampling_parameters.use_loss_masking_spans:
4442
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
45-
if not cross_document_attention:
43+
if not sampling_parameters.cross_document_attention:
4644
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
4745
return GPTBatch(
4846
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
@@ -57,51 +55,47 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
5755
"""
5856

5957
_datasets: dict[str, SampledDataset]
58+
_sampling_parameters: dict[str, GPTSamplingParameters]
6059
_tokenizer: Tokenizer | None
6160
_is_setup: bool = False
6261

6362
def __init__(
6463
self,
6564
config: GPTDataConfig,
6665
distributed_config: DistributedConfig,
67-
vocab_size: int,
68-
max_sequence_length: int,
69-
cross_document_attention: bool = True,
7066
):
7167
"""
7268
Create the data and gather some basic information on the dataset(s).
7369
Should be `setup` before use.
7470
"""
7571
super().__init__(config, distributed_config)
76-
self._vocab_size = vocab_size
77-
self._max_sequence_length = max_sequence_length
78-
self._cross_document_attention = cross_document_attention
7972

8073
def setup(
8174
self,
8275
distributed: "Distributed",
83-
samples_per_dataset: dict[str, int],
76+
sampling_parameters: dict[str, GPTSamplingParameters],
8477
cache_directory: pathlib.Path,
8578
timeout: float | None = None,
8679
) -> None:
8780
"""
8881
Load the datasets, and prepare or load the samplings.
8982
This may take a while and a significant amount of cpu memory.
9083
"""
84+
super().setup(distributed, sampling_parameters, cache_directory)
85+
9186
# Check and raise an error if a used dataset is not defined.
92-
for dataset_name in samples_per_dataset.keys():
87+
for dataset_name in self._sampling_parameters.keys():
9388
if dataset_name not in self._config.datasets:
9489
raise ValueError(f"Dataset {dataset_name} not found.")
9590

9691
# Check and warn if there are defined datasets that are not used.
97-
unused_datasets = self._config.datasets.keys() - samples_per_dataset.keys()
92+
unused_datasets = self._config.datasets.keys() - self._sampling_parameters.keys()
9893
if unused_datasets:
9994
warnings.warn(
10095
f"The following datasets are defined but not used: {', '.join(unused_datasets)}. "
10196
"Ensure this is intentional, or update the configuration accordingly."
10297
)
10398

104-
super().setup(distributed, samples_per_dataset, cache_directory)
10599
log_main_rank(f"Preparing dataset. This may take several minutes.")
106100
self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer)
107101

@@ -110,19 +104,19 @@ def setup(
110104
warnings.warn(f"Using the dataset directory for the index cache.")
111105

112106
self._datasets = {}
113-
for dataset_name, num_samples in samples_per_dataset.items():
114-
if num_samples > 0:
107+
for dataset_name, sampling_parameters in self._sampling_parameters.items():
108+
if self._tokenizer is not None:
109+
# TODO: Too constraining?
110+
Assert.eq(self._tokenizer.vocab_size, sampling_parameters.vocab_size)
111+
if sampling_parameters.num_samples > 0:
115112
sampling = GPTSamplingData(
116-
num_samples=num_samples,
117113
config=self._config.sampling,
114+
parameters=sampling_parameters,
118115
cache_directory=self._cache_directory,
119116
distributed=distributed,
120117
dataset_name=dataset_name,
121-
sequence_length=self._max_sequence_length,
122-
vocab_size=self._vocab_size,
123118
tokenizer=self._tokenizer,
124119
truncate_documents=self._config.truncate_documents,
125-
cross_document_attention=self._cross_document_attention,
126120
)
127121
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
128122
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
@@ -152,7 +146,8 @@ def get_iterator(
152146
dataset_name = dataset_name.lower()
153147

154148
Assert.incl(dataset_name, self._datasets)
155-
Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length)
149+
sampling_parameters = self._sampling_parameters[dataset_name]
150+
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
156151
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")
157152
return iter(
158153
torch.utils.data.DataLoader(
@@ -169,8 +164,7 @@ def get_iterator(
169164
pin_memory=True,
170165
collate_fn=partial(
171166
gpt_data_collate_fn,
172-
use_loss_masking_spans=self._config.sampling.use_loss_masking_spans,
173-
cross_document_attention=self._cross_document_attention,
167+
sampling_parameters=sampling_parameters,
174168
),
175169
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
176170
)

fast_llm/data/dataset/blended.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
Assert.eq(len(datasets), len(weights))
3131
self._datasets = datasets
3232
self._weights = np.array(normalize_probabilities(weights))
33-
self._num_samples = sampling_config.num_samples
33+
self._num_samples = sampling_config.parameters.num_samples
3434

3535
def __len__(self) -> int:
3636
return self._num_samples

fast_llm/data/dataset/config.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,36 @@
1616

1717
@config_class()
1818
class SamplingConfig(Config):
19+
"""
20+
A dataset-dependent configuration for sampling.
21+
"""
22+
1923
seed: int = Field(
2024
default=784569,
2125
desc="Seed for random sampling.",
2226
hint=FieldHint.feature,
2327
)
2428

2529

30+
@dataclasses.dataclass(kw_only=True)
31+
class SamplingParameters:
32+
"""
33+
Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model.
34+
"""
35+
36+
num_samples: int
37+
38+
2639
@dataclasses.dataclass(kw_only=True)
2740
class SamplingData:
41+
"""
42+
Holds all the necessary information for sampling, including dataset-dependent ones (`SamplingConfig`),
43+
usage-dependent ones (`SamplingParameters`), and others set by the `Data`.
44+
"""
45+
2846
# TODO: Have a separate configuration (subset?) for `build`?
2947
config: SamplingConfig
30-
num_samples: int
48+
parameters: SamplingParameters
3149
cache_directory: pathlib.Path | None
3250
# TODO: This prevents the sampling config from being pickled in multiprocessing.
3351
distributed: "Distributed"
@@ -213,10 +231,19 @@ def build_and_sample(
213231
# Blending is deterministic and the error will never be higher than 1.
214232
dataclasses.replace(
215233
sampling,
216-
num_samples=(
217-
math.ceil(weight * (sampling.num_samples + 5 * (sampling.num_samples * (1 - weight)) ** 0.5))
218-
if self.legacy
219-
else math.ceil(weight * sampling.num_samples) + 1
234+
parameters=dataclasses.replace(
235+
sampling.parameters,
236+
num_samples=(
237+
math.ceil(
238+
weight
239+
* (
240+
sampling.parameters.num_samples
241+
+ 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5
242+
)
243+
)
244+
if self.legacy
245+
else math.ceil(weight * sampling.parameters.num_samples) + 1
246+
),
220247
),
221248
# TODO: Seed may not be unique for nested blended datasets.
222249
config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}),

fast_llm/data/dataset/gpt/config.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
SampledDatasetUpdateConfig,
2121
SamplingConfig,
2222
SamplingData,
23+
SamplingParameters,
2324
)
2425
from fast_llm.engine.distributed.config import PhaseType
2526
from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum
@@ -45,34 +46,47 @@ class ShufflingType(str, enum.Enum):
4546

4647
@config_class()
4748
class GPTSamplingConfig(SamplingConfig):
49+
"""
50+
A dataset-dependent configuration for sampling.
51+
"""
52+
4853
gpu: bool = Field(
4954
default=True,
5055
desc="Enable fast sampling on GPU."
5156
" Note that random sampling works differently on GPU,"
5257
" so the sample won't match the CPU equivalent.",
5358
hint=FieldHint.feature,
5459
)
55-
use_loss_masking_spans: bool = Field(
56-
default=False,
57-
desc="Read loss masking spans from the dataset.",
58-
hint=FieldHint.feature,
59-
)
6060
shuffle: ShufflingType = Field(
6161
default=ShufflingType.epoch,
6262
desc="Shuffling strategy.",
6363
hint=FieldHint.feature,
6464
)
6565

6666

67-
@dataclasses.dataclass
68-
class GPTSamplingData(SamplingData):
69-
config: GPTSamplingConfig
70-
# TODO: Sort these out
67+
@dataclasses.dataclass(kw_only=True)
68+
class GPTSamplingParameters(SamplingParameters):
69+
"""
70+
Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model.
71+
"""
72+
7173
sequence_length: int
7274
vocab_size: int
75+
use_loss_masking_spans: bool = False
76+
cross_document_attention: bool = True
77+
78+
79+
@dataclasses.dataclass(kw_only=True)
80+
class GPTSamplingData(SamplingData):
81+
"""
82+
Holds all the necessary information for sampling, including dataset-dependent ones (`GPTSamplingConfig`),
83+
usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`.
84+
"""
85+
86+
config: GPTSamplingConfig
87+
parameters: GPTSamplingParameters
7388
tokenizer: "Tokenizer"
7489
truncate_documents: bool = True
75-
cross_document_attention: bool = True
7690

7791

7892
@config_class()

fast_llm/data/dataset/gpt/fim.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(
1818
dataset: SampledDataset,
1919
sampling: GPTSamplingData,
2020
):
21-
if sampling.config.use_loss_masking_spans:
21+
if sampling.parameters.use_loss_masking_spans:
2222
raise NotImplementedError("FIM is currently not compatible with loss masking.")
2323
self._config = config
2424
self._dataset = dataset

fast_llm/data/dataset/gpt/random.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class GPTRandomSampledDataset(SampledDataset):
2525
def __init__(self, sampling: GPTSamplingData, name: str):
2626
self._name = name
2727
self._seed = sampling.config.seed
28-
self._sequence_length = sampling.sequence_length
29-
self._vocab_size = sampling.vocab_size
30-
self._num_samples = sampling.num_samples
28+
self._sequence_length = sampling.parameters.sequence_length
29+
self._vocab_size = sampling.parameters.vocab_size
30+
self._num_samples = sampling.parameters.num_samples
3131

3232
def __len__(self) -> int:
3333
return self._num_samples

0 commit comments

Comments
 (0)