Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer batch size auto scaling #14200

Closed
tlby opened this issue Oct 28, 2021 · 7 comments
Closed

Trainer batch size auto scaling #14200

tlby opened this issue Oct 28, 2021 · 7 comments
Assignees

Comments

@tlby
Copy link
Contributor

tlby commented Oct 28, 2021

🚀 Feature request

Since Trainer handles both batch_size and gradient_accumulation_steps it seems like it could detect some out-of-memory situations and handle those scenarios automatically.

Motivation

I've been experimenting with model search (model_type, vocab_size, num_hidden_layers, hidden_size) and it's been somewhat difficult to manage the correct batch size for each variant. To avoid a process of trial & error and maintaining configuration tables, what I've been doing to overcome this is detecting memory exhaustion and adapting training arguments on the fly. It's imperfect, but I wonder if there's an official way to achieve this kind of behavior.

Your contribution

This is just a PoC, I'm sure there are several environments where this might be problematic. In particular CPU training on Linux is quite likely to trigger the OOM killer where the entire process is simply wiped from memory. Nevertheless, this strategy seems helpful at least some of the time.

class BatchAutoScaleTrainer(transformers.Trainer):
    ''' Try to detect application crashes due to CUDA/CPU OOMs and
        rescale batch size.  An antiprime batch_size gives best results.
        Inspired by PyTorchLightning/pytorch-lightning#1638
    '''
    def _shrink_bs(self):
        # GAS is used by both .train() and .eval() and we need to find a
        # suitable setting for both
        tbs = self.args.per_device_train_batch_size
        ebs = self.args.per_device_eval_batch_size
        gas = self.args.gradient_accumulation_steps
        for i in range(gas + 1, min(tbs, ebs) + 1):
            if tbs % i or ebs % i:
                continue
            self.args.per_device_train_batch_size = (tbs * gas) // i
            self.args.per_device_eval_batch_size = (ebs * gas) // i
            self.args.gradient_accumulation_steps = i
            return True
        return False
    def _is_oom(self, err):
        # shamelessly stolen from https://github.com/PyTorchLightning/pytorch-lightning/pull/1638/files#diff-5200c11792b86d6a07ea64820e126897aa2e3b7d3d295c92c19b141de6950afeR29-R32
        return len(err.args) == 1 and (
            "CUDA out of memory." in err.args[0]
         or "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in err.args[0]
         or "DefaultCPUAllocator: can't allocate memory" in err.args[0]
         or "CUDA error: CUBLAS_STATUS_ALLOC_FAILED " in err.args[0]
        )
    def _auto_scale_batch_size(self, code):
        while True:
            try:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                return code()
            except RuntimeError as err:
                if self._is_oom(err) and self._shrink_bs():
                    continue
                raise
            assert(False) # bug in _shrink_bs() most likely
    def train(self, *args, **kwds):
        train = super().train
        return self._auto_scale_batch_size(
            lambda: train(*args, **kwds))
    def evaluate(self, *args, **kwds):
        evaluate = super().evaluate
        return self._auto_scale_batch_size(
            lambda: evaluate(*args, **kwds))

Any chance something like this might be integrated with the Trainer?

@sgugger
Copy link
Collaborator

sgugger commented Oct 28, 2021

I am very nervous about adding that kind of feature of auto scaling to the Trainer. Note that the _is_oom test for instance will catch way more CUDA errors than the OOM: haivng the wrong number of labels in your model will trigger an error with CUBLAS_STATUS_ALLOC_FAILED on most environments.

In a notebook, the kernel is in an unrecoverable state after the try/except (and torch.cuda.empty_cache() does not help), so this wouldn't work either.

So for now, my sense is that such a feature would be more painful for the user than beneficial and I would leave the tuning of the batch size to the user.

@tlby
Copy link
Contributor Author

tlby commented Oct 28, 2021

Thanks very much for the feedback!

@tlby
Copy link
Contributor Author

tlby commented Nov 2, 2021

Perhaps instead of shrinking batch_size, it could work the other direction. If gradient_accumulation_steps is > 1, the first few steps could monitor the memory footprint and combine steps when the system sees there is enough capacity. Is that similarly dangerous?

Again, a very rough PoC just to illustrate the behavior:

class BatchAutoScaleTrainer(transformers.Trainer):
    def __init__(self, *args, **kwds):
        self._should_study = True
        super().__init__(*args, **kwds)
        self.mini_bs = self.args.n_gpu
    def _minibatch(self, bs, mbs, batch):
        kl = batch.keys()
        return (dict(zip(kl,
            (_[i:i+mbs] for _ in batch.values()),
        )) for i in range(0, bs, mbs))
    def _ministudy(self, bs, mbs):
        if mbs < bs:
            for i in range(mbs + 1, bs + 1):
                if bs % i == 0:
                    nbs = i
                    break
            d = torch.device
            est = torch.cuda.memory_reserved(d) / mbs * nbs
            if est < torch.cuda.get_device_properties(d).total_memory:
                self.mini_bs = nbs
                return
        transformers.trainer.logger.info(f'{__class__.__name__}: mini_bs={nbs}')
        self._should_study = False
    def training_step(self, model, inputs):
        bs = len(next(iter(inputs.values())))
        mbs = min(bs, self.mini_bs)
        segs = bs // mbs
        ts = super().training_step
        loss = torch.stack(tuple(
            ts(model, batch)
            for batch in self._minibatch(bs, mbs, inputs)
        )).mean()
        if self._should_study:
            self._ministudy(bs, mbs)
        return loss
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys):
        bs = len(next(iter(inputs.values())))
        mbs = self.mini_bs
        segs = bs // mbs
        ps = super().prediction_step
        loss, logits, labels = zip(*(
            ps(model, batch, prediction_loss_only, ignore_keys)
            for batch in self._minibatch(bs, mbs, inputs)
        ))
        if prediction_loss_only:
            return (torch.stack(loss).mean(), None, None)
        return (torch.stack(loss).mean(), torch.cat(logits), torch.cat(labels))

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Dec 6, 2021
@LysandreJik
Copy link
Member

Maybe of interest to you @tlby https://github.com/rentruewang/koila

@tlby
Copy link
Contributor Author

tlby commented Dec 8, 2021

@LysandreJik Indeed, thanks for the note. rentruewang/koila#12 is a hopeful sign.

@daniel-bogdoll
Copy link
Contributor

https://pytorch-lightning.readthedocs.io/en/1.1.1/training_tricks.html#auto-scaling-of-batch-size also maybe interesting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants