Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 118 additions & 10 deletions flax/core/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import contextlib
import dataclasses
import threading
import typing as tp

import jax
from jax.sharding import PartitionSpec, NamedSharding
from flax.core import meta
from flax.typing import (
LogicalRules,
Sharding,
)

def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec:
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
if get_logical_axis_rules() or sharding_rules:
context_rules = get_logical_axis_rules()
rules = composite_rules(context_rules, sharding_rules)
return PartitionSpec(*from_sharding_rules(sharding_names, rules))
sharding_names = logical_to_mesh_axes(sharding_names, sharding_rules)
return PartitionSpec(*sharding_names)


Expand Down Expand Up @@ -105,10 +104,119 @@ def composite_rules(rule1, rule2):
return tuple(rules.items())


def from_sharding_rules(
sharding: Sharding, sharding_rules: LogicalRules
) -> Sharding:
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return tuple(
rules[str(s)] if (s and str(s) in rules) else s for s in sharding

class _UnassignedAxis:
"""Sentinel class for unassigned logical axis name."""

def __repr__(self):
return 'UnassignedAxis'

def __bool__(self):
return False


_unassigned_axis = _UnassignedAxis()


def _mesh_assignment_free(new_assignment, existing_assignments):
"""Determines if a given mesh axis has already been assigned."""
new = set(jax.tree_util.tree_leaves(new_assignment))
existing = set(jax.tree_util.tree_leaves(existing_assignments))
if existing.intersection(new):
return False
return True


def _logical_to_mesh_axes(
array_dim_names: tp.Sequence[str | None] | None,
rules: LogicalRules | None = None,
) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None:
"""Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis."""
if array_dim_names is None:
return None
if rules is None:
rules = get_logical_axis_rules()
axis_name_counts = collections.Counter(array_dim_names)
# None and special values such as PartitionSpec.UNCONSTRAINED can appear more
# then once.
dups = tuple(
k for k, v in axis_name_counts.items() if v > 1 and isinstance(k, str)
)
if dups:
raise ValueError(
f'Unsupported: Dimensions {dups} occur more than once in array names.'
)
if not isinstance(rules, (tuple, list)):
raise ValueError('Unknown axis rule specification type.')
# We assign mesh axes using a priority based ruleset over logical axis names.
result: list[_UnassignedAxis | None | str | tuple[str, ...]]
result = [
(_unassigned_axis if isinstance(name, str) else name)
for name in array_dim_names
]
for rule_model_name, rule_mesh_names in rules:
if rule_model_name in array_dim_names:
pos = array_dim_names.index(rule_model_name)
if (
_mesh_assignment_free(rule_mesh_names, result)
and result[pos] == _unassigned_axis
):
result[pos] = rule_mesh_names
return result


def logical_to_mesh_axes(
array_dim_names: tp.Sequence[str | None] | None,
rules: LogicalRules | None = None,
) -> jax.sharding.PartitionSpec | None:
"""Compute layout for an array.

The rules are in order of precedence, and consist of pairs:
``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array
dimension (if present and unused) should be sharded across the given
mesh dimension (if present and unused).

A Layout of an Array is expressed as a tuple with one element for each
dimension in the Array. The element is either None, or is the name of a
mesh-dimension, meaning that this dimension of the array is sharded across
this dimension of the mesh.

For example, given an array with::

array_dim_names = ('batch', 'length', 'heads', 'features')

and the layout rules are::

rules = (('batch', 'X'),
('features', 'X'),
('heads', 'Y'),
('batch', 'Z'))

then this function will return::

PartitionSpec('X', None, 'Y', None)

Args:
array_dim_names: Tuple of array dimension names or None.
rules: Optional logical to mesh rules override. Defaults to using the
rules defined in the dynamic context set from the ``axis_rules`` function.

Returns:
PartitionSpec for the parameter.
"""
result = _logical_to_mesh_axes(array_dim_names, rules)
if result is None:
return None
# We default to None - ie unsharded along the dimension.
result = [None if x is _unassigned_axis else x for x in result]
return jax.sharding.PartitionSpec(*result)



# def from_sharding_rules(
# sharding_names: Sharding, sharding_rules: LogicalRules
# ) -> Sharding:
# rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
# return tuple(
# rules[str(s)] if (s and str(s) in rules) else s for s in sharding_names
# )
110 changes: 4 additions & 106 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,22 @@
introducing logical axis metadata into a model's variables.
"""

import collections
import dataclasses
import enum
import functools
from typing import Any
from collections.abc import Callable, Sequence
from collections.abc import Callable

import jax
from jax import lax

from flax import struct
from flax.core import meta
from flax.core.spmd import (
_logical_to_mesh_axes,
_unassigned_axis,
get_logical_axis_rules,
logical_to_mesh_axes,
)
from flax.typing import (
Array,
Expand All @@ -49,111 +51,7 @@
)


class _UnassignedAxis:
"""Sentinel class for unassigned logical axis name."""

def __repr__(self):
return 'UnassignedAxis'

def __bool__(self):
return False


_unassigned_axis = _UnassignedAxis()


def _mesh_assignment_free(new_assignment, existing_assignments):
"""Determines if a given mesh axis has already been assigned."""
new = set(jax.tree_util.tree_leaves(new_assignment))
existing = set(jax.tree_util.tree_leaves(existing_assignments))
if existing.intersection(new):
return False
return True


def _logical_to_mesh_axes(
array_dim_names: Sequence[str | None] | None,
rules: LogicalRules | None = None,
) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None:
"""Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis."""
if array_dim_names is None:
return None
if rules is None:
rules = get_logical_axis_rules()
axis_name_counts = collections.Counter(array_dim_names)
# None and special values such as PartitionSpec.UNCONSTRAINED can appear more
# then once.
dups = tuple(
k for k, v in axis_name_counts.items() if v > 1 and isinstance(k, str)
)
if dups:
raise ValueError(
f'Unsupported: Dimensions {dups} occur more than once in array names.'
)
if not isinstance(rules, (tuple, list)):
raise ValueError('Unknown axis rule specification type.')
# We assign mesh axes using a priority based ruleset over logical axis names.
result: list[_UnassignedAxis | None | str | tuple[str, ...]]
result = [
(_unassigned_axis if isinstance(name, str) else name)
for name in array_dim_names
]
for rule_model_name, rule_mesh_names in rules:
if rule_model_name in array_dim_names:
pos = array_dim_names.index(rule_model_name)
if (
_mesh_assignment_free(rule_mesh_names, result)
and result[pos] == _unassigned_axis
):
result[pos] = rule_mesh_names
return result


def logical_to_mesh_axes(
array_dim_names: Sequence[str | None] | None,
rules: LogicalRules | None = None,
) -> jax.sharding.PartitionSpec | None:
"""Compute layout for an array.

The rules are in order of precedence, and consist of pairs:
``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array
dimension (if present and unused) should be sharded across the given
mesh dimension (if present and unused).

A Layout of an Array is expressed as a tuple with one element for each
dimension in the Array. The element is either None, or is the name of a
mesh-dimension, meaning that this dimension of the array is sharded across
this dimension of the mesh.

For example, given an array with::

array_dim_names = ('batch', 'length', 'heads', 'features')

and the layout rules are::

rules = (('batch', 'X'),
('features', 'X'),
('heads', 'Y'),
('batch', 'Z'))

then this function will return::

PartitionSpec('X', None, 'Y', None)

Args:
array_dim_names: Tuple of array dimension names or None.
rules: Optional logical to mesh rules override. Defaults to using the
rules defined in the dynamic context set from the ``axis_rules`` function.

Returns:
PartitionSpec for the parameter.
"""
result = _logical_to_mesh_axes(array_dim_names, rules)
if result is None:
return None
# We default to None - ie unsharded along the dimension.
result = [None if x is _unassigned_axis else x for x in result]
return jax.sharding.PartitionSpec(*result)


def logical_to_mesh(tree: Any, rules: LogicalRules | None = None) -> Any:
Expand Down
10 changes: 4 additions & 6 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,11 @@ def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None:
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
metadata = v.get_metadata()
if 'sharding_names' in metadata and metadata['sharding_names']:
sharding = metadata['sharding_names']
sharding_names = metadata['sharding_names']
if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata:
context_rules = core_spmd.get_logical_axis_rules()
local_rules = metadata.get('sharding_rules', ())
rules = core_spmd.composite_rules(context_rules, local_rules)
return PartitionSpec(*core_spmd.from_sharding_rules(sharding, rules))
return PartitionSpec(*sharding)
sharding_names = core_spmd.logical_to_mesh_axes(
sharding_names, metadata.get('sharding_rules', None))
return PartitionSpec(*sharding_names)
elif hasattr(v, 'shape'):
return PartitionSpec()
return None
Expand Down
7 changes: 5 additions & 2 deletions tests/nnx/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

from absl.testing import absltest
from absl.testing import parameterized
from flax import nnx
Expand Down Expand Up @@ -69,10 +72,10 @@ def loss_fn(model):
optimizer.update(model, grads)

def test_sharding_propagation(self):
with jax.set_mesh(jax.make_mesh(((1, 1)), ('a', 'b'))):
with jax.set_mesh(jax.make_mesh(((2, 2)), ('a', 'b'))):
model = nnx.Linear(
2,
3,
4,
rngs=nnx.Rngs(0),
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(),
Expand Down
7 changes: 4 additions & 3 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,21 @@ def __init__(self):
nnx.with_partitioning(
lambda: jnp.ones((8, 2)),
sharding=('row-alias', 'col-alias'),
sharding_rules=(('row-alias', 'row'),),
)()
)
self.b = nnx.Param(
nnx.with_partitioning(
lambda: jnp.zeros((2,)), sharding=('col-alias',)
lambda: jnp.zeros((2,)), sharding=('col-alias2',),
sharding_rules=(('col-alias2', 'col'),),
)()
)

def __call__(self, x):
return x @ self.w + self.b

mesh = jax.make_mesh(((1, 2, 2)), ('layers', 'row', 'col'))
with jax.set_mesh(mesh), nnx.logical_axis_rules((('col-alias', 'col'),)):
global_rule = (('row-alias', 'row'),('col-alias', 'col'),)
with jax.set_mesh(mesh), nnx.logical_axis_rules(global_rule):
model = Foo()
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

Expand Down
Loading