Skip to content

Fix sampling, add timeouts for test suprocess and data loaders #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from fast_llm.config import Configurable
from fast_llm.data.data.config import DataConfig
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.schedule.config import BatchConfig

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -45,5 +45,6 @@ def get_iterator(
consumed_samples: int,
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[typing.Any]:
pass
1 change: 1 addition & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def get_iterator(
consumed_samples: int,
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[typing.Any]:
assert self._is_setup

Expand Down
66 changes: 35 additions & 31 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __getitem__(self, item: typing.Any) -> np.ndarray:

def _lazy_load(self):
if self._array is None:
assert self.exists()
assert self.exists(), self._path
self._array = np.load(self._path, mmap_mode="r")


Expand Down Expand Up @@ -178,25 +178,26 @@ def _sample(self) -> None:
"truncate_documents": self._truncate_documents,
"config": self._config.to_serialized(),
}
self._load_yaml_data(yaml_data)
if self._truncate_documents:
yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs

if self._yaml_path is not None:
if self._yaml_path.is_file():
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
unshuffled_tokens = loaded_yaml_data.pop("unshuffled_tokens", None)
if unshuffled_tokens is not None:
self._unshuffled_tokens = unshuffled_tokens
if loaded_yaml_data != yaml_data:
raise RuntimeError(
f"Invalid dataset cache for dataset {self.name}."
" If this is due to an intended configuration change,"
" please delete the cache before continuing."
f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}"
f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}"
)
# Dataset is already sampled, skip.
logger.info(f"Using existing sampling for dataset {self.name}")
return
if self._yaml_path is not None and self._yaml_path.is_file():
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
self._load_yaml_data(yaml_data)
if not self._truncate_documents:
del loaded_yaml_data["unshuffled_tokens"]

if loaded_yaml_data != yaml_data:
raise RuntimeError(
f"Invalid dataset cache for dataset {self.name}."
" If this is due to an intended configuration change,"
" please delete the cache before continuing."
f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}"
f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}"
)
# Dataset is already sampled, skip.
logger.info(f"Using existing sampling for dataset {self.name}")
return

if shuffled_documents > 1e8:
warnings.warn(
Expand Down Expand Up @@ -255,33 +256,32 @@ def _sample(self) -> None:
# Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation.
# Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))`
if unshuffled_epochs > 0:
token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum(
token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum(
document_sizes,
offset=0,
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs),
)
if self._truncate_documents:
num_tokens_unshuffled = tokens_per_epoch * unshuffled_epochs
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled)
else:
num_tokens_unshuffled = 0
self._unshuffled_tokens = num_tokens_unshuffled
unshuffled_tokens = 0

if not self._truncate_documents:
yaml_data["unshuffled_tokens"] = unshuffled_tokens
self._load_yaml_data(yaml_data)
if self._yaml_path is not None:
yaml_data["unshuffled_tokens"] = num_tokens_unshuffled
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))

if shuffled_epochs > 0:
token_cumsum_shuffled, num_tokens_shuffled = self._get_token_cumsum(
token_cumsum_shuffled, _ = self._get_token_cumsum(
document_sizes[
# Torch indexing only works with int32 or int64
document_shuffling.to(
dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32
)
],
offset=num_tokens_unshuffled,
offset=self._unshuffled_tokens,
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs),
)
Expand Down Expand Up @@ -432,10 +432,14 @@ def _lazy_load(self):

def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:
self._documents_per_epoch = data["dataset"]["documents_per_epoch"]
if unshuffled_tokens := data.get("unshuffled_tokens") is not None:
self._unshuffled_tokens = unshuffled_tokens
else:
self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"]

if "unshuffled_tokens" not in data:
# Backward compatibility
# TODO v0.x: Remove
assert self._truncate_documents
data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"]

self._unshuffled_tokens = data["unshuffled_tokens"]
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch


Expand Down
1 change: 1 addition & 0 deletions fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def _get_data_iterator(
consumed_samples=completed_steps * self._config.batch.batch_size,
num_workers=self._config.training.num_workers,
prefetch_factor=prefetch_factor,
timeout=self._config.training.timeout,
)

def _prepare_training_state(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def run_test_script(
if num_gpus == 1 and not is_megatron:
CliTrainingConfig.parse_and_run(script)
else:
completed_proc = subprocess.run(command, env=env)
completed_proc = subprocess.run(command, env=env, timeout=60)
if completed_proc.returncode:
raise RuntimeError(f"Process failed with return code {completed_proc.returncode}")
if compare:
Expand Down