Skip to content

Commit ba4bb43

Browse files
committed
jax.Array: support fast path for lax.transpose & lax.squeeze
As part of this change, I created a helper function so that the logic of type checking is in a single location. Eventually we can replace this helper function with appropriate isinstance() checks using the APIs described in jax-ml#11859.
1 parent 4746a39 commit ba4bb43

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

jax/_src/lax/lax.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@
8585
map, unsafe_map = safe_map, map
8686
zip, unsafe_zip = safe_zip, zip
8787

88+
# TODO(jakevdp): replace this with an isinstance() check when JEP 12049 is complete.
89+
def _is_array_or_tracer(operand: Any) -> bool:
90+
if config.jax_array:
91+
from jax.experimental import array # pylint: disable=g-import-not-at-top
92+
return isinstance(operand, (core.Tracer, array.Array))
93+
else:
94+
return isinstance(operand, (core.Tracer, device_array.DeviceArray))
95+
8896
def _validate_shapes(shapes: Sequence[Shape]):
8997
def _check_static_shape(shape: Shape):
9098
checked = canonicalize_shape(shape)
@@ -549,8 +557,6 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
549557

550558
def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
551559
weak_type: bool = False):
552-
from jax.experimental import array
553-
554560
# Don't canonicalize old_dtype because x64 context might cause
555561
# un-canonicalized operands to be passed in.
556562
old_dtype = dtypes.dtype(operand, canonicalize=False)
@@ -577,8 +583,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
577583
operand = np.asarray(operand, new_dtype)
578584
old_weak_type = False
579585

580-
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
581-
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
586+
if (old_dtype, old_weak_type) == (new_dtype, new_weak_type) and _is_array_or_tracer(operand):
582587
return operand
583588
else:
584589
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
@@ -629,13 +634,11 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
629634
Returns:
630635
An array containing the concatenation.
631636
"""
632-
from jax.experimental import array
633-
634637
if len(operands) == 0:
635638
raise ValueError("concatenate requires a non-empty sequences of arrays")
636639
if len(operands) == 1:
637640
op, = operands
638-
if isinstance(op, (core.Tracer, device_array.DeviceArray, array.Array)):
641+
if _is_array_or_tracer(op):
639642
return op
640643
return concatenate_p.bind(*operands, dimension=dimension)
641644

@@ -803,10 +806,7 @@ def broadcast_in_dim(operand: Array, shape: Shape,
803806
See Also:
804807
jax.lax.broadcast : simpler interface to add new leading dimensions.
805808
"""
806-
from jax.experimental import array
807-
808-
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
809-
and isinstance(operand, (device_array.DeviceArray, core.Tracer, array.Array))):
809+
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and _is_array_or_tracer(operand):
810810
return operand
811811
if config.jax_dynamic_shapes:
812812
# We must gate this behavior under a flag because otherwise the errors
@@ -861,8 +861,6 @@ def reshape(operand: Array, new_sizes: Shape,
861861
>>> reshape(y, (6,), (1, 0))
862862
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
863863
"""
864-
from jax.experimental import array
865-
866864
new_sizes = canonicalize_shape(new_sizes) # TODO
867865
new_sizes = tuple(new_sizes)
868866
same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes)
@@ -872,8 +870,7 @@ def reshape(operand: Array, new_sizes: Shape,
872870
else:
873871
dims = api_util._ensure_index_tuple(dimensions)
874872
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
875-
if (np.shape(operand) and same_shape and same_dims
876-
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
873+
if np.shape(operand) and same_shape and same_dims and _is_array_or_tracer(operand):
877874
return operand
878875
else:
879876
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
@@ -952,8 +949,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
952949
operator.
953950
"""
954951
permutation = tuple(operator.index(d) for d in permutation)
955-
if (permutation == tuple(range(np.ndim(operand)))
956-
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
952+
if permutation == tuple(range(np.ndim(operand))) and _is_array_or_tracer(operand):
957953
return operand
958954
else:
959955
return transpose_p.bind(operand, permutation=permutation)
@@ -1283,7 +1279,7 @@ def squeeze(array: Array, dimensions: Sequence[int]) -> Array:
12831279
"""Squeeze any number of size 1 dimensions from an array."""
12841280
ndim = np.ndim(array)
12851281
dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions))
1286-
if not dimensions and isinstance(array, (core.Tracer, device_array.DeviceArray)):
1282+
if not dimensions and _is_array_or_tracer(array):
12871283
return array
12881284
return squeeze_p.bind(array, dimensions=dimensions)
12891285

0 commit comments

Comments
 (0)