Skip to content

Conversation

jacobaustin123
Copy link
Collaborator

At @mattjj's suggestion, the distinction between auto and manual mode may be too great. This change reduces this overhead by pulling this distinction deeper in the section.

@google-cla
Copy link

google-cla bot commented Sep 19, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@jacobaustin123
Copy link
Collaborator Author

@yashk2810 for a review

2. **JAX, take the wheel!** Automatic parallelism is great, but sometimes the compiler does something crazy. Explicit sharding lets you write single-device code like usual, but have JAX handle sharding propagation (not the compiler). This means JAX can ask you for clarification when it's unclear what you want.
3. **Just let me write what I mean, damnit!** While compilers are nice, they sometimes do the wrong thing and add communication you don't intend. Sometimes we want to be explicit about exactly what communication you intend to run.
1. **Compiler, take the wheel!** Let the compiler automatically partition arrays and decide what communication to add to facilitate a given program. This lets you take a program that runs on a single device and automatically run it on thousands without changing anything.
2. **Just let me write what I mean, damnit!** While compilers are nice, they sometimes do the wrong thing and add communication you don't intend. Sometimes we want to be explicit about exactly what communication you intend to run.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should only mention Explicit and Manual mode here if you only want 2 sections.

Mentioning Compiler and Manual doesn't seem right to me.

This section in JAX docs: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html#using-a-mixture-of-sharding-modes does a pretty good job at explaining all modes. Maybe we should copy it here as I originally did in the rewrite :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel like Explicit vs. Manual is a very important distinction TBH (even though I love shit). From the user's perspective, either they write in global vs. local model. That feels like the important thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree. It is an important distinction since it is a different way of programming and it affects how the code will be written.

That feels like the important thing.

That was the old way of thinking about it when there were only 2 modes. We have 3 modes right now. So the working needs to evolve.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't think it matters if the user is just starting to learn JAX. From their point of view, they want to know that (a) they can let the compiler add collectives or (b) they can write explicit collectives. The fact that they can see sharding in the type system is an interesting detail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this chapter should be geared towards someone who is starting to use JAX. They should go read the JAX docs for that.

Given that this is a "book" like structure, by the time they hit this chapter they would have learned a bunch of parallelism and sharding strategies. So telling them about everything that JAX offers is the right thing to do IMO.

- subsections:
- name: "Auto sharding mode"
- name: “Explicit sharding mode
- name: "jax.jit: the automatic parallelism solution"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like how jax.jit is being painted as the automatic parallelism solution. Can we go the other way around and paint it as a more explicit parallelism solution?

Or maybe keep the language more neutral?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more? It seems like it's the thing that automatically puts stuff on multiple devices with little to no intervention.

Copy link
Contributor

@yashk2810 yashk2810 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not really the description of jax.jit. That's just how the compiler works. jit can work wtih explicit mode too.

So calling jax.jit: The automatic parallelsim solution" is incorrect. I think these section names and headings need to be redone and called as Auto mode, Explicit mode and Manual mod. JAX APIs shouldn't be the focus in the title.

For example, you can imagine dropping into Manual mode via a context decorator without the use of shard_map at all (it doesn't work today but in principle, it should).

So overall, I think I would like to change the framing of JAX sharding APIs in this doc and align it more with the JAX docs. Given that this section is all about JAX, it should match the new way of thinking of JAX sharding APIs.

@jacobaustin123
Copy link
Collaborator Author

@yashk2810 ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants