Skip to content

Commit 635eebf

Browse files
author
jax authors
committed
Merge pull request #12311 from jakevdp:fix-lax-array
PiperOrigin-RevId: 473328023
2 parents b764aad + 07f55b3 commit 635eebf

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

jax/_src/lax/lax.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@
8585
map, unsafe_map = safe_map, map
8686
zip, unsafe_zip = safe_zip, zip
8787

88+
# TODO(jakevdp): replace this with an isinstance() check when JEP 12049 is complete.
89+
def _is_array_or_tracer(operand: Any) -> bool:
90+
if config.jax_array:
91+
from jax.experimental import array # pylint: disable=g-import-not-at-top
92+
return isinstance(operand, (core.Tracer, array.Array))
93+
else:
94+
return isinstance(operand, (core.Tracer, device_array.DeviceArray))
95+
8896
def _validate_shapes(shapes: Sequence[Shape]):
8997
def _check_static_shape(shape: Shape):
9098
checked = canonicalize_shape(shape)
@@ -548,8 +556,6 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
548556

549557
def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
550558
weak_type: bool = False):
551-
from jax.experimental import array
552-
553559
# Don't canonicalize old_dtype because x64 context might cause
554560
# un-canonicalized operands to be passed in.
555561
old_dtype = dtypes.dtype(operand, canonicalize=False)
@@ -576,8 +582,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
576582
operand = np.asarray(operand, new_dtype)
577583
old_weak_type = False
578584

579-
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
580-
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
585+
if (old_dtype, old_weak_type) == (new_dtype, new_weak_type) and _is_array_or_tracer(operand):
581586
return operand
582587
else:
583588
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
@@ -628,13 +633,11 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
628633
Returns:
629634
An array containing the concatenation.
630635
"""
631-
from jax.experimental import array
632-
633636
if len(operands) == 0:
634637
raise ValueError("concatenate requires a non-empty sequences of arrays")
635638
if len(operands) == 1:
636639
op, = operands
637-
if isinstance(op, (core.Tracer, device_array.DeviceArray, array.Array)):
640+
if _is_array_or_tracer(op):
638641
return op
639642
return concatenate_p.bind(*operands, dimension=dimension)
640643

@@ -802,10 +805,7 @@ def broadcast_in_dim(operand: Array, shape: Shape,
802805
See Also:
803806
jax.lax.broadcast : simpler interface to add new leading dimensions.
804807
"""
805-
from jax.experimental import array
806-
807-
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
808-
and isinstance(operand, (device_array.DeviceArray, core.Tracer, array.Array))):
808+
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and _is_array_or_tracer(operand):
809809
return operand
810810
if config.jax_dynamic_shapes:
811811
# We must gate this behavior under a flag because otherwise the errors
@@ -860,8 +860,6 @@ def reshape(operand: Array, new_sizes: Shape,
860860
>>> reshape(y, (6,), (1, 0))
861861
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
862862
"""
863-
from jax.experimental import array
864-
865863
new_sizes = canonicalize_shape(new_sizes) # TODO
866864
new_sizes = tuple(new_sizes)
867865
same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes)
@@ -871,8 +869,7 @@ def reshape(operand: Array, new_sizes: Shape,
871869
else:
872870
dims = api_util._ensure_index_tuple(dimensions)
873871
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
874-
if (np.shape(operand) and same_shape and same_dims
875-
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
872+
if np.shape(operand) and same_shape and same_dims and _is_array_or_tracer(operand):
876873
return operand
877874
else:
878875
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
@@ -951,8 +948,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
951948
operator.
952949
"""
953950
permutation = tuple(operator.index(d) for d in permutation)
954-
if (permutation == tuple(range(np.ndim(operand)))
955-
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
951+
if permutation == tuple(range(np.ndim(operand))) and _is_array_or_tracer(operand):
956952
return operand
957953
else:
958954
return transpose_p.bind(operand, permutation=permutation)
@@ -1282,7 +1278,7 @@ def squeeze(array: Array, dimensions: Sequence[int]) -> Array:
12821278
"""Squeeze any number of size 1 dimensions from an array."""
12831279
ndim = np.ndim(array)
12841280
dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions))
1285-
if not dimensions and isinstance(array, (core.Tracer, device_array.DeviceArray)):
1281+
if not dimensions and _is_array_or_tracer(array):
12861282
return array
12871283
return squeeze_p.bind(array, dimensions=dimensions)
12881284

0 commit comments

Comments
 (0)