Skip to content

Commit 2900169

Browse files
committed
Design 2->4
1 parent 6e3c824 commit 2900169

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/array_api_extra/_lib/_funcs.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ def setdiff1d(
544544
/,
545545
*,
546546
assume_unique: bool = False,
547+
size: int | None = None,
547548
fill_value: object | None = None,
548549
xp: ModuleType | None = None,
549550
) -> Array:
@@ -561,11 +562,16 @@ def setdiff1d(
561562
assume_unique : bool
562563
If ``True``, the input arrays are both assumed to be unique, which
563564
can speed up the calculation. Default is ``False``.
564-
fill_value : object, optional
565-
Pad the output array with this value.
565+
size : int, optional
566+
The size of the output array. This is exclusively used inside the JAX JIT, and
567+
only for as long as JAX does not support arrays of unknown size inside it. In
568+
all other cases, it is disregarded.
569+
Returned elements will be clipped if they are more than size, and padded with
570+
`fill_value` if they are less. Default: raise if inside ``jax.jit``.
566571
567-
This is exclusively used for JAX arrays when running inside ``jax.jit``,
568-
where all array shapes need to be known in advance.
572+
fill_value : object, optional
573+
Pad the output array with this value. This is exclusively used for JAX arrays
574+
when running inside ``jax.jit``. Default: 0.
569575
xp : array_namespace, optional
570576
The standard-compatible namespace for `x1` and `x2`. Default: infer.
571577
@@ -630,7 +636,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
630636
return x1 if assume_unique else xp.unique_values(x1)
631637

632638
def _jax_jit_impl(
633-
x1: Array, x2: Array, fill_value: object | None
639+
x1: Array, x2: Array, size: int | None, fill_value: object | None
634640
) -> Array: # numpydoc ignore=PR01,RT01
635641
"""
636642
JAX implementation inside jax.jit.
@@ -639,9 +645,9 @@ def _jax_jit_impl(
639645
and not being able to filter by a boolean mask.
640646
Returns array the same size as x1, padded with fill_value.
641647
"""
642-
# unique_values inside jax.jit is not supported unless it's got a fixed size
643-
mask = _x1_not_in_x2(x1, x2)
644-
648+
if size is None:
649+
msg = "`size` is mandatory when running inside `jax.jit`."
650+
raise ValueError(msg)
645651
if fill_value is None:
646652
fill_value = xp.zeros((), dtype=x1.dtype)
647653
else:
@@ -650,9 +656,12 @@ def _jax_jit_impl(
650656
msg = "`fill_value` must be a scalar."
651657
raise ValueError(msg)
652658

659+
# unique_values inside jax.jit is not supported unless it's got a fixed size
660+
mask = _x1_not_in_x2(x1, x2)
653661
x1 = xp.where(mask, x1, fill_value)
654-
# Note: jnp.unique_values sorts
655-
return xp.unique_values(x1, size=x1.size, fill_value=fill_value)
662+
# Move fill_value to the right
663+
x1 = xp.take(x1, xp.argsort(~mask, stable=True))
664+
x1 = xp.unique_values(x1, size=x1.size, fill_value=fill_value)
656665

657666
if is_dask_namespace(xp):
658667
return _dask_impl(x1, x2)
@@ -666,7 +675,7 @@ def _jax_jit_impl(
666675
jax.errors.ConcretizationTypeError,
667676
jax.errors.NonConcreteBooleanIndexError,
668677
):
669-
return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit
678+
return _jax_jit_impl(x1, x2, size, fill_value) # inside jax.jit
670679

671680
return _generic_impl(x1, x2)
672681

0 commit comments

Comments
 (0)