Skip to content

Commit 21f9602

Browse files
committed
add design discussion
1 parent 655ae98 commit 21f9602

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

docs/jep/12049-type-annotations.md

+35-1
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,40 @@ We will revisit this question in the future once support for such features stabi
197197

198198
Similarly, regarding [jaxtyping](http://github.com/google/jaxtyping) project, despite its promise as a project, at this point we consider its goals beyond the scope of the simpler level of type annotations we would like for the jax core project. This is a decision we could revisit at a future date.
199199

200+
### Design Considerations
201+
202+
#### Following NumPy's Lead
203+
204+
One possible model for JAX's type annotations is those provided by the [`numpy.typing`](https://numpy.org/devdocs/reference/typing.html) module.
205+
An advantage of this approach is that it is that the API is already familiar to many; most notably `numpy.typing` provides:
206+
207+
- `NDArray`, which is a generic array type used only for type annotations
208+
- `np.ndarray`, which is the array type used for instance checks, and also works for type annotations. In Python 3.9, can use constructs of the form `np.ndarray[shape, dtype]`, but this is currently poorly documented and it's unclear how well supported that is at the moment.
209+
210+
JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, we could do something like the following:
211+
- add `jax.typing.NDArray` for use with type annotations
212+
- add `TYPE_CHECKING` logic to ensure that `jnp.ndarray` can be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.
213+
214+
A potential point of confusion with this is that JAX arrays are not actually of type `ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type).
215+
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of all the above).
216+
Despite the familiarity of numpy's API choices, the `ndarray` / `NDArray` / `Array` trichotomy may cause too much confusion
217+
218+
#### Choosing our own path: Unification
219+
220+
Python itself is slowly moving to a world of unifying instance and annotation types; for example, with class-level `getitem` support in Python 3.9, classes like `typing.Dict`, `typing.List`, `typing.Tuple`, etc. are now [deprecated](https://peps.python.org/pep-0585/#implementation) in favor of their true-typed counterparts, `dict`, `list`, and `tuple`.
221+
222+
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers, tracers could be supported using a construct like the following:
223+
```python
224+
if TYPE_CHECKING:
225+
Array = Union[jax.Array, jax.Tracer]
226+
else:
227+
Array = jax.Array
228+
```
229+
For instance checks, we could use the same metaclass override we currently do in the case of `np.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could do so via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.
230+
231+
Another advantage of this unification route (using `jax.Array` as a generic array annotation) is that it fits well with the approach used by the `jaxtyping` library, which uses `jaxtyping.Array[...]` as its core annotation type.
232+
233+
200234
### Implementation Plan
201235

202236
To move forward with type annotations, we will do the following:
@@ -205,7 +239,7 @@ To move forward with type annotations, we will do the following:
205239

206240
2. Create a `jax.typing` module (a thin wrapper around `jax._src.typing` as is our usual pattern) and put in it the first level of simple types mentioned above:
207241

208-
- `NDArray` / `ArrayLike`
242+
- `Array` / `ArrayLike` (For now we can use `Array = Union[ndarray, Tracer]`, and eventually plan to move to a simpler `Array = jax.Array` with enhancements discussed above)
209243
- `Dtype` / `DtypeLike` (Check on capitalization of the `t`: what do other projects use?)
210244
- `Shape` / `NamedShape` / `ShapeLike`
211245
- Move the `jnp.ndarray` definition into `typing.ndarray`; continue aliasing it at its current location.

0 commit comments

Comments
 (0)