13
13
from fast_llm .data .data .abstract import Data
14
14
from fast_llm .data .data .gpt .config import GPTDataConfig
15
15
from fast_llm .data .dataset .abstract import SampledDataset
16
- from fast_llm .data .dataset .gpt .config import GPTSamplingData
16
+ from fast_llm .data .dataset .gpt .config import GPTSamplingData , GPTSamplingParameters
17
17
from fast_llm .data .dataset .gpt .sampled import GPTSample
18
18
from fast_llm .data .dataset .monitor import DatasetMonitor
19
19
from fast_llm .data .iterator import SampledDatasetIterator
@@ -34,15 +34,13 @@ class GPTBatch:
34
34
sequence_lengths : list [torch .Tensor ] | None = None
35
35
36
36
37
- def gpt_data_collate_fn (
38
- batch : list [GPTSample ], use_loss_masking_spans : bool , cross_document_attention : bool
39
- ) -> GPTBatch :
37
+ def gpt_data_collate_fn (batch : list [GPTSample ], sampling_parameters : GPTSamplingParameters ) -> GPTBatch :
40
38
stacked_ids = np .stack ([sample .token_ids for sample in batch ])
41
39
stacked_spans = None
42
40
sequence_lengths = None
43
- if use_loss_masking_spans :
41
+ if sampling_parameters . use_loss_masking_spans :
44
42
stacked_spans = [torch .from_numpy (sample .loss_masking_spans ) for sample in batch ]
45
- if not cross_document_attention :
43
+ if not sampling_parameters . cross_document_attention :
46
44
sequence_lengths = [torch .tensor (sample .sequence_lengths ) for sample in batch ]
47
45
return GPTBatch (
48
46
token_ids = torch .from_numpy (stacked_ids ), loss_masking_spans = stacked_spans , sequence_lengths = sequence_lengths
@@ -57,51 +55,47 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
57
55
"""
58
56
59
57
_datasets : dict [str , SampledDataset ]
58
+ _sampling_parameters : dict [str , GPTSamplingParameters ]
60
59
_tokenizer : Tokenizer | None
61
60
_is_setup : bool = False
62
61
63
62
def __init__ (
64
63
self ,
65
64
config : GPTDataConfig ,
66
65
distributed_config : DistributedConfig ,
67
- vocab_size : int ,
68
- max_sequence_length : int ,
69
- cross_document_attention : bool = True ,
70
66
):
71
67
"""
72
68
Create the data and gather some basic information on the dataset(s).
73
69
Should be `setup` before use.
74
70
"""
75
71
super ().__init__ (config , distributed_config )
76
- self ._vocab_size = vocab_size
77
- self ._max_sequence_length = max_sequence_length
78
- self ._cross_document_attention = cross_document_attention
79
72
80
73
def setup (
81
74
self ,
82
75
distributed : "Distributed" ,
83
- samples_per_dataset : dict [str , int ],
76
+ sampling_parameters : dict [str , GPTSamplingParameters ],
84
77
cache_directory : pathlib .Path ,
85
78
timeout : float | None = None ,
86
79
) -> None :
87
80
"""
88
81
Load the datasets, and prepare or load the samplings.
89
82
This may take a while and a significant amount of cpu memory.
90
83
"""
84
+ super ().setup (distributed , sampling_parameters , cache_directory )
85
+
91
86
# Check and raise an error if a used dataset is not defined.
92
- for dataset_name in samples_per_dataset .keys ():
87
+ for dataset_name in self . _sampling_parameters .keys ():
93
88
if dataset_name not in self ._config .datasets :
94
89
raise ValueError (f"Dataset { dataset_name } not found." )
95
90
96
91
# Check and warn if there are defined datasets that are not used.
97
- unused_datasets = self ._config .datasets .keys () - samples_per_dataset .keys ()
92
+ unused_datasets = self ._config .datasets .keys () - self . _sampling_parameters .keys ()
98
93
if unused_datasets :
99
94
warnings .warn (
100
95
f"The following datasets are defined but not used: { ', ' .join (unused_datasets )} . "
101
96
"Ensure this is intentional, or update the configuration accordingly."
102
97
)
103
98
104
- super ().setup (distributed , samples_per_dataset , cache_directory )
105
99
log_main_rank (f"Preparing dataset. This may take several minutes." )
106
100
self ._tokenizer = None if self ._config .tokenizer .path is None else Tokenizer (self ._config .tokenizer )
107
101
@@ -110,19 +104,19 @@ def setup(
110
104
warnings .warn (f"Using the dataset directory for the index cache." )
111
105
112
106
self ._datasets = {}
113
- for dataset_name , num_samples in samples_per_dataset .items ():
114
- if num_samples > 0 :
107
+ for dataset_name , sampling_parameters in self ._sampling_parameters .items ():
108
+ if self ._tokenizer is not None :
109
+ # TODO: Too constraining?
110
+ Assert .eq (self ._tokenizer .vocab_size , sampling_parameters .vocab_size )
111
+ if sampling_parameters .num_samples > 0 :
115
112
sampling = GPTSamplingData (
116
- num_samples = num_samples ,
117
113
config = self ._config .sampling ,
114
+ parameters = sampling_parameters ,
118
115
cache_directory = self ._cache_directory ,
119
116
distributed = distributed ,
120
117
dataset_name = dataset_name ,
121
- sequence_length = self ._max_sequence_length ,
122
- vocab_size = self ._vocab_size ,
123
118
tokenizer = self ._tokenizer ,
124
119
truncate_documents = self ._config .truncate_documents ,
125
- cross_document_attention = self ._cross_document_attention ,
126
120
)
127
121
dataset = self ._config .datasets [dataset_name ].build_and_sample (sampling )
128
122
self ._datasets [dataset_name ] = DatasetMonitor (dataset , self ._config .data_sample_warn_time_ms )
@@ -152,7 +146,8 @@ def get_iterator(
152
146
dataset_name = dataset_name .lower ()
153
147
154
148
Assert .incl (dataset_name , self ._datasets )
155
- Assert .in_range_incl (batch_config .sequence_length , 1 , self ._max_sequence_length )
149
+ sampling_parameters = self ._sampling_parameters [dataset_name ]
150
+ Assert .in_range_incl (batch_config .sequence_length , 1 , sampling_parameters .sequence_length )
156
151
log_main_rank (f"Initializing { dataset_name } dataset iterator from sample { consumed_samples } ..." )
157
152
return iter (
158
153
torch .utils .data .DataLoader (
@@ -169,8 +164,7 @@ def get_iterator(
169
164
pin_memory = True ,
170
165
collate_fn = partial (
171
166
gpt_data_collate_fn ,
172
- use_loss_masking_spans = self ._config .sampling .use_loss_masking_spans ,
173
- cross_document_attention = self ._cross_document_attention ,
167
+ sampling_parameters = sampling_parameters ,
174
168
),
175
169
multiprocessing_context = self ._config .multiprocessing_context .value if num_workers > 0 else None ,
176
170
)
0 commit comments