@@ -65,7 +65,7 @@ def __getitem__(self, item: typing.Any) -> np.ndarray:
65
65
66
66
def _lazy_load (self ):
67
67
if self ._array is None :
68
- assert self .exists ()
68
+ assert self .exists (), self . _path
69
69
self ._array = np .load (self ._path , mmap_mode = "r" )
70
70
71
71
@@ -178,25 +178,26 @@ def _sample(self) -> None:
178
178
"truncate_documents" : self ._truncate_documents ,
179
179
"config" : self ._config .to_serialized (),
180
180
}
181
- self ._load_yaml_data (yaml_data )
181
+ if self ._truncate_documents :
182
+ yaml_data ["unshuffled_tokens" ] = tokens_per_epoch * unshuffled_epochs
182
183
183
- if self ._yaml_path is not None :
184
- if self ._yaml_path .is_file ():
185
- loaded_yaml_data = yaml . safe_load ( self ._yaml_path . open ( "r" ) )
186
- unshuffled_tokens = loaded_yaml_data . pop ( "unshuffled_tokens" , None )
187
- if unshuffled_tokens is not None :
188
- self . _unshuffled_tokens = unshuffled_tokens
189
- if loaded_yaml_data != yaml_data :
190
- raise RuntimeError (
191
- f"Invalid dataset cache for dataset { self .name } ."
192
- " If this is due to an intended configuration change,"
193
- " please delete the cache before continuing."
194
- f"\n Current config:\n { yaml .safe_dump (yaml_data )} "
195
- f"\n Cached config:\n { yaml .safe_dump (loaded_yaml_data )} "
196
- )
197
- # Dataset is already sampled, skip.
198
- logger .info (f"Using existing sampling for dataset { self .name } " )
199
- return
184
+ if self ._yaml_path is not None and self . _yaml_path . is_file () :
185
+ loaded_yaml_data = yaml . safe_load ( self ._yaml_path .open ( "r" ))
186
+ self ._load_yaml_data ( yaml_data )
187
+ if not self . _truncate_documents :
188
+ del loaded_yaml_data [ " unshuffled_tokens" ]
189
+
190
+ if loaded_yaml_data != yaml_data :
191
+ raise RuntimeError (
192
+ f"Invalid dataset cache for dataset { self .name } ."
193
+ " If this is due to an intended configuration change,"
194
+ " please delete the cache before continuing."
195
+ f"\n Current config:\n { yaml .safe_dump (yaml_data )} "
196
+ f"\n Cached config:\n { yaml .safe_dump (loaded_yaml_data )} "
197
+ )
198
+ # Dataset is already sampled, skip.
199
+ logger .info (f"Using existing sampling for dataset { self .name } " )
200
+ return
200
201
201
202
if shuffled_documents > 1e8 :
202
203
warnings .warn (
@@ -255,33 +256,32 @@ def _sample(self) -> None:
255
256
# Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation.
256
257
# Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))`
257
258
if unshuffled_epochs > 0 :
258
- token_cumsum_unshuffled , num_tokens_unshuffled = self ._get_token_cumsum (
259
+ token_cumsum_unshuffled , unshuffled_tokens = self ._get_token_cumsum (
259
260
document_sizes ,
260
261
offset = 0 ,
261
262
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
262
263
dtype = get_unsigned_integer_type ((2 - self ._truncate_documents ) * tokens_per_epoch * num_epochs ),
263
264
)
264
- if self ._truncate_documents :
265
- num_tokens_unshuffled = tokens_per_epoch * unshuffled_epochs
266
265
self ._token_cumsum_unshuffled .save (token_cumsum_unshuffled )
267
266
else :
268
- num_tokens_unshuffled = 0
269
- self ._unshuffled_tokens = num_tokens_unshuffled
267
+ unshuffled_tokens = 0
270
268
269
+ if not self ._truncate_documents :
270
+ yaml_data ["unshuffled_tokens" ] = unshuffled_tokens
271
+ self ._load_yaml_data (yaml_data )
271
272
if self ._yaml_path is not None :
272
- yaml_data ["unshuffled_tokens" ] = num_tokens_unshuffled
273
273
self ._yaml_path .parent .mkdir (parents = True , exist_ok = True )
274
274
yaml .safe_dump (yaml_data , self ._yaml_path .open ("w" ))
275
275
276
276
if shuffled_epochs > 0 :
277
- token_cumsum_shuffled , num_tokens_shuffled = self ._get_token_cumsum (
277
+ token_cumsum_shuffled , _ = self ._get_token_cumsum (
278
278
document_sizes [
279
279
# Torch indexing only works with int32 or int64
280
280
document_shuffling .to (
281
281
dtype = torch .int64 if document_shuffling .dtype == torch .int64 else torch .int32
282
282
)
283
283
],
284
- offset = num_tokens_unshuffled ,
284
+ offset = self . _unshuffled_tokens ,
285
285
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
286
286
dtype = get_unsigned_integer_type ((2 - self ._truncate_documents ) * tokens_per_epoch * num_epochs ),
287
287
)
@@ -432,10 +432,14 @@ def _lazy_load(self):
432
432
433
433
def _load_yaml_data (self , data : dict [str , typing .Any ]) -> None :
434
434
self ._documents_per_epoch = data ["dataset" ]["documents_per_epoch" ]
435
- if unshuffled_tokens := data .get ("unshuffled_tokens" ) is not None :
436
- self ._unshuffled_tokens = unshuffled_tokens
437
- else :
438
- self ._unshuffled_tokens = data ["unshuffled_epochs" ] * data ["dataset" ]["tokens_per_epoch" ]
435
+
436
+ if "unshuffled_tokens" not in data :
437
+ # Backward compatibility
438
+ # TODO v0.x: Remove
439
+ assert self ._truncate_documents
440
+ data ["unshuffled_tokens" ] = data ["tokens_per_epoch" ] * data ["unshuffled_epochs" ]
441
+
442
+ self ._unshuffled_tokens = data ["unshuffled_tokens" ]
439
443
self ._unshuffled_documents = data ["unshuffled_epochs" ] * self ._documents_per_epoch
440
444
441
445
0 commit comments