-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] made whole adaptive workflow jax/jit #400
Conversation
Merging in recent changes from upstream.
Merging jit changes into pymbar4
made adaptive for pymbar jittable.
import jax | ||
from jax.scipy.special import logsumexp | ||
from jax.ops import index_update, index | ||
from jax.config import config; config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will slow things down dramatically, esp. on the GPU. Do you really need 64bit precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably yes - the algorithm doesn't seem to converge well with 32 bit floats. But I can poke around and see if I can isolate the particular problems in convergence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need to work on this . . .
|
||
jit_core_adaptive = jax.jit(core_adaptive,static_argnums=(0,1,2,3,4,5,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you call jax.jit, the first invocation will be exceedingly slow it's compiling the JIT kernels. So if you're benchmarking you should call this more than once to amortize out the jit time.
Futhermore, if the arguments have shapes that contain (None,) in one of the shapes, a recompilation is required due to an XLA requirement for each specialization of (None,) into a known shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this explains why the partial jitification works better- then it's calling different functions multiple times. Though only 7-8 times in many cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will test running mbar initialization several times with different data sets in the same script to see if that helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also noticed you're declaring every arg to be static, which means that unless you're passing in the exact args you'll be triggering a recompilation. Why did you have to declare them to be static to start with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
U_kn probably shouldn't be static if you're trying to pass in results from different runs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to declare it static because otherwise, gnorm_sci and gnorm_nr end up as traced arrays instead of concrete arrays (since they are functions of traced arrays), and jit refuses to compile the conditional comparing the two, which decides the branch to take, unless they are declared static.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@proteneer had a good suggestion for handling the conditional:
import jax
import jax.numpy as jnp
# can't be JIT'd
def foo(a, b, cond_a, cond_b):
if cond_a < cond_b:
return a+b
else:
return a-b
# foo_jit = jax.jit(foo) # fails
print(foo(0., 1., 2., 3.))
# can be JIT'd
def bar(a, b, cond_a, cond_b):
return jnp.where(cond_a < cond_b, a+b, a-b)
bar_jit = jax.jit(bar)
print(bar_jit(0., 1., 2., 3.))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, if we accelerate just the inner loop, then we can pull the conditional out of the accelerated code with little loss of timing - all of the conditionals are just comparing floats and assigning values, so there is little use in accelerating them.
|
||
# Perform Newton-Raphson iterations (with sci computed on the way) | ||
for iteration in range(0, maximum_iterations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you write this as jax.lax.while_loop? Just jit-compiling this for loop while be a nightmare for XLA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you write this as jax.lax.while_loop?
I can take a look, though the fact that it's on the adaptive loop that is generally only called once in pymbar means it won't be as useful to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it makes sense then to just jit the body of this function? (i.e. move everything inside the for loop to a separate function and just jit that)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes, that makes a lot of sense given the constraints. I'll try that.
|
||
obj = math.fsum(log_denominator_n) - N_k.dot(f_k) | ||
|
||
return obj, grad | ||
|
||
def jax_mbar_hessian(u_kn, N_k, f_k): | ||
|
||
jNk = 1.0*N_k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the point of 1.0*?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N_k is an int, jit (or maybe Jax, I can't recall) complains if it's not converted to float first. I wasn't able to find the best direct conversion to float function that made jit happy. so I just multiplied by 1.0 to do it automatically.
Closing in favor of one that just accelerates the inner loop. |
Converted the entire loop in adaptive to jax/jit.
See discussion on issue #340 for discussion of timings.
Intended for inspection, not merging for now.
Failing because jax/jit isn't loaded via conda (not clear it can be right now, conda-forge install on my machine didn't work, had to pip install).