-
Notifications
You must be signed in to change notification settings - Fork 94
Updates JAX automatic parallelism section to reduce the complexity of explicit mode sharding. #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
@yashk2810 for a review |
2. **JAX, take the wheel!** Automatic parallelism is great, but sometimes the compiler does something crazy. Explicit sharding lets you write single-device code like usual, but have JAX handle sharding propagation (not the compiler). This means JAX can ask you for clarification when it's unclear what you want. | ||
3. **Just let me write what I mean, damnit!** While compilers are nice, they sometimes do the wrong thing and add communication you don't intend. Sometimes we want to be explicit about exactly what communication you intend to run. | ||
1. **Compiler, take the wheel!** Let the compiler automatically partition arrays and decide what communication to add to facilitate a given program. This lets you take a program that runs on a single device and automatically run it on thousands without changing anything. | ||
2. **Just let me write what I mean, damnit!** While compilers are nice, they sometimes do the wrong thing and add communication you don't intend. Sometimes we want to be explicit about exactly what communication you intend to run. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should only mention Explicit and Manual mode here if you only want 2 sections.
Mentioning Compiler and Manual doesn't seem right to me.
This section in JAX docs: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html#using-a-mixture-of-sharding-modes does a pretty good job at explaining all modes. Maybe we should copy it here as I originally did in the rewrite :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel like Explicit vs. Manual is a very important distinction TBH (even though I love shit). From the user's perspective, either they write in global vs. local model. That feels like the important thing.
- subsections: | ||
- name: "Auto sharding mode" | ||
- name: “Explicit sharding mode” | ||
- name: "jax.jit: the automatic parallelism solution" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like how jax.jit
is being painted as the automatic parallelism solution. Can we go the other way around and paint it as a more explicit parallelism solution?
Or maybe keep the language more neutral?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say more? It seems like it's the thing that automatically puts stuff on multiple devices with little to no intervention.
At @mattjj's suggestion, the distinction between auto and manual mode may be too great. This change reduces this overhead by pulling this distinction deeper in the section.