Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 120 additions & 5 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import dataclasses
import functools
import inspect
import operator
import typing as tp

import jax
Expand Down Expand Up @@ -389,14 +391,25 @@ def __init__(
out_shardings,
)

if isinstance(in_shardings, (tuple, list)) and (static_argnums or static_argnames):
# We should reintroduce None values into in_shardings corresponding to static arguments
static_argnums = _resolve_argnums(fun, static_argnums, static_argnames)
in_shardings = list(in_shardings)
for static_arg_index in sorted(static_argnums):
in_shardings.insert(static_arg_index, None)
in_shardings = tuple(in_shardings)

jax_out_in_shardings = jax.tree.map(
lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x)
if isinstance(x, StateSharding)
else x,
in_shardings,
)

self.jitted_fn = jax.jit(
JitFn(fun, in_shardings, out_shardings, kwarg_shardings, self),
in_shardings=self.jax_in_shardings,
out_shardings=(
self.jax_in_shardings,
kwarg_shardings,
self.jax_out_shardings,
),
out_shardings=(jax_out_in_shardings, kwarg_shardings, self.jax_out_shardings),
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
Expand Down Expand Up @@ -1031,3 +1044,105 @@ def shard_map_wrapper(*args, **kwargs):
shard_map_wrapper.inner = shard_map_fn # type: ignore

return shard_map_wrapper # type: ignore


# We can't use private methods from jax._src.api_util
# We copy the function: api_util.fun_signature
def _fun_signature(fun: tp.Callable) -> inspect.Signature | None:
try:
return inspect.signature(fun)
except (ValueError, TypeError):
return None

# Adapted copy of private jax function from api_util: fun_signature
def _resolve_argnums(
fun: tp.Callable,
static_argnums: int | tp.Sequence[int] | None,
static_argnames: str | tp.Iterable[str] | None,
) -> tuple[int, ...]:
def _ensure_index_tuple(x: tp.Any) -> tuple[int, ...]:
"""Convert x to a tuple of indices."""
try:
return (operator.index(x),)
except TypeError:
return tuple(map(operator.index, x))

def _ensure_str(x: str) -> str:
if not isinstance(x, str):
raise TypeError(f"argument is not a string: {x}")
return x

def _ensure_str_tuple(x: str | tp.Iterable[str]) -> tuple[str, ...]:
"""Convert x to a tuple of strings."""
if isinstance(x, str):
return (x,)
else:
return tuple(map(_ensure_str, x))

signature = _fun_signature(fun)
if signature is None:
# Some built-in functions don't support signature.
# See: https://github.com/python/cpython/issues/73485
# In this case no validation is done
static_argnums = () if static_argnums is None else _ensure_index_tuple(
static_argnums)
else:
# Infer argnums and argnames according to docstring
# If nums is None and names is not None, then nums are inferred from the
# names and vice-versa.
_POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
_POSITIONAL_ARGUMENTS = (
inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD
)

def infer_argnums_and_argnames(
sig: inspect.Signature,
argnums: int | tp.Iterable[int] | None,
argnames: str | tp.Iterable[str] | None,
) -> tuple[tuple[int, ...], tuple[str, ...]]:
"""Infer missing argnums and argnames for a function with inspect."""
if argnums is None and argnames is None:
return (), ()

if argnums is not None and argnames is not None:
argnums = _ensure_index_tuple(argnums)
argnames = _ensure_str_tuple(argnames)
return argnums, argnames

parameters = sig.parameters
if argnums is None:
assert argnames is not None
argnames = _ensure_str_tuple(argnames)
argnums = tuple(
i for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames
)
else:
argnums = _ensure_index_tuple(argnums)
argnames = tuple(
k for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums
)
return argnums, argnames

def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None:
n_pos_args = 0
for param in sig.parameters.values():
if param.kind in _POSITIONAL_ARGUMENTS:
n_pos_args += 1

elif param.kind is inspect.Parameter.VAR_POSITIONAL:
# We can have any number of positional arguments
return

if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args):
raise ValueError(f"Jitted function has {argnums_name}={argnums}, "
f"but only accepts {n_pos_args} positional arguments.")

static_argnums, static_argnames = infer_argnums_and_argnames(
signature, static_argnums, static_argnames)

# Validation
_validate_argnums(signature, static_argnums, "static_argnums")

return static_argnums
68 changes: 67 additions & 1 deletion tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@



class TestJIT(absltest.TestCase):
class TestJIT(parameterized.TestCase):
def test_jit(self):
m = nnx.Dict(a=nnx.Param(1))

Expand Down Expand Up @@ -436,6 +436,72 @@ def f(m: nnx.Linear, x):
y = compiled(m, x)
self.assertEqual(m.count[...], 2)

@parameterized.parameters(
{'static_argnums': (2,), 'static_argnames': None},
{'static_argnums': None, 'static_argnames': ('use_relu',)},
)
def test_jit_static_args_with_shardings(self, static_argnums, static_argnames):
"""Test static arguments work correctly with in_shardings."""
n_devices = jax.local_device_count()
devices = mesh_utils.create_device_mesh((n_devices,))
mesh = jax.sharding.Mesh(devices, ('data',))

def fn(x, scale, use_relu):
y = x * scale
if use_relu:
y = jnp.maximum(y, 0)
return y.sum()

x = jnp.linspace(-1.0, 1.0, 16, dtype=jnp.float32).reshape(4, 4)
x_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('data'))

f = nnx.jit(fn, in_shardings=(x_sharding, None),
static_argnums=static_argnums, static_argnames=static_argnames)
y_relu = f(x, 0.5, True)
y_no_relu = f(x, 0.5, False)
self.assertNotEqual(y_relu, y_no_relu)

@parameterized.parameters(
{
'static_args': {'static_argnums': (2, 3)},
},
{
'static_args': {'static_argnames': ('static_arg1', 'static_arg2')},
},
)
def test_with_sharding_and_static_args(self, static_args):
n_devices = max(jax.local_device_count() // 2, 1)
devices = mesh_utils.create_device_mesh(
(n_devices, jax.local_device_count() // n_devices)
)
mesh = jax.sharding.Mesh(devices, ('a', 'b'))

def sharding(*args):
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*args))

state_sharding = nnx.StateSharding(
{
nnx.PathContains('kernel'): sharding('a', 'b'),
nnx.PathContains('bias'): sharding('b'),
}
)

m = nnx.Linear(16, 32, rngs=nnx.Rngs(0))
self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding)

@nnx.jit(
in_shardings=(state_sharding, None),
**static_args,
)
def constrain_object(m, scale: float, static_arg1: bool, static_arg2: bool):
new_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('b', 'a'))
m.kernel = jax.lax.with_sharding_constraint(m.kernel, new_sharding)
return None

constrain_object(m, 0.5, True, True)
self.assertEqual(m.kernel.sharding.spec, jax.sharding.PartitionSpec("a", "b"))


class TestEvalShape(absltest.TestCase):
def test_eval_shape(self):
abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0)))
Expand Down
Loading