Skip to content

Commit 35f9e76

Browse files
committed
clarify implementation regarding jax.Array
1 parent 21f9602 commit 35f9e76

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

docs/jep/12049-type-annotations.md

+14-6
Original file line numberDiff line numberDiff line change
@@ -219,33 +219,41 @@ Despite the familiarity of numpy's API choices, the `ndarray` / `NDArray` / `Arr
219219

220220
Python itself is slowly moving to a world of unifying instance and annotation types; for example, with class-level `getitem` support in Python 3.9, classes like `typing.Dict`, `typing.List`, `typing.Tuple`, etc. are now [deprecated](https://peps.python.org/pep-0585/#implementation) in favor of their true-typed counterparts, `dict`, `list`, and `tuple`.
221221

222-
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers, tracers could be supported using a construct like the following:
222+
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers within type annotations, we could use a construct like the following:
223223
```python
224224
if TYPE_CHECKING:
225225
Array = Union[jax.Array, jax.Tracer]
226226
else:
227-
Array = jax.Array
227+
Array = jax._src.array.Array
228228
```
229-
For instance checks, we could use the same metaclass override we currently do in the case of `np.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could do so via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.
229+
230+
For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could support constructions like `jax.Array[(3, 4), int]` via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.
230231

231232
Another advantage of this unification route (using `jax.Array` as a generic array annotation) is that it fits well with the approach used by the `jaxtyping` library, which uses `jaxtyping.Array[...]` as its core annotation type.
232233

234+
Given these advantages, it seems like the unified `jax.Array` is the better option compared to splitting annotation and instance logic between `jax.Array`, `jax.typing.NDArray`, and `jnp.ndarray`.
235+
233236

234237
### Implementation Plan
235238

236239
To move forward with type annotations, we will do the following:
237240

238241
1. Iterate on this JEP doc until developers and stakeholders are bought-in.
239242

240-
2. Create a `jax.typing` module (a thin wrapper around `jax._src.typing` as is our usual pattern) and put in it the first level of simple types mentioned above:
243+
2. Create a private `jax._src.typing` (not providing any public APIs for now) and put in it the first level of simple types mentioned above:
241244

242-
- `Array` / `ArrayLike` (For now we can use `Array = Union[ndarray, Tracer]`, and eventually plan to move to a simpler `Array = jax.Array` with enhancements discussed above)
245+
- `Array`: because `jax.Array` is not yet out of experimental, we'll define `jax._src.typing.Array` with the type annotation features that will eventually move to `jax.Array` when it has landed.
246+
- Move the `jnp.ndarray` definition into `typing.ndarray`; continue aliasing it at its current location.
247+
- `ArrayLike`: a Union of types valid as inputs to normal `jax.numpy` functions
243248
- `Dtype` / `DtypeLike` (Check on capitalization of the `t`: what do other projects use?)
244249
- `Shape` / `NamedShape` / `ShapeLike`
245-
- Move the `jnp.ndarray` definition into `typing.ndarray`; continue aliasing it at its current location.
246250

247251
3. Once this is implemented, use these new typing definitions to comprehensively annotate functions within `jax.lax` according to the guidelines above.
248252

249253
4. Continue adding additional annotations one module at a time, focusing on public API functions.
250254

255+
5. Once `jax.Array` has fully landed, migrate the features of `jax._src.typing.Array` to this class, and let `jax.numpy.ndarray = Array`.
256+
257+
6. When all is finalized, create a public `jax.typing` module that makes the above types available to users, along with documentation of annotation best practices for code using JAX.
258+
251259
We will track this work in {jax-issue}`#12049`, which this JEP is named for.

0 commit comments

Comments
 (0)