Skip to content

Commit 38285f7

Browse files
committed
Removed private jax._src.api_util import and copied them into flax
1 parent b9ac1bc commit 38285f7

File tree

1 file changed

+114
-11
lines changed

1 file changed

+114
-11
lines changed

flax/nnx/transforms/compilation.py

Lines changed: 114 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
import dataclasses
1818
import functools
19+
import inspect
20+
import operator
1921
import typing as tp
2022

2123
import jax
2224
from jax.sharding import AbstractMesh, Mesh, PartitionSpec
23-
from jax._src import api_util # We use fun_signature and resolve_argnums
2425

2526
from flax.nnx import (
2627
extract,
@@ -390,17 +391,9 @@ def __init__(
390391
out_shardings,
391392
)
392393

393-
if isinstance(in_shardings, (list, tuple)):
394+
if isinstance(in_shardings, (tuple, list)) and (static_argnums or static_argnames):
394395
# We should reintroduce None values into in_shardings corresponding to static arguments
395-
fun_signature = api_util.fun_signature(fun)
396-
_, _, static_argnums, _ = api_util.resolve_argnums(
397-
fun,
398-
fun_signature,
399-
None,
400-
None,
401-
static_argnums,
402-
static_argnames,
403-
)
396+
static_argnums = resolve_argnums(fun, static_argnums, static_argnames)
404397
in_shardings = list(in_shardings)
405398
for static_arg_index in sorted(static_argnums):
406399
in_shardings.insert(static_arg_index, None)
@@ -1051,3 +1044,113 @@ def shard_map_wrapper(*args, **kwargs):
10511044
shard_map_wrapper.inner = shard_map_fn # type: ignore
10521045

10531046
return shard_map_wrapper # type: ignore
1047+
1048+
1049+
# We can't use private methods from jax._src.api_util
1050+
# We copy the function: api_util.fun_signature
1051+
def fun_signature(fun: tp.Callable) -> inspect.Signature | None:
1052+
try:
1053+
return inspect.signature(fun)
1054+
except (ValueError, TypeError):
1055+
return None
1056+
1057+
# Adapted copy of private jax function from api_util: fun_signature
1058+
def resolve_argnums(
1059+
fun: tp.Callable,
1060+
static_argnums: int | tp.Sequence[int] | None,
1061+
static_argnames: str | tp.Iterable[str] | None,
1062+
) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]:
1063+
def _ensure_index_tuple(x: tp.Any) -> tuple[int, ...]:
1064+
"""Convert x to a tuple of indices."""
1065+
try:
1066+
return (operator.index(x),)
1067+
except TypeError:
1068+
return tuple(map(operator.index, x))
1069+
1070+
def _ensure_str(x: str) -> str:
1071+
if not isinstance(x, str):
1072+
raise TypeError(f"argument is not a string: {x}")
1073+
return x
1074+
1075+
def _ensure_str_tuple(x: str | tp.Iterable[str]) -> tuple[str, ...]:
1076+
"""Convert x to a tuple of strings."""
1077+
if isinstance(x, str):
1078+
return (x,)
1079+
else:
1080+
return tuple(map(_ensure_str, x))
1081+
1082+
signature = fun_signature(fun)
1083+
if signature is None:
1084+
# Some built-in functions don't support signature.
1085+
# See: https://github.com/python/cpython/issues/73485
1086+
# In this case no validation is done
1087+
static_argnums = () if static_argnums is None else _ensure_index_tuple(
1088+
static_argnums)
1089+
static_argnames = () if static_argnames is None else _ensure_str_tuple(
1090+
static_argnames)
1091+
else:
1092+
# Infer argnums and argnames according to docstring
1093+
# If nums is None and names is not None, then nums are inferred from the
1094+
# names and vice-versa.
1095+
_POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
1096+
_POSITIONAL_ARGUMENTS = (
1097+
inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD
1098+
)
1099+
_KEYWORD_ARGUMENTS = (
1100+
inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY
1101+
)
1102+
_INVALID_KEYWORD_ARGUMENTS = (
1103+
inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL
1104+
)
1105+
1106+
def infer_argnums_and_argnames(
1107+
sig: inspect.Signature,
1108+
argnums: int | tp.Iterable[int] | None,
1109+
argnames: str | tp.Iterable[str] | None,
1110+
) -> tuple[tuple[int, ...], tuple[str, ...]]:
1111+
"""Infer missing argnums and argnames for a function with inspect."""
1112+
if argnums is None and argnames is None:
1113+
return (), ()
1114+
1115+
if argnums is not None and argnames is not None:
1116+
argnums = _ensure_index_tuple(argnums)
1117+
argnames = _ensure_str_tuple(argnames)
1118+
return argnums, argnames
1119+
1120+
parameters = sig.parameters
1121+
if argnums is None:
1122+
assert argnames is not None
1123+
argnames = _ensure_str_tuple(argnames)
1124+
argnums = tuple(
1125+
i for i, (k, param) in enumerate(parameters.items())
1126+
if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames
1127+
)
1128+
else:
1129+
argnums = _ensure_index_tuple(argnums)
1130+
argnames = tuple(
1131+
k for i, (k, param) in enumerate(parameters.items())
1132+
if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums
1133+
)
1134+
return argnums, argnames
1135+
1136+
def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None:
1137+
n_pos_args = 0
1138+
for param in sig.parameters.values():
1139+
if param.kind in _POSITIONAL_ARGUMENTS:
1140+
n_pos_args += 1
1141+
1142+
elif param.kind is inspect.Parameter.VAR_POSITIONAL:
1143+
# We can have any number of positional arguments
1144+
return
1145+
1146+
if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args):
1147+
raise ValueError(f"Jitted function has {argnums_name}={argnums}, "
1148+
f"but only accepts {n_pos_args} positional arguments.")
1149+
1150+
static_argnums, static_argnames = infer_argnums_and_argnames(
1151+
signature, static_argnums, static_argnames)
1152+
1153+
# Validation
1154+
_validate_argnums(signature, static_argnums, "static_argnums")
1155+
1156+
return static_argnums

0 commit comments

Comments
 (0)