-
Notifications
You must be signed in to change notification settings - Fork 97
Updates JAX automatic parallelism section to reduce the complexity of explicit mode sharding. #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
@yashk2810 for a review |
2. **JAX, take the wheel!** Automatic parallelism is great, but sometimes the compiler does something crazy. Explicit sharding lets you write single-device code like usual, but have JAX handle sharding propagation (not the compiler). This means JAX can ask you for clarification when it's unclear what you want. | ||
3. **Just let me write what I mean, damnit!** While compilers are nice, they sometimes do the wrong thing and add communication you don't intend. Sometimes we want to be explicit about exactly what communication you intend to run. | ||
1. **Compiler, take the wheel!** Let the compiler automatically partition arrays and decide what communication to add to facilitate a given program. This lets you take a program that runs on a single device and automatically run it on thousands without changing anything. | ||
2. **Just let me write what I mean, damnit!** While compilers are nice, they sometimes do the wrong thing and add communication you don't intend. Sometimes we want to be explicit about exactly what communication you intend to run. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should only mention Explicit and Manual mode here if you only want 2 sections.
Mentioning Compiler and Manual doesn't seem right to me.
This section in JAX docs: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html#using-a-mixture-of-sharding-modes does a pretty good job at explaining all modes. Maybe we should copy it here as I originally did in the rewrite :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel like Explicit vs. Manual is a very important distinction TBH (even though I love shit). From the user's perspective, either they write in global vs. local model. That feels like the important thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree. It is an important distinction since it is a different way of programming and it affects how the code will be written.
That feels like the important thing.
That was the old way of thinking about it when there were only 2 modes. We have 3 modes right now. So the working needs to evolve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't think it matters if the user is just starting to learn JAX. From their point of view, they want to know that (a) they can let the compiler add collectives or (b) they can write explicit collectives. The fact that they can see sharding in the type system is an interesting detail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this chapter should be geared towards someone who is starting to use JAX. They should go read the JAX docs for that.
Given that this is a "book" like structure, by the time they hit this chapter they would have learned a bunch of parallelism and sharding strategies. So telling them about everything that JAX offers is the right thing to do IMO.
- subsections: | ||
- name: "Auto sharding mode" | ||
- name: “Explicit sharding mode” | ||
- name: "jax.jit: the automatic parallelism solution" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like how jax.jit
is being painted as the automatic parallelism solution. Can we go the other way around and paint it as a more explicit parallelism solution?
Or maybe keep the language more neutral?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say more? It seems like it's the thing that automatically puts stuff on multiple devices with little to no intervention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not really the description of jax.jit
. That's just how the compiler works. jit can work wtih explicit mode too.
So calling jax.jit: The automatic parallelsim solution"
is incorrect. I think these section names and headings need to be redone and called as Auto mode
, Explicit mode
and Manual mod
. JAX APIs shouldn't be the focus in the title.
For example, you can imagine dropping into Manual mode via a context decorator without the use of shard_map at all (it doesn't work today but in principle, it should).
So overall, I think I would like to change the framing of JAX sharding APIs in this doc and align it more with the JAX docs. Given that this section is all about JAX, it should match the new way of thinking of JAX sharding APIs.
@yashk2810 ping |
At @mattjj's suggestion, the distinction between auto and manual mode may be too great. This change reduces this overhead by pulling this distinction deeper in the section.