diff --git a/flax/core/spmd.py b/flax/core/spmd.py index 6d3a447e0..beed9b445 100644 --- a/flax/core/spmd.py +++ b/flax/core/spmd.py @@ -12,24 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import contextlib import dataclasses import threading +import typing as tp import jax from jax.sharding import PartitionSpec, NamedSharding from flax.core import meta from flax.typing import ( LogicalRules, - Sharding, ) def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec: """Given an `nnx.Variable`, return its `PartitionSpec`.""" if get_logical_axis_rules() or sharding_rules: - context_rules = get_logical_axis_rules() - rules = composite_rules(context_rules, sharding_rules) - return PartitionSpec(*from_sharding_rules(sharding_names, rules)) + sharding_names = logical_to_mesh_axes(sharding_names, sharding_rules) return PartitionSpec(*sharding_names) @@ -105,10 +104,119 @@ def composite_rules(rule1, rule2): return tuple(rules.items()) -def from_sharding_rules( - sharding: Sharding, sharding_rules: LogicalRules -) -> Sharding: - rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} - return tuple( - rules[str(s)] if (s and str(s) in rules) else s for s in sharding + +class _UnassignedAxis: + """Sentinel class for unassigned logical axis name.""" + + def __repr__(self): + return 'UnassignedAxis' + + def __bool__(self): + return False + + +_unassigned_axis = _UnassignedAxis() + + +def _mesh_assignment_free(new_assignment, existing_assignments): + """Determines if a given mesh axis has already been assigned.""" + new = set(jax.tree_util.tree_leaves(new_assignment)) + existing = set(jax.tree_util.tree_leaves(existing_assignments)) + if existing.intersection(new): + return False + return True + + +def _logical_to_mesh_axes( + array_dim_names: tp.Sequence[str | None] | None, + rules: LogicalRules | None = None, +) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None: + """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis.""" + if array_dim_names is None: + return None + if rules is None: + rules = get_logical_axis_rules() + axis_name_counts = collections.Counter(array_dim_names) + # None and special values such as PartitionSpec.UNCONSTRAINED can appear more + # then once. + dups = tuple( + k for k, v in axis_name_counts.items() if v > 1 and isinstance(k, str) ) + if dups: + raise ValueError( + f'Unsupported: Dimensions {dups} occur more than once in array names.' + ) + if not isinstance(rules, (tuple, list)): + raise ValueError('Unknown axis rule specification type.') + # We assign mesh axes using a priority based ruleset over logical axis names. + result: list[_UnassignedAxis | None | str | tuple[str, ...]] + result = [ + (_unassigned_axis if isinstance(name, str) else name) + for name in array_dim_names + ] + for rule_model_name, rule_mesh_names in rules: + if rule_model_name in array_dim_names: + pos = array_dim_names.index(rule_model_name) + if ( + _mesh_assignment_free(rule_mesh_names, result) + and result[pos] == _unassigned_axis + ): + result[pos] = rule_mesh_names + return result + + +def logical_to_mesh_axes( + array_dim_names: tp.Sequence[str | None] | None, + rules: LogicalRules | None = None, +) -> jax.sharding.PartitionSpec | None: + """Compute layout for an array. + + The rules are in order of precedence, and consist of pairs: + ``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array + dimension (if present and unused) should be sharded across the given + mesh dimension (if present and unused). + + A Layout of an Array is expressed as a tuple with one element for each + dimension in the Array. The element is either None, or is the name of a + mesh-dimension, meaning that this dimension of the array is sharded across + this dimension of the mesh. + + For example, given an array with:: + + array_dim_names = ('batch', 'length', 'heads', 'features') + + and the layout rules are:: + + rules = (('batch', 'X'), + ('features', 'X'), + ('heads', 'Y'), + ('batch', 'Z')) + + then this function will return:: + + PartitionSpec('X', None, 'Y', None) + + Args: + array_dim_names: Tuple of array dimension names or None. + rules: Optional logical to mesh rules override. Defaults to using the + rules defined in the dynamic context set from the ``axis_rules`` function. + + Returns: + PartitionSpec for the parameter. + """ + result = _logical_to_mesh_axes(array_dim_names, rules) + if result is None: + return None + # We default to None - ie unsharded along the dimension. + result = [None if x is _unassigned_axis else x for x in result] + return jax.sharding.PartitionSpec(*result) + + + +# def from_sharding_rules( +# sharding_names: Sharding, sharding_rules: LogicalRules +# ) -> Sharding: +# rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} +# return tuple( +# rules[str(s)] if (s and str(s) in rules) else s for s in sharding_names +# ) diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index b68487b47..715656b25 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -24,12 +24,11 @@ introducing logical axis metadata into a model's variables. """ -import collections import dataclasses import enum import functools from typing import Any -from collections.abc import Callable, Sequence +from collections.abc import Callable import jax from jax import lax @@ -37,7 +36,10 @@ from flax import struct from flax.core import meta from flax.core.spmd import ( + _logical_to_mesh_axes, + _unassigned_axis, get_logical_axis_rules, + logical_to_mesh_axes, ) from flax.typing import ( Array, @@ -49,111 +51,7 @@ ) -class _UnassignedAxis: - """Sentinel class for unassigned logical axis name.""" - def __repr__(self): - return 'UnassignedAxis' - - def __bool__(self): - return False - - -_unassigned_axis = _UnassignedAxis() - - -def _mesh_assignment_free(new_assignment, existing_assignments): - """Determines if a given mesh axis has already been assigned.""" - new = set(jax.tree_util.tree_leaves(new_assignment)) - existing = set(jax.tree_util.tree_leaves(existing_assignments)) - if existing.intersection(new): - return False - return True - - -def _logical_to_mesh_axes( - array_dim_names: Sequence[str | None] | None, - rules: LogicalRules | None = None, -) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None: - """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis.""" - if array_dim_names is None: - return None - if rules is None: - rules = get_logical_axis_rules() - axis_name_counts = collections.Counter(array_dim_names) - # None and special values such as PartitionSpec.UNCONSTRAINED can appear more - # then once. - dups = tuple( - k for k, v in axis_name_counts.items() if v > 1 and isinstance(k, str) - ) - if dups: - raise ValueError( - f'Unsupported: Dimensions {dups} occur more than once in array names.' - ) - if not isinstance(rules, (tuple, list)): - raise ValueError('Unknown axis rule specification type.') - # We assign mesh axes using a priority based ruleset over logical axis names. - result: list[_UnassignedAxis | None | str | tuple[str, ...]] - result = [ - (_unassigned_axis if isinstance(name, str) else name) - for name in array_dim_names - ] - for rule_model_name, rule_mesh_names in rules: - if rule_model_name in array_dim_names: - pos = array_dim_names.index(rule_model_name) - if ( - _mesh_assignment_free(rule_mesh_names, result) - and result[pos] == _unassigned_axis - ): - result[pos] = rule_mesh_names - return result - - -def logical_to_mesh_axes( - array_dim_names: Sequence[str | None] | None, - rules: LogicalRules | None = None, -) -> jax.sharding.PartitionSpec | None: - """Compute layout for an array. - - The rules are in order of precedence, and consist of pairs: - ``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array - dimension (if present and unused) should be sharded across the given - mesh dimension (if present and unused). - - A Layout of an Array is expressed as a tuple with one element for each - dimension in the Array. The element is either None, or is the name of a - mesh-dimension, meaning that this dimension of the array is sharded across - this dimension of the mesh. - - For example, given an array with:: - - array_dim_names = ('batch', 'length', 'heads', 'features') - - and the layout rules are:: - - rules = (('batch', 'X'), - ('features', 'X'), - ('heads', 'Y'), - ('batch', 'Z')) - - then this function will return:: - - PartitionSpec('X', None, 'Y', None) - - Args: - array_dim_names: Tuple of array dimension names or None. - rules: Optional logical to mesh rules override. Defaults to using the - rules defined in the dynamic context set from the ``axis_rules`` function. - - Returns: - PartitionSpec for the parameter. - """ - result = _logical_to_mesh_axes(array_dim_names, rules) - if result is None: - return None - # We default to None - ie unsharded along the dimension. - result = [None if x is _unassigned_axis else x for x in result] - return jax.sharding.PartitionSpec(*result) def logical_to_mesh(tree: Any, rules: LogicalRules | None = None) -> Any: diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index 756165af9..364a0f93b 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -129,13 +129,11 @@ def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None: """Given an `nnx.Variable`, return its `PartitionSpec`.""" metadata = v.get_metadata() if 'sharding_names' in metadata and metadata['sharding_names']: - sharding = metadata['sharding_names'] + sharding_names = metadata['sharding_names'] if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata: - context_rules = core_spmd.get_logical_axis_rules() - local_rules = metadata.get('sharding_rules', ()) - rules = core_spmd.composite_rules(context_rules, local_rules) - return PartitionSpec(*core_spmd.from_sharding_rules(sharding, rules)) - return PartitionSpec(*sharding) + sharding_names = core_spmd.logical_to_mesh_axes( + sharding_names, metadata.get('sharding_rules', None)) + return PartitionSpec(*sharding_names) elif hasattr(v, 'shape'): return PartitionSpec() return None diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index c05f9dd4d..f223968cc 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' + from absl.testing import absltest from absl.testing import parameterized from flax import nnx @@ -69,10 +72,10 @@ def loss_fn(model): optimizer.update(model, grads) def test_sharding_propagation(self): - with jax.set_mesh(jax.make_mesh(((1, 1)), ('a', 'b'))): + with jax.set_mesh(jax.make_mesh(((2, 2)), ('a', 'b'))): model = nnx.Linear( 2, - 3, + 4, rngs=nnx.Rngs(0), kernel_init=nnx.with_partitioning( nnx.initializers.lecun_normal(), diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 9e8abe400..4bdc7fd6e 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -185,12 +185,12 @@ def __init__(self): nnx.with_partitioning( lambda: jnp.ones((8, 2)), sharding=('row-alias', 'col-alias'), - sharding_rules=(('row-alias', 'row'),), )() ) self.b = nnx.Param( nnx.with_partitioning( - lambda: jnp.zeros((2,)), sharding=('col-alias',) + lambda: jnp.zeros((2,)), sharding=('col-alias2',), + sharding_rules=(('col-alias2', 'col'),), )() ) @@ -198,7 +198,8 @@ def __call__(self, x): return x @ self.w + self.b mesh = jax.make_mesh(((1, 2, 2)), ('layers', 'row', 'col')) - with jax.set_mesh(mesh), nnx.logical_axis_rules((('col-alias', 'col'),)): + global_rule = (('row-alias', 'row'),('col-alias', 'col'),) + with jax.set_mesh(mesh), nnx.logical_axis_rules(global_rule): model = Foo() optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)