Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED error while performing inference on gemma-2b on TPU #85

Open
adityarajsahu opened this issue Feb 7, 2025 · 6 comments

Comments

@adityarajsahu
Copy link

I ran the whole colab script on my TPU server - https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/sampling_tutorial.ipynb#scrollTo=tqbJ1SUcESaN

The script ran properly the first time, but sometime later, when I again ran the script I got the following error

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 15.65G. That was not possible. There are 11.08G free.; (0x0x0_HBM0)

Can anyone please tell what is the exact cause of the error and how to fix this?

@Gopi-Uppari
Copy link

Hi @adityarajsahu,

I reproduced this issue. The crash occurring on the second run is likely caused by residual memory from the previous run not being released. This is a common issue with large models and memory-intensive processes in environments like Google Colab.
To resolve it, batch Inputs to Reduce Memory Usage. If the input batch is too large, split it into smaller batches to process sequentially. Can you please refer this gist file where you will find the modified code.

Thank you.

@adityarajsahu
Copy link
Author

Hi @Gopi-Uppari,
Thanks for the help, I will try the modified code. Apart from that, any way to ensure all memory is released after the process terminates?

@Gopi-Uppari
Copy link

Hi @adityarajsahu,

To ensure all memory is released after the process terminates, we can restart the runtime/session in Colab.

Thank you.

@adityarajsahu
Copy link
Author

Hi @Gopi-Uppari,

I am working on GCP TPU VM.

@Conchylicultor
Copy link
Collaborator

You can explicitly call x.delete() on an array to explicitly release the memory from the TPU. Use it inside jax.tree.map to release a tree of arrays.

Otherwise, Jax will release the memory if there's no reference on the python array anymore.

@Gopi-Uppari
Copy link

Hi @adityarajsahu,

Could you please confirm if this issue is resolved for you with the above comments ? Please feel free to close the issue if it is resolved ?

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants