Skip to content

Commit dd03e5d

Browse files
committed
Respond to feedback
1 parent 35f9e76 commit dd03e5d

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

docs/jep/12049-type-annotations.md

+7-9
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,13 @@ An advantage of this approach is that it is that the API is already familiar to
207207
- `NDArray`, which is a generic array type used only for type annotations
208208
- `np.ndarray`, which is the array type used for instance checks, and also works for type annotations. In Python 3.9, can use constructs of the form `np.ndarray[shape, dtype]`, but this is currently poorly documented and it's unclear how well supported that is at the moment.
209209

210-
JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, we could do something like the following:
210+
JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, changes might look something like this:
211+
211212
- add `jax.typing.NDArray` for use with type annotations
212-
- add `TYPE_CHECKING` logic to ensure that `jnp.ndarray` can be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.
213+
- add `TYPE_CHECKING` logic to ensure that the `jnp.ndarray` object can also be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.
213214

214-
A potential point of confusion with this is that JAX arrays are not actually of type `ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type).
215-
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of all the above).
215+
A potential point of confusion with this is that JAX arrays are not actually of type `jnp.ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type; see {jax-issue}`#12016`).
216+
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of the above).
216217
Despite the familiarity of numpy's API choices, the `ndarray` / `NDArray` / `Array` trichotomy may cause too much confusion
217218

218219
#### Choosing our own path: Unification
@@ -222,18 +223,15 @@ Python itself is slowly moving to a world of unifying instance and annotation ty
222223
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers within type annotations, we could use a construct like the following:
223224
```python
224225
if TYPE_CHECKING:
225-
Array = Union[jax.Array, jax.Tracer]
226+
Array = Union[jax._src.array.Array, jax.Tracer]
226227
else:
227228
Array = jax._src.array.Array
228229
```
229230

230-
For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could support constructions like `jax.Array[(3, 4), int]` via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.
231-
232-
Another advantage of this unification route (using `jax.Array` as a generic array annotation) is that it fits well with the approach used by the `jaxtyping` library, which uses `jaxtyping.Array[...]` as its core annotation type.
231+
For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, this would set us up to follow the conventions being developed in the [`jaxtyping`](https://github.com/google/jaxtyping/) project. Because `jaxtyping` uses `jaxtyping.Array[...]` as its core annotation type, unifying under `jax.Array` makes any potential future integration with that project more natural.
233232

234233
Given these advantages, it seems like the unified `jax.Array` is the better option compared to splitting annotation and instance logic between `jax.Array`, `jax.typing.NDArray`, and `jnp.ndarray`.
235234

236-
237235
### Implementation Plan
238236

239237
To move forward with type annotations, we will do the following:

0 commit comments

Comments
 (0)