Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] made whole adaptive workflow jax/jit #400

Closed
wants to merge 24 commits into from
Closed

Conversation

mrshirts
Copy link
Collaborator

@mrshirts mrshirts commented Jun 29, 2020

Converted the entire loop in adaptive to jax/jit.
See discussion on issue #340 for discussion of timings.

Intended for inspection, not merging for now.

Failing because jax/jit isn't loaded via conda (not clear it can be right now, conda-forge install on my machine didn't work, had to pip install).

import jax
from jax.scipy.special import logsumexp
from jax.ops import index_update, index
from jax.config import config; config.update("jax_enable_x64", True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will slow things down dramatically, esp. on the GPU. Do you really need 64bit precision?

Copy link
Collaborator Author

@mrshirts mrshirts Jun 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes - the algorithm doesn't seem to converge well with 32 bit floats. But I can poke around and see if I can isolate the particular problems in convergence.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to work on this . . .


jit_core_adaptive = jax.jit(core_adaptive,static_argnums=(0,1,2,3,4,5,))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you call jax.jit, the first invocation will be exceedingly slow it's compiling the JIT kernels. So if you're benchmarking you should call this more than once to amortize out the jit time.

Futhermore, if the arguments have shapes that contain (None,) in one of the shapes, a recompilation is required due to an XLA requirement for each specialization of (None,) into a known shape.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this explains why the partial jitification works better- then it's calling different functions multiple times. Though only 7-8 times in many cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will test running mbar initialization several times with different data sets in the same script to see if that helps.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed you're declaring every arg to be static, which means that unless you're passing in the exact args you'll be triggering a recompilation. Why did you have to declare them to be static to start with?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

U_kn probably shouldn't be static if you're trying to pass in results from different runs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to declare it static because otherwise, gnorm_sci and gnorm_nr end up as traced arrays instead of concrete arrays (since they are functions of traced arrays), and jit refuses to compile the conditional comparing the two, which decides the branch to take, unless they are declared static.

Copy link
Collaborator Author

@mrshirts mrshirts Jun 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@proteneer had a good suggestion for handling the conditional:

import jax
import jax.numpy as jnp
# can't be JIT'd
def foo(a, b, cond_a, cond_b):
    if cond_a < cond_b:
        return a+b
    else:
        return a-b
# foo_jit = jax.jit(foo) # fails
print(foo(0., 1., 2., 3.))
# can be JIT'd
def bar(a, b, cond_a, cond_b):
    return jnp.where(cond_a < cond_b, a+b, a-b)
bar_jit = jax.jit(bar)
print(bar_jit(0., 1., 2., 3.))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if we accelerate just the inner loop, then we can pull the conditional out of the accelerated code with little loss of timing - all of the conditionals are just comparing floats and assigning values, so there is little use in accelerating them.


# Perform Newton-Raphson iterations (with sci computed on the way)
for iteration in range(0, maximum_iterations):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you write this as jax.lax.while_loop? Just jit-compiling this for loop while be a nightmare for XLA

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you write this as jax.lax.while_loop?

I can take a look, though the fact that it's on the adaptive loop that is generally only called once in pymbar means it won't be as useful to do this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it makes sense then to just jit the body of this function? (i.e. move everything inside the for loop to a separate function and just jit that)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, that makes a lot of sense given the constraints. I'll try that.


obj = math.fsum(log_denominator_n) - N_k.dot(f_k)

return obj, grad

def jax_mbar_hessian(u_kn, N_k, f_k):

jNk = 1.0*N_k

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of 1.0*?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N_k is an int, jit (or maybe Jax, I can't recall) complains if it's not converted to float first. I wasn't able to find the best direct conversion to float function that made jit happy. so I just multiplied by 1.0 to do it automatically.

@mrshirts
Copy link
Collaborator Author

Closing in favor of one that just accelerates the inner loop.

@mrshirts mrshirts closed this Jun 30, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants