Skip to content

Ensure __jax_array__ support in binary ops #28630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,13 +573,15 @@ def _notimplemented_flat(self):
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array)
_rejected_binop_types = (list, tuple, set, dict)

def _defer_to_unrecognized_arg(opchar, binary_op, swap=False):
def _defer_to_unrecognized_arg(
opchar: str,
binary_op: Callable[[ArrayLike, ArrayLike], Array],
swap: bool = False
) -> Callable[[Array, ArrayLike], Array]:
# Ensure that other array types have the chance to override arithmetic.
def deferring_binary_op(self, other):
if hasattr(other, '__jax_array__'):
other = other.__jax_array__()
args = (other, self) if swap else (self, other)
if isinstance(other, _accepted_binop_types):
if hasattr(other, "__jax_array__") or isinstance(other, _accepted_binop_types):
return binary_op(*args)
# Note: don't use isinstance here, because we don't want to raise for
# subclasses, e.g. NamedTuple objects that may override operators.
Expand All @@ -589,7 +591,7 @@ def deferring_binary_op(self, other):
return NotImplemented
return deferring_binary_op

def _unimplemented_setitem(self, i, x):
def _unimplemented_setitem(self, i: Any, x: ArrayLike):
msg = ("JAX arrays are immutable and do not support in-place item assignment."
" Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:"
" https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html")
Expand Down
31 changes: 31 additions & 0 deletions tests/array_extensibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import functools
import operator
from typing import Any, Callable, NamedTuple

from absl.testing import absltest
Expand Down Expand Up @@ -512,6 +513,22 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct:
NumPyAPI.sig(jnp.zeros_like, Float[5]),
]

BINARY_OPERATORS = [
NumPyAPI.sig(operator.add, Float[5], Float[5]),
NumPyAPI.sig(operator.sub, Float[5], Float[5]),
NumPyAPI.sig(operator.mul, Float[5], Float[5]),
NumPyAPI.sig(operator.truediv, Float[5], Float[5]),
NumPyAPI.sig(operator.floordiv, Float[5], Float[5]),
NumPyAPI.sig(operator.mod, Float[5], Float[5]),
NumPyAPI.sig(operator.pow, Float[5], Float[5]),
NumPyAPI.sig(operator.matmul, Float[5], Float[5]),
NumPyAPI.sig(operator.and_, Int[5], Int[5]),
NumPyAPI.sig(operator.or_, Int[5], Int[5]),
NumPyAPI.sig(operator.xor, Int[5], Int[5]),
NumPyAPI.sig(operator.lshift, Int[5], Int[5]),
NumPyAPI.sig(operator.rshift, Int[5], Int[5]),
]


class JaxArrayTests(jtu.JaxTestCase):
@parameterized.named_parameters(
Expand All @@ -529,6 +546,20 @@ def test_numpy_api_supports_jax_array(self, api):

self.assertAllClose(wrapped, expected, atol=0, rtol=0)

@parameterized.named_parameters(
{'testcase_name': api.name(), 'api': api} for api in BINARY_OPERATORS)
def test_binary_operator_supports_jax_array(self, api):
if api.skip_on_devices and jtu.test_device_matches(api.skip_on_devices):
self.skipTest(f'{api.name()} not supported on {api.skip_on_devices}')
lhs, rhs = map(jnp.asarray, api.make_args(self.rng()))

expected = api.fun(lhs, rhs)
result_left = api.fun(JaxArrayWrapper(lhs), rhs)
result_right = api.fun(lhs, JaxArrayWrapper(rhs))

self.assertAllClose(result_left, expected, atol=0, rtol=0)
self.assertAllClose(result_right, expected, atol=0, rtol=0)

@parameterized.named_parameters(
{'testcase_name': func.__name__, 'func': func}
for func in [jnp.zeros_like, jnp.ones_like, jnp.empty_like, jnp.full_like]
Expand Down
Loading