@@ -48,6 +48,7 @@ def is_numpy_array(x):
48
48
is_array_api_obj
49
49
is_cupy_array
50
50
is_torch_array
51
+ is_ndonnx_array
51
52
is_dask_array
52
53
is_jax_array
53
54
is_pydata_sparse_array
@@ -78,11 +79,12 @@ def is_cupy_array(x):
78
79
is_array_api_obj
79
80
is_numpy_array
80
81
is_torch_array
82
+ is_ndonnx_array
81
83
is_dask_array
82
84
is_jax_array
83
85
is_pydata_sparse_array
84
86
"""
85
- # Avoid importing NumPy if it isn't already
87
+ # Avoid importing CuPy if it isn't already
86
88
if 'cupy' not in sys .modules :
87
89
return False
88
90
@@ -118,6 +120,33 @@ def is_torch_array(x):
118
120
# TODO: Should we reject ndarray subclasses?
119
121
return isinstance (x , torch .Tensor )
120
122
123
+ def is_ndonnx_array (x ):
124
+ """
125
+ Return True if `x` is a ndonnx Array.
126
+
127
+ This function does not import ndonnx if it has not already been imported
128
+ and is therefore cheap to use.
129
+
130
+ See Also
131
+ --------
132
+
133
+ array_namespace
134
+ is_array_api_obj
135
+ is_numpy_array
136
+ is_cupy_array
137
+ is_ndonnx_array
138
+ is_dask_array
139
+ is_jax_array
140
+ is_pydata_sparse_array
141
+ """
142
+ # Avoid importing torch if it isn't already
143
+ if 'ndonnx' not in sys .modules :
144
+ return False
145
+
146
+ import ndonnx as ndx
147
+
148
+ return isinstance (x , ndx .Array )
149
+
121
150
def is_dask_array (x ):
122
151
"""
123
152
Return True if `x` is a dask.array Array.
@@ -133,6 +162,7 @@ def is_dask_array(x):
133
162
is_numpy_array
134
163
is_cupy_array
135
164
is_torch_array
165
+ is_ndonnx_array
136
166
is_jax_array
137
167
is_pydata_sparse_array
138
168
"""
@@ -160,6 +190,7 @@ def is_jax_array(x):
160
190
is_numpy_array
161
191
is_cupy_array
162
192
is_torch_array
193
+ is_ndonnx_array
163
194
is_dask_array
164
195
is_pydata_sparse_array
165
196
"""
@@ -188,6 +219,7 @@ def is_pydata_sparse_array(x) -> bool:
188
219
is_numpy_array
189
220
is_cupy_array
190
221
is_torch_array
222
+ is_ndonnx_array
191
223
is_dask_array
192
224
is_jax_array
193
225
"""
@@ -211,6 +243,7 @@ def is_array_api_obj(x):
211
243
is_numpy_array
212
244
is_cupy_array
213
245
is_torch_array
246
+ is_ndonnx_array
214
247
is_dask_array
215
248
is_jax_array
216
249
"""
@@ -613,6 +646,7 @@ def size(x):
613
646
"is_jax_array" ,
614
647
"is_numpy_array" ,
615
648
"is_torch_array" ,
649
+ "is_ndonnx_array" ,
616
650
"is_pydata_sparse_array" ,
617
651
"size" ,
618
652
"to_device" ,
0 commit comments