You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/jep/12049-type-annotations.md
+35-1
Original file line number
Diff line number
Diff line change
@@ -197,6 +197,40 @@ We will revisit this question in the future once support for such features stabi
197
197
198
198
Similarly, regarding [jaxtyping](http://github.com/google/jaxtyping) project, despite its promise as a project, at this point we consider its goals beyond the scope of the simpler level of type annotations we would like for the jax core project. This is a decision we could revisit at a future date.
199
199
200
+
### Design Considerations
201
+
202
+
#### Following NumPy's Lead
203
+
204
+
One possible model for JAX's type annotations is those provided by the [`numpy.typing`](https://numpy.org/devdocs/reference/typing.html) module.
205
+
An advantage of this approach is that it is that the API is already familiar to many; most notably `numpy.typing` provides:
206
+
207
+
-`NDArray`, which is a generic array type used only for type annotations
208
+
-`np.ndarray`, which is the array type used for instance checks, and also works for type annotations. In Python 3.9, can use constructs of the form `np.ndarray[shape, dtype]`, but this is currently poorly documented and it's unclear how well supported that is at the moment.
209
+
210
+
JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, we could do something like the following:
211
+
- add `jax.typing.NDArray` for use with type annotations
212
+
- add `TYPE_CHECKING` logic to ensure that `jnp.ndarray` can be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.
213
+
214
+
A potential point of confusion with this is that JAX arrays are not actually of type `ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type).
215
+
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of all the above).
216
+
Despite the familiarity of numpy's API choices, the `ndarray` / `NDArray` / `Array` trichotomy may cause too much confusion
217
+
218
+
#### Choosing our own path: Unification
219
+
220
+
Python itself is slowly moving to a world of unifying instance and annotation types; for example, with class-level `getitem` support in Python 3.9, classes like `typing.Dict`, `typing.List`, `typing.Tuple`, etc. are now [deprecated](https://peps.python.org/pep-0585/#implementation) in favor of their true-typed counterparts, `dict`, `list`, and `tuple`.
221
+
222
+
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers, tracers could be supported using a construct like the following:
223
+
```python
224
+
ifTYPE_CHECKING:
225
+
Array = Union[jax.Array, jax.Tracer]
226
+
else:
227
+
Array = jax.Array
228
+
```
229
+
For instance checks, we could use the same metaclass override we currently do in the case of `np.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could do so via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.
230
+
231
+
Another advantage of this unification route (using `jax.Array` as a generic array annotation) is that it fits well with the approach used by the `jaxtyping` library, which uses `jaxtyping.Array[...]` as its core annotation type.
232
+
233
+
200
234
### Implementation Plan
201
235
202
236
To move forward with type annotations, we will do the following:
@@ -205,7 +239,7 @@ To move forward with type annotations, we will do the following:
205
239
206
240
2. Create a `jax.typing` module (a thin wrapper around `jax._src.typing` as is our usual pattern) and put in it the first level of simple types mentioned above:
207
241
208
-
-`NDArray` / `ArrayLike`
242
+
-`Array` / `ArrayLike` (For now we can use `Array = Union[ndarray, Tracer]`, and eventually plan to move to a simpler `Array = jax.Array` with enhancements discussed above)
209
243
-`Dtype` / `DtypeLike` (Check on capitalization of the `t`: what do other projects use?)
210
244
-`Shape` / `NamedShape` / `ShapeLike`
211
245
- Move the `jnp.ndarray` definition into `typing.ndarray`; continue aliasing it at its current location.
0 commit comments