Skip to content

Use jax autograd from PyTorch #8805

@tengyifei

Description

@tengyifei

🚀 Feature

Instead of using PyTorch autograd and checkpointing, we'll investigate using jax.grad, jax.vjp, jax.remat etc. to control the rematerialization of a PyTorch model.

Motivation

JAX remat is more powerful than PyTorch autograd. For example, we can name individual tensors and selectively save/offload them. PyTorch does not support naming a tensor.

Pitch

Something like https://github.com/tengyifei/playground/blob/master/torch-jax-autograd.ipynb combined with #8781 and torchax.

cc @qihqi

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions