Skip to content

How to set a function's vjp to be identical to another's? #28511

Closed Answered by dfm
ahuang314 asked this question in Q&A
Discussion options

You must be logged in to vote

I think that you might be better off using something like:

@jax.custom_jvp
def f1(*args):
  return ...

@f.defjvp
def f1_jvp(primals, tangents):
  return jax.jvp(f2, primals, tangents)

or if you really don't ever want to use f2 for the primals:

@f.defjvp
def f1_jvp(primals, tangents):
  _, out_tangents = jax.jvp(f2, primals, tangents)
  return f1(*primals), out_tangents

The first one will have exactly the same computational performance as f2 when differentiated, whereas the second one might perform some extra work, since the primals aren't used from f2's forward pass.

If this doesn't work, please feel free to post a minimal reproducer with the errors your seeing!

Replies: 1 comment 1 reply

Comment options

dfm
May 5, 2025
Collaborator

You must be logged in to vote
1 reply
@ahuang314
Comment options

Answer selected by ahuang314
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants