Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix for tp >= 2 and pp > 2 in autoitercount #1296

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,35 @@ def update_iterations(neox_args, data_loaders):
to do as many iterations as possible while ensuring that each example is seen *at most* train_epochs
times.
"""
if neox_args.train_iters is not None:
if (not neox_args.do_train) or (neox_args.train_iters is not None):
pass
elif neox_args.train_iters is None and neox_args.train_epochs is None:
print_rank_0(
"ERROR:Failed to specify either train_epochs or train_iters in config file"
)
else:
train_dataloader = data_loaders["train"]
train_epochs = neox_args.train_epochs
gradient_accumulation_steps = neox_args.gradient_accumulation_steps
global_rank = torch.distributed.get_rank()

train_iterations = (
len(train_dataloader) * train_epochs
) // gradient_accumulation_steps
if global_rank == 0:
train_dataloader = data_loaders["train"]
train_epochs = neox_args.train_epochs
gradient_accumulation_steps = neox_args.gradient_accumulation_steps

train_dataloader_len = len(train_dataloader)
train_iterations = (
train_dataloader_len * train_epochs
) // gradient_accumulation_steps

train_iters_tensor = torch.cuda.LongTensor([train_iterations])
else:
train_iters_tensor = torch.cuda.LongTensor([0])

torch.distributed.broadcast(train_iters_tensor, src=0)

neox_args.train_iters = train_iters_tensor[0].item()

neox_args.train_iters = train_iterations
print_rank_0(
f"Training for a total of {train_iterations} iterations, corresponding to {train_epochs} epochs."
f"Training for a total of {neox_args.train_iters} iterations, corresponding to {neox_args.train_epochs} epochs."
)


Expand Down
Loading