|
16 | 16 |
|
17 | 17 | import dataclasses |
18 | 18 | import functools |
| 19 | +import inspect |
| 20 | +import operator |
19 | 21 | import typing as tp |
20 | 22 |
|
21 | 23 | import jax |
22 | 24 | from jax.sharding import AbstractMesh, Mesh, PartitionSpec |
23 | | -from jax._src import api_util # We use fun_signature and resolve_argnums |
24 | 25 |
|
25 | 26 | from flax.nnx import ( |
26 | 27 | extract, |
@@ -390,17 +391,9 @@ def __init__( |
390 | 391 | out_shardings, |
391 | 392 | ) |
392 | 393 |
|
393 | | - if isinstance(in_shardings, (list, tuple)): |
| 394 | + if isinstance(in_shardings, (tuple, list)) and (static_argnums or static_argnames): |
394 | 395 | # We should reintroduce None values into in_shardings corresponding to static arguments |
395 | | - fun_signature = api_util.fun_signature(fun) |
396 | | - _, _, static_argnums, _ = api_util.resolve_argnums( |
397 | | - fun, |
398 | | - fun_signature, |
399 | | - None, |
400 | | - None, |
401 | | - static_argnums, |
402 | | - static_argnames, |
403 | | - ) |
| 396 | + static_argnums = resolve_argnums(fun, static_argnums, static_argnames) |
404 | 397 | in_shardings = list(in_shardings) |
405 | 398 | for static_arg_index in sorted(static_argnums): |
406 | 399 | in_shardings.insert(static_arg_index, None) |
@@ -1051,3 +1044,113 @@ def shard_map_wrapper(*args, **kwargs): |
1051 | 1044 | shard_map_wrapper.inner = shard_map_fn # type: ignore |
1052 | 1045 |
|
1053 | 1046 | return shard_map_wrapper # type: ignore |
| 1047 | + |
| 1048 | + |
| 1049 | +# We can't use private methods from jax._src.api_util |
| 1050 | +# We copy the function: api_util.fun_signature |
| 1051 | +def fun_signature(fun: tp.Callable) -> inspect.Signature | None: |
| 1052 | + try: |
| 1053 | + return inspect.signature(fun) |
| 1054 | + except (ValueError, TypeError): |
| 1055 | + return None |
| 1056 | + |
| 1057 | +# Adapted copy of private jax function from api_util: fun_signature |
| 1058 | +def resolve_argnums( |
| 1059 | + fun: tp.Callable, |
| 1060 | + static_argnums: int | tp.Sequence[int] | None, |
| 1061 | + static_argnames: str | tp.Iterable[str] | None, |
| 1062 | +) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]: |
| 1063 | + def _ensure_index_tuple(x: tp.Any) -> tuple[int, ...]: |
| 1064 | + """Convert x to a tuple of indices.""" |
| 1065 | + try: |
| 1066 | + return (operator.index(x),) |
| 1067 | + except TypeError: |
| 1068 | + return tuple(map(operator.index, x)) |
| 1069 | + |
| 1070 | + def _ensure_str(x: str) -> str: |
| 1071 | + if not isinstance(x, str): |
| 1072 | + raise TypeError(f"argument is not a string: {x}") |
| 1073 | + return x |
| 1074 | + |
| 1075 | + def _ensure_str_tuple(x: str | tp.Iterable[str]) -> tuple[str, ...]: |
| 1076 | + """Convert x to a tuple of strings.""" |
| 1077 | + if isinstance(x, str): |
| 1078 | + return (x,) |
| 1079 | + else: |
| 1080 | + return tuple(map(_ensure_str, x)) |
| 1081 | + |
| 1082 | + signature = fun_signature(fun) |
| 1083 | + if signature is None: |
| 1084 | + # Some built-in functions don't support signature. |
| 1085 | + # See: https://github.com/python/cpython/issues/73485 |
| 1086 | + # In this case no validation is done |
| 1087 | + static_argnums = () if static_argnums is None else _ensure_index_tuple( |
| 1088 | + static_argnums) |
| 1089 | + static_argnames = () if static_argnames is None else _ensure_str_tuple( |
| 1090 | + static_argnames) |
| 1091 | + else: |
| 1092 | + # Infer argnums and argnames according to docstring |
| 1093 | + # If nums is None and names is not None, then nums are inferred from the |
| 1094 | + # names and vice-versa. |
| 1095 | + _POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD |
| 1096 | + _POSITIONAL_ARGUMENTS = ( |
| 1097 | + inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD |
| 1098 | + ) |
| 1099 | + _KEYWORD_ARGUMENTS = ( |
| 1100 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY |
| 1101 | + ) |
| 1102 | + _INVALID_KEYWORD_ARGUMENTS = ( |
| 1103 | + inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL |
| 1104 | + ) |
| 1105 | + |
| 1106 | + def infer_argnums_and_argnames( |
| 1107 | + sig: inspect.Signature, |
| 1108 | + argnums: int | tp.Iterable[int] | None, |
| 1109 | + argnames: str | tp.Iterable[str] | None, |
| 1110 | + ) -> tuple[tuple[int, ...], tuple[str, ...]]: |
| 1111 | + """Infer missing argnums and argnames for a function with inspect.""" |
| 1112 | + if argnums is None and argnames is None: |
| 1113 | + return (), () |
| 1114 | + |
| 1115 | + if argnums is not None and argnames is not None: |
| 1116 | + argnums = _ensure_index_tuple(argnums) |
| 1117 | + argnames = _ensure_str_tuple(argnames) |
| 1118 | + return argnums, argnames |
| 1119 | + |
| 1120 | + parameters = sig.parameters |
| 1121 | + if argnums is None: |
| 1122 | + assert argnames is not None |
| 1123 | + argnames = _ensure_str_tuple(argnames) |
| 1124 | + argnums = tuple( |
| 1125 | + i for i, (k, param) in enumerate(parameters.items()) |
| 1126 | + if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames |
| 1127 | + ) |
| 1128 | + else: |
| 1129 | + argnums = _ensure_index_tuple(argnums) |
| 1130 | + argnames = tuple( |
| 1131 | + k for i, (k, param) in enumerate(parameters.items()) |
| 1132 | + if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums |
| 1133 | + ) |
| 1134 | + return argnums, argnames |
| 1135 | + |
| 1136 | + def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None: |
| 1137 | + n_pos_args = 0 |
| 1138 | + for param in sig.parameters.values(): |
| 1139 | + if param.kind in _POSITIONAL_ARGUMENTS: |
| 1140 | + n_pos_args += 1 |
| 1141 | + |
| 1142 | + elif param.kind is inspect.Parameter.VAR_POSITIONAL: |
| 1143 | + # We can have any number of positional arguments |
| 1144 | + return |
| 1145 | + |
| 1146 | + if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args): |
| 1147 | + raise ValueError(f"Jitted function has {argnums_name}={argnums}, " |
| 1148 | + f"but only accepts {n_pos_args} positional arguments.") |
| 1149 | + |
| 1150 | + static_argnums, static_argnames = infer_argnums_and_argnames( |
| 1151 | + signature, static_argnums, static_argnames) |
| 1152 | + |
| 1153 | + # Validation |
| 1154 | + _validate_argnums(signature, static_argnums, "static_argnums") |
| 1155 | + |
| 1156 | + return static_argnums |
0 commit comments