-
Notifications
You must be signed in to change notification settings - Fork 559
Labels
Description
🚀 Feature
Instead of using PyTorch autograd and checkpointing, we'll investigate using jax.grad
, jax.vjp
, jax.remat
etc. to control the rematerialization of a PyTorch model.
Motivation
JAX remat is more powerful than PyTorch autograd. For example, we can name individual tensors and selectively save/offload them. PyTorch does not support naming a tensor.
Pitch
Something like https://github.com/tengyifei/playground/blob/master/torch-jax-autograd.ipynb combined with #8781 and torchax.
cc @qihqi