85
85
map , unsafe_map = safe_map , map
86
86
zip , unsafe_zip = safe_zip , zip
87
87
88
+ # TODO(jakevdp): replace this with an isinstance() check when JEP 12049 is complete.
89
+ def _is_array_or_tracer (operand : Any ) -> bool :
90
+ if config .jax_array :
91
+ from jax .experimental import array # pylint: disable=g-import-not-at-top
92
+ return isinstance (operand , (core .Tracer , array .Array ))
93
+ else :
94
+ return isinstance (operand , (core .Tracer , device_array .DeviceArray ))
95
+
88
96
def _validate_shapes (shapes : Sequence [Shape ]):
89
97
def _check_static_shape (shape : Shape ):
90
98
checked = canonicalize_shape (shape )
@@ -549,8 +557,6 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
549
557
550
558
def _convert_element_type (operand : Array , new_dtype : Optional [DType ] = None ,
551
559
weak_type : bool = False ):
552
- from jax .experimental import array
553
-
554
560
# Don't canonicalize old_dtype because x64 context might cause
555
561
# un-canonicalized operands to be passed in.
556
562
old_dtype = dtypes .dtype (operand , canonicalize = False )
@@ -577,8 +583,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
577
583
operand = np .asarray (operand , new_dtype )
578
584
old_weak_type = False
579
585
580
- if ((old_dtype , old_weak_type ) == (new_dtype , new_weak_type )
581
- and isinstance (operand , (core .Tracer , device_array .DeviceArray , array .Array ))):
586
+ if (old_dtype , old_weak_type ) == (new_dtype , new_weak_type ) and _is_array_or_tracer (operand ):
582
587
return operand
583
588
else :
584
589
return convert_element_type_p .bind (operand , new_dtype = new_dtype ,
@@ -629,13 +634,11 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
629
634
Returns:
630
635
An array containing the concatenation.
631
636
"""
632
- from jax .experimental import array
633
-
634
637
if len (operands ) == 0 :
635
638
raise ValueError ("concatenate requires a non-empty sequences of arrays" )
636
639
if len (operands ) == 1 :
637
640
op , = operands
638
- if isinstance (op , ( core . Tracer , device_array . DeviceArray , array . Array ) ):
641
+ if _is_array_or_tracer (op ):
639
642
return op
640
643
return concatenate_p .bind (* operands , dimension = dimension )
641
644
@@ -803,10 +806,7 @@ def broadcast_in_dim(operand: Array, shape: Shape,
803
806
See Also:
804
807
jax.lax.broadcast : simpler interface to add new leading dimensions.
805
808
"""
806
- from jax .experimental import array
807
-
808
- if (np .ndim (operand ) == len (shape ) and not len (broadcast_dimensions )
809
- and isinstance (operand , (device_array .DeviceArray , core .Tracer , array .Array ))):
809
+ if np .ndim (operand ) == len (shape ) and not len (broadcast_dimensions ) and _is_array_or_tracer (operand ):
810
810
return operand
811
811
if config .jax_dynamic_shapes :
812
812
# We must gate this behavior under a flag because otherwise the errors
@@ -861,8 +861,6 @@ def reshape(operand: Array, new_sizes: Shape,
861
861
>>> reshape(y, (6,), (1, 0))
862
862
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
863
863
"""
864
- from jax .experimental import array
865
-
866
864
new_sizes = canonicalize_shape (new_sizes ) # TODO
867
865
new_sizes = tuple (new_sizes )
868
866
same_shape = core .symbolic_equal_shape (np .shape (operand ), new_sizes )
@@ -872,8 +870,7 @@ def reshape(operand: Array, new_sizes: Shape,
872
870
else :
873
871
dims = api_util ._ensure_index_tuple (dimensions )
874
872
same_dims = tuple (dims ) == tuple (range (np .ndim (operand )))
875
- if (np .shape (operand ) and same_shape and same_dims
876
- and isinstance (operand , (core .Tracer , device_array .DeviceArray , array .Array ))):
873
+ if np .shape (operand ) and same_shape and same_dims and _is_array_or_tracer (operand ):
877
874
return operand
878
875
else :
879
876
dyn_shape , static_new_sizes = _extract_tracers_dyn_shape (new_sizes )
@@ -952,8 +949,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
952
949
operator.
953
950
"""
954
951
permutation = tuple (operator .index (d ) for d in permutation )
955
- if (permutation == tuple (range (np .ndim (operand )))
956
- and isinstance (operand , (core .Tracer , device_array .DeviceArray ))):
952
+ if permutation == tuple (range (np .ndim (operand ))) and _is_array_or_tracer (operand ):
957
953
return operand
958
954
else :
959
955
return transpose_p .bind (operand , permutation = permutation )
@@ -1283,7 +1279,7 @@ def squeeze(array: Array, dimensions: Sequence[int]) -> Array:
1283
1279
"""Squeeze any number of size 1 dimensions from an array."""
1284
1280
ndim = np .ndim (array )
1285
1281
dimensions = tuple (sorted (canonicalize_axis (i , ndim ) for i in dimensions ))
1286
- if not dimensions and isinstance (array , ( core . Tracer , device_array . DeviceArray ) ):
1282
+ if not dimensions and _is_array_or_tracer (array ):
1287
1283
return array
1288
1284
return squeeze_p .bind (array , dimensions = dimensions )
1289
1285
0 commit comments