Skip to content

Commit c4ae178

Browse files
committed
Reword some things
1 parent a118c6f commit c4ae178

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

docs/jep/11859-type-annotations.md

+16-7
Original file line numberDiff line numberDiff line change
@@ -132,20 +132,29 @@ With this framing (Level 1/2/3) and JAX-specific challenges in mind, we can begi
132132

133133
For JAX type annotation, we have the following goals:
134134

135-
1. We would like to support full, *Level 1, 2, and 3* type annotation. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions.
135+
1. We would like to support full, *Level 1, 2, and 3* type annotation as far as possible. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions.
136136

137-
2. When functions are decorated by jax transformations like `jit`, `vmap`, `grad`, etc. JAX will **strip all annotations**. The reason for this is that without the mechanisms of [PEP 612](https://peps.python.org/pep-0612/) there is no good way to do otherwise, and `ParamSpec` will not be available for use until Python 3.10.
137+
2. In order to not add undue development friction (due to the internal/external CI differences), we would like to be conservative in the type annotation constructs we use: in particular, when it comes to recently-introduced mechanisms such as `ParamSpec` (PEP [PEP 612](https://peps.python.org/pep-0612/),), we would like to wait until support in mypy and other tools stabilizes before relying on them.
138+
One impact of this is that for the time being, when functions are decorated by jax transformations like `jit`, `vmap`, `grad`, etc. JAX will **strip all annotations**. This is because `ParamSpec` is still only partially supported; the PEP is slated for Python 3.10 (though it can be used before that via [typing-extensions](https://github.com/python/typing_extensions) and at the time of this writing mypy has a laundry-list of incompatibilities with the `ParamSpec`-based annotations (see [`ParamSpec` mypy bug tracker](https://github.com/python/mypy/issues?q=is%3Aissue+is%3Aopen++label%3Atopic-paramspec+)).
139+
We will revisit this question in the future once support for such features stabilizes.
138140

139-
3. We will design JAX type annotations to annotate the **intent** of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (a simple example is an arbitrary iterator passed in place of a shape). Inputs to JAX functions should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or dtype-adjacent classes such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as:
141+
3. JAX type annotations shoudl in general indicate the **intent** of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (a simple example is an arbitrary iterator passed in place of a shape in some function implementations).
142+
143+
4. Inputs to JAX functions and methods should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or dtype-adjacent classes such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as:
140144

141-
- `NDArray`
142145
- `ArrayLike`
143146
- `DtypeLike`
144147
- `ShapeLike`
145148
- etc.
146149

147-
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation.
150+
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in most places, so the type definition will be simpler than the numpy analog.
151+
152+
5. Conversely, outputs of functions and methods should be typed as strictly as possible: for example, for a JAX function that returns an array, the output should be annotated with `jnp.ndarray` rather than `ArrayLike`. Functions returning a dtype should always be annotated `np.dtype`, and functions returning a shape should always be `Tuple[int]` or a strictly-typed NamedShape equivalent. For this purpose, we will implement in {mod}`jax.typing` several strictly-typed analogs of the permissive types mentioned above, namely:
148153

149-
4. Function outputs should be typed as strictly as possible: for example, for a function that returns an array, the output should be annotated with `jnp.ndarray` rather than `ArrayLike`. Functions returning a dtype should always be annotated `np.dtype`, and functions returning a shape should always be `Tuple[int]` or a strictly-typed NamedShape equivalent.
154+
- `NDArray` (perhaps this could be equivalent to `jnp.ndarray`?)
155+
- `DType` (perhaps this could be simply `np.dtype`?)
156+
- `Shape`
157+
- `NamedShape`
158+
- etc.
150159

151-
5. Aside from common typing protocols gathered in `jax.typing`, we should avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as `Union[simple_type, Any]` in the case that the full type specification cannot be simply specified. This is a comprimise that achieces the goals of Level 1 and 2 annotations, while punting on Level 3 for complicated APIs.
160+
6. Aside from common typing protocols gathered in `jax.typing`, we should err on the side of simplicity, and avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as `Union[simple_type, Any]` in the case that the full type specification of the API cannot be succinctly specified. This is a comprimise that achieves the goals of Level 1 and 2 annotations, while punting on Level 3 in favor of simplicity.

0 commit comments

Comments
 (0)