You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/jep/11859-type-annotations.md
+16-7
Original file line number
Diff line number
Diff line change
@@ -132,20 +132,29 @@ With this framing (Level 1/2/3) and JAX-specific challenges in mind, we can begi
132
132
133
133
For JAX type annotation, we have the following goals:
134
134
135
-
1. We would like to support full, *Level 1, 2, and 3* type annotation. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions.
135
+
1. We would like to support full, *Level 1, 2, and 3* type annotation as far as possible. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions.
136
136
137
-
2. When functions are decorated by jax transformations like `jit`, `vmap`, `grad`, etc. JAX will **strip all annotations**. The reason for this is that without the mechanisms of [PEP 612](https://peps.python.org/pep-0612/) there is no good way to do otherwise, and `ParamSpec` will not be available for use until Python 3.10.
137
+
2. In order to not add undue development friction (due to the internal/external CI differences), we would like to be conservative in the type annotation constructs we use: in particular, when it comes to recently-introduced mechanisms such as `ParamSpec` (PEP [PEP 612](https://peps.python.org/pep-0612/),), we would like to wait until support in mypy and other tools stabilizes before relying on them.
138
+
One impact of this is that for the time being, when functions are decorated by jax transformations like `jit`, `vmap`, `grad`, etc. JAX will **strip all annotations**. This is because `ParamSpec` is still only partially supported; the PEP is slated for Python 3.10 (though it can be used before that via [typing-extensions](https://github.com/python/typing_extensions) and at the time of this writing mypy has a laundry-list of incompatibilities with the `ParamSpec`-based annotations (see [`ParamSpec` mypy bug tracker](https://github.com/python/mypy/issues?q=is%3Aissue+is%3Aopen++label%3Atopic-paramspec+)).
139
+
We will revisit this question in the future once support for such features stabilizes.
138
140
139
-
3. We will design JAX type annotations to annotate the **intent** of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (a simple example is an arbitrary iterator passed in place of a shape). Inputs to JAX functions should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or dtype-adjacent classes such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as:
141
+
3. JAX type annotations shoudl in general indicate the **intent** of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (a simple example is an arbitrary iterator passed in place of a shape in some function implementations).
142
+
143
+
4. Inputs to JAX functions and methods should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or dtype-adjacent classes such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as:
140
144
141
-
-`NDArray`
142
145
-`ArrayLike`
143
146
-`DtypeLike`
144
147
-`ShapeLike`
145
148
- etc.
146
149
147
-
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation.
150
+
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in most places, so the type definition will be simpler than the numpy analog.
151
+
152
+
5. Conversely, outputs of functions and methods should be typed as strictly as possible: for example, for a JAX function that returns an array, the output should be annotated with `jnp.ndarray` rather than `ArrayLike`. Functions returning a dtype should always be annotated `np.dtype`, and functions returning a shape should always be `Tuple[int]` or a strictly-typed NamedShape equivalent. For this purpose, we will implement in {mod}`jax.typing` several strictly-typed analogs of the permissive types mentioned above, namely:
148
153
149
-
4. Function outputs should be typed as strictly as possible: for example, for a function that returns an array, the output should be annotated with `jnp.ndarray` rather than `ArrayLike`. Functions returning a dtype should always be annotated `np.dtype`, and functions returning a shape should always be `Tuple[int]` or a strictly-typed NamedShape equivalent.
154
+
-`NDArray` (perhaps this could be equivalent to `jnp.ndarray`?)
155
+
-`DType` (perhaps this could be simply `np.dtype`?)
156
+
-`Shape`
157
+
-`NamedShape`
158
+
- etc.
150
159
151
-
5. Aside from common typing protocols gathered in `jax.typing`, we should avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as `Union[simple_type, Any]` in the case that the full type specification cannot be simply specified. This is a comprimise that achieces the goals of Level 1 and 2 annotations, while punting on Level 3 for complicated APIs.
160
+
6. Aside from common typing protocols gathered in `jax.typing`, we should err on the side of simplicity, and avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as `Union[simple_type, Any]` in the case that the full type specification of the API cannot be succinctly specified. This is a comprimise that achieves the goals of Level 1 and 2 annotations, while punting on Level 3 in favor of simplicity.
0 commit comments