Skip to content

Commit 8ccf58d

Browse files
authored
Fix sampling, add timeouts for test suprocess and data loaders (#221)
1 parent 58b6f8a commit 8ccf58d

File tree

5 files changed

+40
-33
lines changed

5 files changed

+40
-33
lines changed

fast_llm/data/data/abstract.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from fast_llm.config import Configurable
66
from fast_llm.data.data.config import DataConfig
7-
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
7+
from fast_llm.engine.distributed.config import DistributedConfig
88
from fast_llm.engine.schedule.config import BatchConfig
99

1010
if typing.TYPE_CHECKING:
@@ -45,5 +45,6 @@ def get_iterator(
4545
consumed_samples: int,
4646
num_workers: int,
4747
prefetch_factor: int | None = None,
48+
timeout: float = 60,
4849
) -> typing.Iterator[typing.Any]:
4950
pass

fast_llm/data/data/gpt/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def get_iterator(
143143
consumed_samples: int,
144144
num_workers: int,
145145
prefetch_factor: int | None = None,
146+
timeout: float = 60,
146147
) -> typing.Iterator[typing.Any]:
147148
assert self._is_setup
148149

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __getitem__(self, item: typing.Any) -> np.ndarray:
6565

6666
def _lazy_load(self):
6767
if self._array is None:
68-
assert self.exists()
68+
assert self.exists(), self._path
6969
self._array = np.load(self._path, mmap_mode="r")
7070

7171

@@ -178,25 +178,26 @@ def _sample(self) -> None:
178178
"truncate_documents": self._truncate_documents,
179179
"config": self._config.to_serialized(),
180180
}
181-
self._load_yaml_data(yaml_data)
181+
if self._truncate_documents:
182+
yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs
182183

183-
if self._yaml_path is not None:
184-
if self._yaml_path.is_file():
185-
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
186-
unshuffled_tokens = loaded_yaml_data.pop("unshuffled_tokens", None)
187-
if unshuffled_tokens is not None:
188-
self._unshuffled_tokens = unshuffled_tokens
189-
if loaded_yaml_data != yaml_data:
190-
raise RuntimeError(
191-
f"Invalid dataset cache for dataset {self.name}."
192-
" If this is due to an intended configuration change,"
193-
" please delete the cache before continuing."
194-
f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}"
195-
f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}"
196-
)
197-
# Dataset is already sampled, skip.
198-
logger.info(f"Using existing sampling for dataset {self.name}")
199-
return
184+
if self._yaml_path is not None and self._yaml_path.is_file():
185+
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
186+
self._load_yaml_data(yaml_data)
187+
if not self._truncate_documents:
188+
del loaded_yaml_data["unshuffled_tokens"]
189+
190+
if loaded_yaml_data != yaml_data:
191+
raise RuntimeError(
192+
f"Invalid dataset cache for dataset {self.name}."
193+
" If this is due to an intended configuration change,"
194+
" please delete the cache before continuing."
195+
f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}"
196+
f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}"
197+
)
198+
# Dataset is already sampled, skip.
199+
logger.info(f"Using existing sampling for dataset {self.name}")
200+
return
200201

201202
if shuffled_documents > 1e8:
202203
warnings.warn(
@@ -255,33 +256,32 @@ def _sample(self) -> None:
255256
# Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation.
256257
# Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))`
257258
if unshuffled_epochs > 0:
258-
token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum(
259+
token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum(
259260
document_sizes,
260261
offset=0,
261262
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
262263
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs),
263264
)
264-
if self._truncate_documents:
265-
num_tokens_unshuffled = tokens_per_epoch * unshuffled_epochs
266265
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled)
267266
else:
268-
num_tokens_unshuffled = 0
269-
self._unshuffled_tokens = num_tokens_unshuffled
267+
unshuffled_tokens = 0
270268

269+
if not self._truncate_documents:
270+
yaml_data["unshuffled_tokens"] = unshuffled_tokens
271+
self._load_yaml_data(yaml_data)
271272
if self._yaml_path is not None:
272-
yaml_data["unshuffled_tokens"] = num_tokens_unshuffled
273273
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
274274
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))
275275

276276
if shuffled_epochs > 0:
277-
token_cumsum_shuffled, num_tokens_shuffled = self._get_token_cumsum(
277+
token_cumsum_shuffled, _ = self._get_token_cumsum(
278278
document_sizes[
279279
# Torch indexing only works with int32 or int64
280280
document_shuffling.to(
281281
dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32
282282
)
283283
],
284-
offset=num_tokens_unshuffled,
284+
offset=self._unshuffled_tokens,
285285
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
286286
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs),
287287
)
@@ -432,10 +432,14 @@ def _lazy_load(self):
432432

433433
def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:
434434
self._documents_per_epoch = data["dataset"]["documents_per_epoch"]
435-
if unshuffled_tokens := data.get("unshuffled_tokens") is not None:
436-
self._unshuffled_tokens = unshuffled_tokens
437-
else:
438-
self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"]
435+
436+
if "unshuffled_tokens" not in data:
437+
# Backward compatibility
438+
# TODO v0.x: Remove
439+
assert self._truncate_documents
440+
data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"]
441+
442+
self._unshuffled_tokens = data["unshuffled_tokens"]
439443
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch
440444

441445

fast_llm/engine/training/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def _get_data_iterator(
416416
consumed_samples=completed_steps * self._config.batch.batch_size,
417417
num_workers=self._config.training.num_workers,
418418
prefetch_factor=prefetch_factor,
419+
timeout=self._config.training.timeout,
419420
)
420421

421422
def _prepare_training_state(self) -> None:

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def run_test_script(
394394
if num_gpus == 1 and not is_megatron:
395395
CliTrainingConfig.parse_and_run(script)
396396
else:
397-
completed_proc = subprocess.run(command, env=env)
397+
completed_proc = subprocess.run(command, env=env, timeout=60)
398398
if completed_proc.returncode:
399399
raise RuntimeError(f"Process failed with return code {completed_proc.returncode}")
400400
if compare:

0 commit comments

Comments
 (0)