Skip to content

Commit c0ea12d

Browse files
author
Flax Authors
committed
Merge pull request #5033 from IvyZX:shard-msg
PiperOrigin-RevId: 820437084
2 parents 303f8b2 + da8ef36 commit c0ea12d

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

docs_nnx/flip/4844-var-eager-sharding.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
Simplify the creation of sharded NNX models. When a sharding annotation is provided, all `nnx.Variable` creation will **require a mesh context** and automatically be sharded as annotated.
1010

11+
See [GSPMD Guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) for a comprehensive guide on how to make sharded NNX models.
12+
1113
# Motivation
1214

1315
To create a sharded model, user should only need to do this:
@@ -44,6 +46,19 @@ User can turn off this feature in two ways:
4446
* **Variable-specific flag**: Create a specific variable with metadata `eager_sharding=False`, such as: `nnx.Param(..., eager_sharding=False)`.
4547

4648

49+
# Flexibility options
50+
51+
For debugging in a CPU environment, make a dummy mesh to run the model:
52+
53+
```python
54+
mesh = jax.make_mesh(((1, 1, 1)), ('your', 'axes', 'names'))
55+
with jax.set_mesh(mesh):
56+
...
57+
```
58+
59+
For JAX explicit mode, remove the `sharding_names=` annotation on the `nnx.Variable`.
60+
61+
4762
# Implementation
4863
[implementation]: #implementation
4964

flax/core/spmd.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ def shard_value(value, sharding_names, sharding_rules, mesh):
4040
raise ValueError(
4141
'An auto mesh context or metadata is required if creating a variable'
4242
f' with annotation {sharding_names=}. '
43-
'If running this on CPU for debugging, make a'
44-
' dummy mesh like `jax.make_mesh(((1, 1)), (<your axis names>))`. '
45-
'If running on explicit mode, remove `sharding_names=` annotation.')
43+
'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.')
4644
pspec = get_pspec(sharding_names, sharding_rules)
4745
if mesh is not None:
4846
jax.lax.with_sharding_constraint(value, NamedSharding(mesh, pspec))

0 commit comments

Comments
 (0)