You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/jep/12049-type-annotations.md
+94-29
Original file line number
Diff line number
Diff line change
@@ -175,7 +175,7 @@ Note that these will in general be simpler than the equivalent protocols used in
175
175
176
176
Conversely, outputs of functions and methods should be typed as strictly as possible: for example, for a JAX function that returns an array, the output should be annotated with `jnp.ndarray` rather than `ArrayLike`. Functions returning a dtype should always be annotated `np.dtype`, and functions returning a shape should always be `Tuple[int]` or a strictly-typed NamedShape equivalent. For this purpose, we will implement in {mod}`jax.typing` several strictly-typed analogs of the permissive types mentioned above, namely:
177
177
178
-
-`NDArray` is effectively equivalent to `Union[Tracer, jnp.ndarray]` and should be used to annotate array outputs.
178
+
-`Array` or `NDArray` (see below) for type annotation purposes is effectively equivalent to `Union[Tracer, jnp.ndarray]` and should be used to annotate array outputs.
179
179
-`DType` is an alias of `np.dtype`, perhaps with the ability to also represent key types and other generalizations used within JAX.
180
180
-`Shape` is essentially `Tuple[int, ...]`, perhaps with some additional flexibilty to account for dynamic shapes.
181
181
-`NamedShape` is an extension of `Shape` that allows for named shapes as used internall in JAX.
@@ -197,40 +197,104 @@ We will revisit this question in the future once support for such features stabi
197
197
198
198
Similarly, regarding [jaxtyping](http://github.com/google/jaxtyping) project, despite its promise as a project, at this point we consider its goals beyond the scope of the simpler level of type annotations we would like for the jax core project. This is a decision we could revisit at a future date.
199
199
200
-
### Design Considerations
200
+
### `Array`/`NDArray`Design Considerations
201
201
202
-
#### Following NumPy's Lead
202
+
Type annotation of arrays in JAX poses a unique challenge because of JAX's extensive use of duck-typing, i.e. passing and returning `Tracer` objects in place actual arrays within jax transformations.
203
+
This becomes increasingly confusing because objects used for type annotation often overlap with objects used for runtime instance checking, and may or may not correspond to the actual type hierarchy of the objects in question.
204
+
For JAX, we need to provide duck-typed objects for use in two contexts: **static type annotations** and **runtime instance checks**.
203
205
204
-
One possible model for JAX's type annotations is those provided by the [`numpy.typing`](https://numpy.org/devdocs/reference/typing.html) module.
205
-
An advantage of this approach is that it is that the API is already familiar to many; most notably `numpy.typing` provides:
206
+
The following discussion will assume that `jax.Array` is the type used to represent on-device arrays, which is not yet the case but will be once the work in {jax-issue}`#12016` is complete.
206
207
207
-
-`NDArray`, which is a generic array type used only for type annotations
208
-
-`np.ndarray`, which is the array type used for instance checks, and also works for type annotations. In Python 3.9, can use constructs of the form `np.ndarray[shape, dtype]`, but this is currently poorly documented and it's unclear how well supported that is at the moment.
208
+
1.**Static type annotations.**
209
+
We need to provide an object that can be used for duck-typed type annotations.
210
+
Assuming for the moment that we call this object `ArrayAnnotation`, we need a solution which satisfies `mypy` and `pytype` for a case like the following:
211
+
```python
212
+
@jit
213
+
deff(x: ArrayAnnotation) -> ArrayAnnotation:
214
+
assertisinstance(x, core.Tracer)
215
+
return x
216
+
```
217
+
This could be accomplished via a number of approaches, for example:
218
+
- Use a type union: `ArrayAnnotation = Union[Array, Tracer]`
219
+
- Create an interface file that declares `Tracer` and `Array` should be treated as subclasses of `ArrayAnnotation`.
220
+
- Restructure `Array` and `Tracer` so that `ArrayAnnotation` is a true base class of both.
221
+
2.**Runtime instance checks.**
222
+
We also must provide an object that can be used for duck-typed runtime `isinstance` checks.
223
+
Assuming for the moment that we call this object `ArrayInstance`, we need a solution that passes the following runtime check:
224
+
```python
225
+
deff(x):
226
+
returnisinstance(x, ArrayInstance)
227
+
x = jnp.array([1, 2, 3])
228
+
assert f(x) # x will be an array
229
+
assert jit(f)(x) # x will be a tracer
230
+
```
231
+
Again, there are a couple mechanisms that could be used for this:
232
+
- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/google/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)).
233
+
- define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer`
234
+
- restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer`
209
235
210
-
JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, changes might look something like this:
236
+
A decision we need to make is whether `ArrayAnnotation` and `ArrayInstance` should be the same or different objects. There is some precedent here; for example in the core Python language spec, `typing.Dict` and `typing.List` exist for the sake of annotation, while the built-in `dict` and `list` serve the purposes of instance checks.
237
+
However, `Dict` and `List` are [deprecated](https://peps.python.org/pep-0585/#implementation) in newer Python versions in favor of using `dict` and `list` for both annotation and instance checks.
211
238
212
-
- add `jax.typing.NDArray` for use with type annotations
213
-
- add `TYPE_CHECKING` logic to ensure that the `jnp.ndarray` object can also be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.
239
+
#### Following NumPy's lead
214
240
215
-
A potential point of confusion with this is that JAX arrays are not actually of type `jnp.ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type; see {jax-issue}`#12016`).
216
-
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of the above).
217
-
Despite the familiarity of numpy's API choices, the `ndarray` / `NDArray` / `Array` trichotomy may cause too much confusion
241
+
In NumPy's case, `np.typing.NDArray` serves the purpose of type annotations, while `np.ndarray` serves the purpose of instance checks (as well as array type identity).
242
+
Given this, it may be reasonable to conform to NumPy's precedent and implement the following:
218
243
219
-
#### Choosing our own path: Unification
244
+
-`jax.Array` is the actual type of on-device arrays.
245
+
-`jax.typing.NDArray` is the object used for duck-typed array annotations.
246
+
-`jax.numpy.ndarray` is the object used for duck-typed array instance checks.
220
247
221
-
Python itself is slowly moving to a world of unifying instance and annotation types; for example, with class-level `getitem` support in Python 3.9, classes like `typing.Dict`, `typing.List`, `typing.Tuple`, etc. are now [deprecated](https://peps.python.org/pep-0585/#implementation) in favor of their true-typed counterparts, `dict`, `list`, and `tuple`.
248
+
This might feel somewhat natural to NumPy power-users, however this trifurcation would likely be a source of confusion: the choice of which to use for instance checks and annotations is not immediately clear.
222
249
223
-
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers within type annotations, we could use a construct like the following:
224
-
```python
225
-
ifTYPE_CHECKING:
226
-
Array = Union[jax._src.array.Array, jax.Tracer]
227
-
else:
228
-
Array = jax._src.array.Array
229
-
```
250
+
#### Unification
251
+
252
+
Another approach would be to unify type checking and annotation via override mechanisms mentioned above.
253
+
254
+
##### Option 1: Partial unification
255
+
A partial unification might look like this:
256
+
257
+
-`jax.Array` is the actual type of on-device arrays.
258
+
-`jax.typing.Array` is the object used for duck-typed array annotations (via `.pyi` interfaces on `Array` and `Tracer`).
259
+
-`jax.typing.Array` is also the object used duck-typed instance checks (via an `__isinstance__` override in its metaclass)
260
+
261
+
In this approach, `jax.numpy.ndarray` would become a simple alias `jax.typing.Array` for backward compatibility.
262
+
263
+
264
+
##### Option 2: Full unification via overrides
265
+
Alternatively, we could opt for full unification via overrides:
266
+
267
+
-`jax.Array` is the actual type of on-device arrays.
268
+
-`jax.Array` is also the object used for duck-typed array annotations (via a `.pyi` interface on `Tracer`)
269
+
-`jax.Array` is also the object used for duck-typed instance checks (via an `__isinstance__` override in its metaclass)
230
270
231
-
For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, this would set us up to follow the conventions being developed in the [`jaxtyping`](https://github.com/google/jaxtyping/) project. Because `jaxtyping` uses `jaxtyping.Array[...]` as its core annotation type, unifying under `jax.Array`makes any potential future integration with that project more natural.
271
+
Here, `jax.numpy.ndarray`would become a simple alias `jax.Array`for backward compatibility.
232
272
233
-
Given these advantages, it seems like the unified `jax.Array` is the better option compared to splitting annotation and instance logic between `jax.Array`, `jax.typing.NDArray`, and `jnp.ndarray`.
273
+
##### Option 3: Full unification via class hierarchy
274
+
Finally, we could opt for full unification via restructuring of the class hierarchy and replacing duck-typing with OOP object hierarchies:
275
+
276
+
-`jax.Array` is the actual type of on-device arrays
277
+
-`jax.Array` is also the object used for array annotations, by ensuring that `Tracer` inherits from `jax.Array`
278
+
-`jax.Array` is also the object used for instance checks, via the same mechanism
279
+
280
+
Here `jnp.ndarray` could be an alias for `jax.Array`.
281
+
This final approach is in some senses the most pure, but it may be challenging from an OOP design standpoint (`Tracer`*is a*`Array`?).
282
+
283
+
##### Option 4: Parial unification via class hierarchy
284
+
We could appease OOP pedants by instead making `Tracer` and `Array` derive from a common `ArrayBase` base class:
285
+
286
+
-`jax.Array` is the actual type of on-device arrays
287
+
-`ArrayBase` is the object used for array annotations
288
+
-`ArrayBase` is also the object used for instance checks
289
+
290
+
Here `jnp.ndarray` would be an alias for `ArrayBase`.
291
+
This may be purer from an OOP perspective, but it reintroduces a bifurcation and the distinction between `Array` and `ArrayBase` for annotation and instance checks may become confusing.
292
+
293
+
##### Evaluation
294
+
295
+
There is no perfect option here, but weighing the pros and cons of these solutions, I (@jakevdp) believe that Option 4 presents the best path forward.
296
+
It offers the least confusing API for users (`jax.Array` is the only object you need to worry about), and does not require any significant restructuring of our existing codepaths.
297
+
There is one minor technical hurdle involved; that is that `jax.Array` will be defined in C++ via pybind11, and pybind11 currently [does not support](https://github.com/pybind/pybind11/issues/2696) custom metaclasses required for overriding `__instancecheck__`; nevertheless we should be able to work around this.
234
298
235
299
### Implementation Plan
236
300
@@ -240,17 +304,18 @@ To move forward with type annotations, we will do the following:
240
304
241
305
2. Create a private `jax._src.typing` (not providing any public APIs for now) and put in it the first level of simple types mentioned above:
242
306
243
-
-`Array`: because `jax.Array` is not yet out of experimental, we'll define `jax._src.typing.Array` with the type annotation features that will eventually move to `jax.Array` when it has landed.
244
-
- Move the `jnp.ndarray` definition into `typing.ndarray`; continue aliasing it at its current location.
307
+
-`Array`: because `jax.Array` is not yet out of experimental, we'll define `jax._src.typing.Array` with the type annotation and instance-checking features that will eventually move to `jax.Array` when it has fully landed.
245
308
-`ArrayLike`: a Union of types valid as inputs to normal `jax.numpy` functions
246
309
-`Dtype` / `DtypeLike` (Check on capitalization of the `t`: what do other projects use?)
247
310
-`Shape` / `NamedShape` / `ShapeLike`
248
311
249
-
3. Once this is implemented, use these new typing definitions to comprehensively annotate functions within `jax.lax` according to the guidelines above.
312
+
3. When this is implemented, remove the existing definition of `jnp.ndarray` and set is as an alias of `jax._src.typing.Array`.
313
+
314
+
4. As a test, use these new typing definitions to comprehensively annotate functions within `jax.lax` according to the guidelines above.
250
315
251
-
4. Continue adding additional annotations one module at a time, focusing on public API functions.
316
+
5. Continue adding additional annotations one module at a time, focusing on public API functions.
252
317
253
-
5. Once `jax.Array` has fully landed, migrate the features of `jax._src.typing.Array` to this class, and let `jax.numpy.ndarray = Array`.
318
+
6. Once `jax.Array` has fully landed, migrate the features of `jax._src.typing.Array` to this class, and let `jax.numpy.ndarray = Array`.
254
319
255
320
6. When all is finalized, create a public `jax.typing` module that makes the above types available to users, along with documentation of annotation best practices for code using JAX.
0 commit comments