-
Suppose I have two different implementations of the same function, Here is what I've tried:
However this still seems to be slower than autodifferentiating through f2 directly. Additionally, I seem to run into some assertion error when z is complex, even after setting holomorphic=True in the jacobian call. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think that you might be better off using something like: @jax.custom_jvp
def f1(*args):
return ...
@f.defjvp
def f1_jvp(primals, tangents):
return jax.jvp(f2, primals, tangents) or if you really don't ever want to use @f.defjvp
def f1_jvp(primals, tangents):
_, out_tangents = jax.jvp(f2, primals, tangents)
return f1(*primals), out_tangents The first one will have exactly the same computational performance as If this doesn't work, please feel free to post a minimal reproducer with the errors your seeing! |
Beta Was this translation helpful? Give feedback.
I think that you might be better off using something like:
or if you really don't ever want to use
f2
for the primals:The first one will have exactly the same computational performance as
f2
when differentiated, whereas the second one might perform some extra work, since the primals aren't used fromf2
's forward pass.If this doesn't work, please feel free to post a minimal reproducer with the errors your seeing!