You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/jep/12049-type-annotations.md
+14-6
Original file line number
Diff line number
Diff line change
@@ -219,33 +219,41 @@ Despite the familiarity of numpy's API choices, the `ndarray` / `NDArray` / `Arr
219
219
220
220
Python itself is slowly moving to a world of unifying instance and annotation types; for example, with class-level `getitem` support in Python 3.9, classes like `typing.Dict`, `typing.List`, `typing.Tuple`, etc. are now [deprecated](https://peps.python.org/pep-0585/#implementation) in favor of their true-typed counterparts, `dict`, `list`, and `tuple`.
221
221
222
-
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers, tracers could be supported using a construct like the following:
222
+
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers within type annotations, we could use a construct like the following:
223
223
```python
224
224
ifTYPE_CHECKING:
225
225
Array = Union[jax.Array, jax.Tracer]
226
226
else:
227
-
Array = jax.Array
227
+
Array = jax._src.array.Array
228
228
```
229
-
For instance checks, we could use the same metaclass override we currently do in the case of `np.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could do so via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.
229
+
230
+
For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could support constructions like `jax.Array[(3, 4), int]` via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.
230
231
231
232
Another advantage of this unification route (using `jax.Array` as a generic array annotation) is that it fits well with the approach used by the `jaxtyping` library, which uses `jaxtyping.Array[...]` as its core annotation type.
232
233
234
+
Given these advantages, it seems like the unified `jax.Array` is the better option compared to splitting annotation and instance logic between `jax.Array`, `jax.typing.NDArray`, and `jnp.ndarray`.
235
+
233
236
234
237
### Implementation Plan
235
238
236
239
To move forward with type annotations, we will do the following:
237
240
238
241
1. Iterate on this JEP doc until developers and stakeholders are bought-in.
239
242
240
-
2. Create a `jax.typing` module (a thin wrapper around `jax._src.typing`as is our usual pattern) and put in it the first level of simple types mentioned above:
243
+
2. Create a private `jax._src.typing`(not providing any public APIs for now) and put in it the first level of simple types mentioned above:
241
244
242
-
-`Array` / `ArrayLike` (For now we can use `Array = Union[ndarray, Tracer]`, and eventually plan to move to a simpler `Array = jax.Array` with enhancements discussed above)
245
+
-`Array`: because `jax.Array` is not yet out of experimental, we'll define `jax._src.typing.Array` with the type annotation features that will eventually move to `jax.Array` when it has landed.
246
+
- Move the `jnp.ndarray` definition into `typing.ndarray`; continue aliasing it at its current location.
247
+
-`ArrayLike`: a Union of types valid as inputs to normal `jax.numpy` functions
243
248
-`Dtype` / `DtypeLike` (Check on capitalization of the `t`: what do other projects use?)
244
249
-`Shape` / `NamedShape` / `ShapeLike`
245
-
- Move the `jnp.ndarray` definition into `typing.ndarray`; continue aliasing it at its current location.
246
250
247
251
3. Once this is implemented, use these new typing definitions to comprehensively annotate functions within `jax.lax` according to the guidelines above.
248
252
249
253
4. Continue adding additional annotations one module at a time, focusing on public API functions.
250
254
255
+
5. Once `jax.Array` has fully landed, migrate the features of `jax._src.typing.Array` to this class, and let `jax.numpy.ndarray = Array`.
256
+
257
+
6. When all is finalized, create a public `jax.typing` module that makes the above types available to users, along with documentation of annotation best practices for code using JAX.
258
+
251
259
We will track this work in {jax-issue}`#12049`, which this JEP is named for.
0 commit comments