-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainer batch size auto scaling #14200
Comments
I am very nervous about adding that kind of feature of auto scaling to the Trainer. Note that the In a notebook, the kernel is in an unrecoverable state after the So for now, my sense is that such a feature would be more painful for the user than beneficial and I would leave the tuning of the batch size to the user. |
Thanks very much for the feedback! |
Perhaps instead of shrinking batch_size, it could work the other direction. If gradient_accumulation_steps is > 1, the first few steps could monitor the memory footprint and combine steps when the system sees there is enough capacity. Is that similarly dangerous? Again, a very rough PoC just to illustrate the behavior: class BatchAutoScaleTrainer(transformers.Trainer):
def __init__(self, *args, **kwds):
self._should_study = True
super().__init__(*args, **kwds)
self.mini_bs = self.args.n_gpu
def _minibatch(self, bs, mbs, batch):
kl = batch.keys()
return (dict(zip(kl,
(_[i:i+mbs] for _ in batch.values()),
)) for i in range(0, bs, mbs))
def _ministudy(self, bs, mbs):
if mbs < bs:
for i in range(mbs + 1, bs + 1):
if bs % i == 0:
nbs = i
break
d = torch.device
est = torch.cuda.memory_reserved(d) / mbs * nbs
if est < torch.cuda.get_device_properties(d).total_memory:
self.mini_bs = nbs
return
transformers.trainer.logger.info(f'{__class__.__name__}: mini_bs={nbs}')
self._should_study = False
def training_step(self, model, inputs):
bs = len(next(iter(inputs.values())))
mbs = min(bs, self.mini_bs)
segs = bs // mbs
ts = super().training_step
loss = torch.stack(tuple(
ts(model, batch)
for batch in self._minibatch(bs, mbs, inputs)
)).mean()
if self._should_study:
self._ministudy(bs, mbs)
return loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys):
bs = len(next(iter(inputs.values())))
mbs = self.mini_bs
segs = bs // mbs
ps = super().prediction_step
loss, logits, labels = zip(*(
ps(model, batch, prediction_loss_only, ignore_keys)
for batch in self._minibatch(bs, mbs, inputs)
))
if prediction_loss_only:
return (torch.stack(loss).mean(), None, None)
return (torch.stack(loss).mean(), torch.cat(logits), torch.cat(labels)) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Maybe of interest to you @tlby https://github.com/rentruewang/koila |
@LysandreJik Indeed, thanks for the note. rentruewang/koila#12 is a hopeful sign. |
🚀 Feature request
Since
Trainer
handles both batch_size and gradient_accumulation_steps it seems like it could detect some out-of-memory situations and handle those scenarios automatically.Motivation
I've been experimenting with model search (model_type, vocab_size, num_hidden_layers, hidden_size) and it's been somewhat difficult to manage the correct batch size for each variant. To avoid a process of trial & error and maintaining configuration tables, what I've been doing to overcome this is detecting memory exhaustion and adapting training arguments on the fly. It's imperfect, but I wonder if there's an official way to achieve this kind of behavior.
Your contribution
This is just a PoC, I'm sure there are several environments where this might be problematic. In particular CPU training on Linux is quite likely to trigger the OOM killer where the entire process is simply wiped from memory. Nevertheless, this strategy seems helpful at least some of the time.
Any chance something like this might be integrated with the Trainer?
The text was updated successfully, but these errors were encountered: