-
Notifications
You must be signed in to change notification settings - Fork 97
Updates JAX automatic parallelism section to reduce the complexity of explicit mode sharding. #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,8 +53,8 @@ authors: | |
| toc: | ||
| - name: "How Does Parallelism Work in JAX?" | ||
| - subsections: | ||
| - name: "Auto sharding mode" | ||
| - name: “Explicit sharding mode” | ||
| - name: "jax.jit: the automatic parallelism solution" | ||
| - name: "jax.jit + explicit sharding mode" | ||
| - name: "Manual sharding mode via shard_map" | ||
| - name: "Worked Problems" | ||
|
|
||
|
|
@@ -80,27 +80,20 @@ _styles: > | |
|
|
||
| ## How Does Parallelism Work in JAX? | ||
|
|
||
| JAX supports three schools of thought for multi-device programming: | ||
| JAX supports two schools of thought for multi-device programming: | ||
|
|
||
| 1. **Compiler, take the wheel!** Let the XLA compiler automatically partition arrays and decide what communication to add to facilitate a given program. This lets you take a program that runs on a single device and automatically run it on thousands without changing anything. | ||
| 2. **JAX, take the wheel!** Automatic parallelism is great, but sometimes the compiler does something crazy. Explicit sharding lets you write single-device code like usual, but have JAX handle sharding propagation (not the compiler). This means JAX can ask you for clarification when it's unclear what you want. | ||
| 3. **Just let me write what I mean, damnit!** While compilers are nice, they sometimes do the wrong thing and add communication you don't intend. Sometimes we want to be explicit about exactly what communication you intend to run. | ||
| 1. **Compiler, take the wheel!** Let the compiler automatically partition arrays and decide what communication to add to facilitate a given program. This lets you take a program that runs on a single device and automatically run it on thousands without changing anything. | ||
| 2. **Just let me write what I mean, damnit!** While compilers are nice, they sometimes do the wrong thing and add communication you don't intend. Sometimes we want to be explicit about exactly what communication you intend to run. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should only mention Explicit and Manual mode here if you only want 2 sections. Mentioning Compiler and Manual doesn't seem right to me. This section in JAX docs: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html#using-a-mixture-of-sharding-modes does a pretty good job at explaining all modes. Maybe we should copy it here as I originally did in the rewrite :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't feel like Explicit vs. Manual is a very important distinction TBH (even though I love shit). From the user's perspective, either they write in global vs. local model. That feels like the important thing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I disagree. It is an important distinction since it is a different way of programming and it affects how the code will be written.
That was the old way of thinking about it when there were only 2 modes. We have 3 modes right now. So the working needs to evolve. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't think it matters if the user is just starting to learn JAX. From their point of view, they want to know that (a) they can let the compiler add collectives or (b) they can write explicit collectives. The fact that they can see sharding in the type system is an interesting detail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this chapter should be geared towards someone who is starting to use JAX. They should go read the JAX docs for that. Given that this is a "book" like structure, by the time they hit this chapter they would have learned a bunch of parallelism and sharding strategies. So telling them about everything that JAX offers is the right thing to do IMO. |
||
|
|
||
| | Mode | View? | Explicit sharding? | Explicit Collectives? | | ||
| |:---:|:---:|:---:|:---:| | ||
| | Auto | Global | ❌ | ❌ | | ||
| | Explicit | Global | ✅ | ❌ | | ||
| | Manual | Per-device | ✅ | ✅ | | ||
| Correspondingly, JAX provides two APIs for each of these schools: **jit** (`jax.jit`) and **shard\_map** (`jax.shard_map`). | ||
|
|
||
| Correspondingly, JAX provides APIs for each of these modes: | ||
| 1. `jax.jit` lets you take any existing JAX function and call it with sharded inputs, leaving it up to JAX to partition the rest of the program automatically (either using XLA's [Shardy](https://openxla.org/shardy/getting_started_jax) compiler or [JAX's native "sharding in types"](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) system). While it isn't perfect, it usually does a decent job at automatically scaling your program to any number of chips. | ||
|
|
||
| 1. `jax.jit` (with `Auto` mesh axes) lets you take any existing JAX function and call it with sharded inputs. JAX then uses XLA's [Shardy](https://openxla.org/shardy) compiler which automatically parallelizes the program. XLA will add communication for you (AllGathers, ReduceScatters, AllReduces, etc.) when needed to facilitate existing operations. While it isn't perfect, it usually does a decent job at automatically scaling your program to any number of chips without code changes. | ||
| 2. `jax.jit` with `Explicit` mesh axes looks similar to (1), but lets JAX handle the sharding propagation instead of XLA. That means the sharding of an array is actually part of the JAX type system, and JAX can error out when it detects ambiguous communication and lets the user resolve it. | ||
| 3. `jax.shard_map` is the more manual counterpart. You get a device-local view of the program and have to write any communication you want explicitly. Have a sharded array and want the whole thing on each device? Add a `jax.lax.all_gather`. Want to sum an array across your devices? Add a `jax.lax.psum` (an AllReduce). Programming is harder but far less likely to do something you don't want. | ||
| 2. `jax.shard_map` is the more explicit counterpart. You get a device-local view of the program and have to write any communication you want explicitly. Have a sharded array and want the whole thing on each device? Add a `jax.lax.all_gather`. Want to sum an array across your devices? Add a `jax.lax.psum` (an AllReduce). Programming is harder but far less likely to do something you don't want. | ||
|
|
||
| <h3 id="auto-sharding-mode">Auto sharding mode</h3> | ||
| <h3 id="jax-jit-the-automatic-parallelism-solution">jax.jit: the automatic parallelism solution</h3> | ||
|
|
||
| jax.jit plays two roles inside JAX. As the name suggests, it "just-in-time" compiles a function from Python into bytecode (via XLA/HLO/LLO) so it runs faster. But if the input is sharded or the user specifies an `in_sharding` or `out_sharding`, it also lets XLA distribute the computation across multiple devices and add communication as needed. For example, here's how you could write a sharded matmul using jax.jit: | ||
| `jax.jit` plays two roles inside JAX. As the name suggests, it "just-in-time" compiles a function from Python into bytecode (via XLA/HLO/LLO) so it runs faster. But if the input is sharded or the user specifies an `in_sharding` or `out_sharding`, it also lets XLA distribute the computation across multiple devices and add communication as needed. For example, here's how you could write a sharded matmul using jax.jit: | ||
|
|
||
| ```py | ||
| import jax | ||
|
|
@@ -162,11 +155,11 @@ def matmul(x, Win, Wout): | |
| return jnp.einsum('bf,df->bd', hidden, Wout) | ||
| ``` | ||
|
|
||
| This makes up like 60% of JAX parallel programming in the automatic partitioning world where you control the intermediate shardings via `jax.lax.with_sharding_constraint`. But "compiler tickling" is famously not a fun programming model. You could annotate every intermediate variable and still not know if you'll get the right outcome. Instead, what if JAX itself could handle and control sharding propagation? | ||
| This makes up like 60% of JAX parallel programming in the automatic partitioning (`Auto`) world where you control the intermediate shardings via `jax.lax.with_sharding_constraint`. But "compiler tickling" is famously not a fun programming model. You could annotate every intermediate variable and still not know if you'll get the right outcome. Instead, what if JAX itself could handle and control sharding propagation? | ||
|
|
||
| <h3 id="explicit-sharding-mode">Explicit sharding mode</h3> | ||
| <h3 id="jax-jit-explicit-sharding-mode">jax.jit + explicit sharding mode</h3> | ||
|
|
||
| Explicit sharding (or “sharding in types”) looks a lot like automatic sharding, but sharding propagation happens at the JAX level! Each JAX operation has a sharding rule that takes the shardings of the op's arguments and produces a sharding for the op's result. You can see the resulting sharding using `jax.typeof`: | ||
| Explicit sharding (or “sharding in types”) looks a lot like automatic sharding, but sharding propagation happens at the JAX level! **This means the user can inspect the sharding at each point in the program.** You can see these sharding using `jax.typeof`: | ||
|
|
||
| ```py | ||
| import jax | ||
|
|
@@ -192,7 +185,9 @@ def f(x): | |
| f(x) | ||
| ``` | ||
|
|
||
| As you can see, JAX propagated the sharding from input (`x`) to output (`x`) which are inspectable at trace-time via `jax.typeof`. For most operations these rules are simple and obvious because there's only one reasonable choice (e.g. elementwise ops retain the same sharding). But for some operations it's ambiguous how to shard the result in which case JAX throws a trace-time error and we ask the programmer to provide an `out_sharding` argument explicitly (e.g. jnp.einsum, jnp.reshape, etc). Let's see another example where you have conflicts: | ||
| As you can see, JAX propagated the sharding from input (`x`) to output (`x`) which are inspectable at trace-time via `jax.typeof`. Each JAX operation has a sharding rule that takes the shardings of the op's arguments and produces a sharding for the op's result. | ||
|
|
||
| For most operations these rules are simple and obvious because there's only one reasonable choice (e.g. elementwise ops retain the same sharding). But for some operations it's ambiguous how to shard the result in which case JAX throws a trace-time error! **This is different from "Auto" mode above, which cannot error and will always produce a sharding, even if it's very bad!** When there's an ambiguity, we ask the programmer to provide an `out_sharding` argument explicitly (e.g. jnp.einsum, jnp.reshape, etc). Here's another example where you have conflicts: | ||
|
|
||
| ```py | ||
| # We create a matrix W and input activations In sharded across our devices. | ||
|
|
@@ -211,10 +206,10 @@ matmul_square(In, W) # This will error | |
| This code errors with `Contracting dimensions are sharded and it is ambiguous how the output should be sharded. Please specify the output sharding via the `out_sharding` parameter. Got lhs_contracting_spec=('Y',) and rhs_contracting_spec=('Y',)` | ||
|
|
||
| This is awesome because how the output of einsum should be sharded is ambiguous. The output sharding can be: | ||
| * P('X', 'Y') which will induce a reduce-scatter or | ||
| * P('X', None) which will induce an all-reduce | ||
| * `P('X', 'Y')`, which will induce a ReduceScatter. | ||
| * `P('X', None)`, which will induce an AllReduce. | ||
|
|
||
| Unlike Auto mode, explicit mode errors out when it detects ambiguous communication and requires the users to resolve it. So here you can do: | ||
| Unlike `Auto` mode, explicit mode errors out when it detects ambiguous communication and requires the users to resolve it. So here you can do: | ||
|
|
||
| ```py | ||
| @jax.jit | ||
|
|
@@ -225,7 +220,7 @@ out = matmul_square(In, W) | |
| print(jax.typeof(out)) # bfloat16[8@X,8192@Y] | ||
| ``` | ||
|
|
||
| Auto mode and Explicit mode can be composed via `jax.sharding.auto_axes` and `jax.sharding.explicit_axes` APIs. This is a [great doc to read](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) for more information. | ||
| `Auto` mode and `Explicit` mode can be composed via `jax.sharding.auto_axes` and `jax.sharding.explicit_axes` APIs. This is a [great doc to read](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) for more information. | ||
|
|
||
| <h3 id="manual-sharding-mode-via-shard_map">shard_map: explicit parallelism control over a program</h3> | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like how
jax.jitis being painted as the automatic parallelism solution. Can we go the other way around and paint it as a more explicit parallelism solution?Or maybe keep the language more neutral?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say more? It seems like it's the thing that automatically puts stuff on multiple devices with little to no intervention.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not really the description of
jax.jit. That's just how the compiler works. jit can work wtih explicit mode too.So calling
jax.jit: The automatic parallelsim solution"is incorrect. I think these section names and headings need to be redone and called asAuto mode,Explicit modeandManual mod. JAX APIs shouldn't be the focus in the title.For example, you can imagine dropping into Manual mode via a context decorator without the use of shard_map at all (it doesn't work today but in principle, it should).
So overall, I think I would like to change the framing of JAX sharding APIs in this doc and align it more with the JAX docs. Given that this section is all about JAX, it should match the new way of thinking of JAX sharding APIs.