You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/jep/12049-type-annotations.md
+7-9
Original file line number
Diff line number
Diff line change
@@ -207,12 +207,13 @@ An advantage of this approach is that it is that the API is already familiar to
207
207
-`NDArray`, which is a generic array type used only for type annotations
208
208
-`np.ndarray`, which is the array type used for instance checks, and also works for type annotations. In Python 3.9, can use constructs of the form `np.ndarray[shape, dtype]`, but this is currently poorly documented and it's unclear how well supported that is at the moment.
209
209
210
-
JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, we could do something like the following:
210
+
JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, changes might look something like this:
211
+
211
212
- add `jax.typing.NDArray` for use with type annotations
212
-
- add `TYPE_CHECKING` logic to ensure that `jnp.ndarray` can be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.
213
+
- add `TYPE_CHECKING` logic to ensure that the `jnp.ndarray`object can also be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.
213
214
214
-
A potential point of confusion with this is that JAX arrays are not actually of type `ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type).
215
-
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of all the above).
215
+
A potential point of confusion with this is that JAX arrays are not actually of type `jnp.ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type; see {jax-issue}`#12016`).
216
+
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of the above).
216
217
Despite the familiarity of numpy's API choices, the `ndarray` / `NDArray` / `Array` trichotomy may cause too much confusion
217
218
218
219
#### Choosing our own path: Unification
@@ -222,18 +223,15 @@ Python itself is slowly moving to a world of unifying instance and annotation ty
222
223
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers within type annotations, we could use a construct like the following:
223
224
```python
224
225
ifTYPE_CHECKING:
225
-
Array = Union[jax.Array, jax.Tracer]
226
+
Array = Union[jax._src.array.Array, jax.Tracer]
226
227
else:
227
228
Array = jax._src.array.Array
228
229
```
229
230
230
-
For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could support constructions like `jax.Array[(3, 4), int]` via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.
231
-
232
-
Another advantage of this unification route (using `jax.Array` as a generic array annotation) is that it fits well with the approach used by the `jaxtyping` library, which uses `jaxtyping.Array[...]` as its core annotation type.
231
+
For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, this would set us up to follow the conventions being developed in the [`jaxtyping`](https://github.com/google/jaxtyping/) project. Because `jaxtyping` uses `jaxtyping.Array[...]` as its core annotation type, unifying under `jax.Array` makes any potential future integration with that project more natural.
233
232
234
233
Given these advantages, it seems like the unified `jax.Array` is the better option compared to splitting annotation and instance logic between `jax.Array`, `jax.typing.NDArray`, and `jnp.ndarray`.
235
234
236
-
237
235
### Implementation Plan
238
236
239
237
To move forward with type annotations, we will do the following:
0 commit comments