@@ -544,6 +544,7 @@ def setdiff1d(
544
544
/ ,
545
545
* ,
546
546
assume_unique : bool = False ,
547
+ size : int | None = None ,
547
548
fill_value : object | None = None ,
548
549
xp : ModuleType | None = None ,
549
550
) -> Array :
@@ -561,11 +562,16 @@ def setdiff1d(
561
562
assume_unique : bool
562
563
If ``True``, the input arrays are both assumed to be unique, which
563
564
can speed up the calculation. Default is ``False``.
564
- fill_value : object, optional
565
- Pad the output array with this value.
565
+ size : int, optional
566
+ The size of the output array. This is exclusively used inside the JAX JIT, and
567
+ only for as long as JAX does not support arrays of unknown size inside it. In
568
+ all other cases, it is disregarded.
569
+ Returned elements will be clipped if they are more than size, and padded with
570
+ `fill_value` if they are less. Default: raise if inside ``jax.jit``.
566
571
567
- This is exclusively used for JAX arrays when running inside ``jax.jit``,
568
- where all array shapes need to be known in advance.
572
+ fill_value : object, optional
573
+ Pad the output array with this value. This is exclusively used for JAX arrays
574
+ when running inside ``jax.jit``. Default: 0.
569
575
xp : array_namespace, optional
570
576
The standard-compatible namespace for `x1` and `x2`. Default: infer.
571
577
@@ -630,7 +636,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
630
636
return x1 if assume_unique else xp .unique_values (x1 )
631
637
632
638
def _jax_jit_impl (
633
- x1 : Array , x2 : Array , fill_value : object | None
639
+ x1 : Array , x2 : Array , size : int | None , fill_value : object | None
634
640
) -> Array : # numpydoc ignore=PR01,RT01
635
641
"""
636
642
JAX implementation inside jax.jit.
@@ -639,9 +645,9 @@ def _jax_jit_impl(
639
645
and not being able to filter by a boolean mask.
640
646
Returns array the same size as x1, padded with fill_value.
641
647
"""
642
- # unique_values inside jax.jit is not supported unless it's got a fixed size
643
- mask = _x1_not_in_x2 ( x1 , x2 )
644
-
648
+ if size is None :
649
+ msg = "`size` is mandatory when running inside `jax.jit`."
650
+ raise ValueError ( msg )
645
651
if fill_value is None :
646
652
fill_value = xp .zeros ((), dtype = x1 .dtype )
647
653
else :
@@ -650,9 +656,12 @@ def _jax_jit_impl(
650
656
msg = "`fill_value` must be a scalar."
651
657
raise ValueError (msg )
652
658
659
+ # unique_values inside jax.jit is not supported unless it's got a fixed size
660
+ mask = _x1_not_in_x2 (x1 , x2 )
653
661
x1 = xp .where (mask , x1 , fill_value )
654
- # Note: jnp.unique_values sorts
655
- return xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
662
+ # Move fill_value to the right
663
+ x1 = xp .take (x1 , xp .argsort (~ mask , stable = True ))
664
+ x1 = xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
656
665
657
666
if is_dask_namespace (xp ):
658
667
return _dask_impl (x1 , x2 )
@@ -666,7 +675,7 @@ def _jax_jit_impl(
666
675
jax .errors .ConcretizationTypeError ,
667
676
jax .errors .NonConcreteBooleanIndexError ,
668
677
):
669
- return _jax_jit_impl (x1 , x2 , fill_value ) # inside jax.jit
678
+ return _jax_jit_impl (x1 , x2 , size , fill_value ) # inside jax.jit
670
679
671
680
return _generic_impl (x1 , x2 )
672
681
0 commit comments