85
85
map , unsafe_map = safe_map , map
86
86
zip , unsafe_zip = safe_zip , zip
87
87
88
+ # TODO(jakevdp): replace this with an isinstance() check when JEP 12049 is complete.
89
+ def _is_array_or_tracer (operand : Any ) -> bool :
90
+ if config .jax_array :
91
+ from jax .experimental import array # pylint: disable=g-import-not-at-top
92
+ return isinstance (operand , (core .Tracer , array .Array ))
93
+ else :
94
+ return isinstance (operand , (core .Tracer , device_array .DeviceArray ))
95
+
88
96
def _validate_shapes (shapes : Sequence [Shape ]):
89
97
def _check_static_shape (shape : Shape ):
90
98
checked = canonicalize_shape (shape )
@@ -548,8 +556,6 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
548
556
549
557
def _convert_element_type (operand : Array , new_dtype : Optional [DType ] = None ,
550
558
weak_type : bool = False ):
551
- from jax .experimental import array
552
-
553
559
# Don't canonicalize old_dtype because x64 context might cause
554
560
# un-canonicalized operands to be passed in.
555
561
old_dtype = dtypes .dtype (operand , canonicalize = False )
@@ -576,8 +582,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
576
582
operand = np .asarray (operand , new_dtype )
577
583
old_weak_type = False
578
584
579
- if ((old_dtype , old_weak_type ) == (new_dtype , new_weak_type )
580
- and isinstance (operand , (core .Tracer , device_array .DeviceArray , array .Array ))):
585
+ if (old_dtype , old_weak_type ) == (new_dtype , new_weak_type ) and _is_array_or_tracer (operand ):
581
586
return operand
582
587
else :
583
588
return convert_element_type_p .bind (operand , new_dtype = new_dtype ,
@@ -628,13 +633,11 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
628
633
Returns:
629
634
An array containing the concatenation.
630
635
"""
631
- from jax .experimental import array
632
-
633
636
if len (operands ) == 0 :
634
637
raise ValueError ("concatenate requires a non-empty sequences of arrays" )
635
638
if len (operands ) == 1 :
636
639
op , = operands
637
- if isinstance (op , ( core . Tracer , device_array . DeviceArray , array . Array ) ):
640
+ if _is_array_or_tracer (op ):
638
641
return op
639
642
return concatenate_p .bind (* operands , dimension = dimension )
640
643
@@ -802,10 +805,7 @@ def broadcast_in_dim(operand: Array, shape: Shape,
802
805
See Also:
803
806
jax.lax.broadcast : simpler interface to add new leading dimensions.
804
807
"""
805
- from jax .experimental import array
806
-
807
- if (np .ndim (operand ) == len (shape ) and not len (broadcast_dimensions )
808
- and isinstance (operand , (device_array .DeviceArray , core .Tracer , array .Array ))):
808
+ if np .ndim (operand ) == len (shape ) and not len (broadcast_dimensions ) and _is_array_or_tracer (operand ):
809
809
return operand
810
810
if config .jax_dynamic_shapes :
811
811
# We must gate this behavior under a flag because otherwise the errors
@@ -860,8 +860,6 @@ def reshape(operand: Array, new_sizes: Shape,
860
860
>>> reshape(y, (6,), (1, 0))
861
861
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
862
862
"""
863
- from jax .experimental import array
864
-
865
863
new_sizes = canonicalize_shape (new_sizes ) # TODO
866
864
new_sizes = tuple (new_sizes )
867
865
same_shape = core .symbolic_equal_shape (np .shape (operand ), new_sizes )
@@ -871,8 +869,7 @@ def reshape(operand: Array, new_sizes: Shape,
871
869
else :
872
870
dims = api_util ._ensure_index_tuple (dimensions )
873
871
same_dims = tuple (dims ) == tuple (range (np .ndim (operand )))
874
- if (np .shape (operand ) and same_shape and same_dims
875
- and isinstance (operand , (core .Tracer , device_array .DeviceArray , array .Array ))):
872
+ if np .shape (operand ) and same_shape and same_dims and _is_array_or_tracer (operand ):
876
873
return operand
877
874
else :
878
875
dyn_shape , static_new_sizes = _extract_tracers_dyn_shape (new_sizes )
@@ -951,8 +948,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
951
948
operator.
952
949
"""
953
950
permutation = tuple (operator .index (d ) for d in permutation )
954
- if (permutation == tuple (range (np .ndim (operand )))
955
- and isinstance (operand , (core .Tracer , device_array .DeviceArray ))):
951
+ if permutation == tuple (range (np .ndim (operand ))) and _is_array_or_tracer (operand ):
956
952
return operand
957
953
else :
958
954
return transpose_p .bind (operand , permutation = permutation )
@@ -1282,7 +1278,7 @@ def squeeze(array: Array, dimensions: Sequence[int]) -> Array:
1282
1278
"""Squeeze any number of size 1 dimensions from an array."""
1283
1279
ndim = np .ndim (array )
1284
1280
dimensions = tuple (sorted (canonicalize_axis (i , ndim ) for i in dimensions ))
1285
- if not dimensions and isinstance (array , ( core . Tracer , device_array . DeviceArray ) ):
1281
+ if not dimensions and _is_array_or_tracer (array ):
1286
1282
return array
1287
1283
return squeeze_p .bind (array , dimensions = dimensions )
1288
1284
0 commit comments