Skip to content

Add types to JAX source code #1555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gnecula opened this issue Oct 22, 2019 · 14 comments
Closed

Add types to JAX source code #1555

gnecula opened this issue Oct 22, 2019 · 14 comments
Labels
documentation enhancement New feature or request

Comments

@gnecula
Copy link
Collaborator

gnecula commented Oct 22, 2019

I think that adding type hints to the JAX source code would be an improvement for code readability and perhaps also as protection against simple errors. The checks should be added also to Travis builds, but should be easily runnable locally.

I mentioned this to several people, and in a couple of cases I heard (e.g., @hawkinsp) that pytype can be slow. I have not had issues with pytype before, perhaps it depends on the actual code base. Maybe mypy is better?

@gnecula gnecula added enhancement New feature or request documentation labels Oct 22, 2019
@shoyer
Copy link
Collaborator

shoyer commented Oct 22, 2019

I have not tried to run pytype outside of Google, but mypy is pretty fast for moderately sized projects like JAX.

Unfortunately NumPy itself doesn't really have type annotations yet, so the value of these is somewhat limited.

Also, it would be great to be able to type check errors like the usage of lax.scan in #1534, but unfortunately I'm pretty sure that can't be checked currently. We would need a Python typing equivalent of a pytree tree definition, which I'm not sure is possible.

@gnecula
Copy link
Collaborator Author

gnecula commented Oct 22, 2019

lax.scan is beyond what a simple type system could do. I think that even without numpy type annotations there would be benefit

@jekbradbury
Copy link
Contributor

My understanding is that pytype and mypy serve slightly different purposes, in that mypy is intended to help construct and check a "static Python" codebase, where ~everything is concretely statically typed, and annotations are relatively pervasive, while pytype is intended more for Python codebases "as they are" and will happily infer unions or Any. If that's in fact the case, then I think JAX would benefit more from pytype (but I think even pytype throws a few false positives on our codebase; we'd want a couple type annotations to silence those).

@hawkinsp
Copy link
Collaborator

hawkinsp commented Jun 5, 2020

We are systematically adding more type annotations to JAX over time. Closing this issue because it isn't actionable; we can open specific issues or PRs if there are particular type annotations we want to add.

@hawkinsp hawkinsp closed this as completed Jun 5, 2020
@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Jun 5, 2020

Glad to hear JAX will eventually have annotations! I was following this issue for updates. I actually started trying to add annotations myself before I realized how much work it was going to be.

One thing it would be nice to expose as soon as possible though are base types for tensors and pytrees:

Tensor = Union[np.ndarray, jnp.ndarray]  # probably needs more things in the union like tracers and DeviceArray?
PyTree = Union[Tensor,
               'PyTreeLike',
               Tuple['PyTree', ...],
               List['PyTree'],
               Dict[Hashable, 'PyTree'],
               None]

@shoyer
Copy link
Collaborator

shoyer commented Jun 5, 2020

There's recently been a flurry of activity in numpy-stubs. The plan is to move the annotations into NumPy proper soon. It's a little challenging due to some legacy choices in NumPy's API (e.g., the existence of both 0d arrays and scalars) but hopefully we can do better in JAX :).

Type checking pytrees is likely beyond the capabilities of Python type checkers, at least without a custom plugin of some sort.

@NeilGirdhar
Copy link
Contributor

Yeah, I just saw that! I'm really excited about numpy annotations.

Regarding pytrees, I like to annotate them in my code so that I can keep track of what is a PyTree and what's a tensor.

What's the problem that you see with type-checking pytrees though, I'm curious?

@shoyer
Copy link
Collaborator

shoyer commented Jun 5, 2020

I just don't think either mypy or pytype is current capable of type checking user-registered pytrees. Typing has protocol but we don't require implementing particular methods for new pytrees.

@NeilGirdhar
Copy link
Contributor

Right, that makes sense. You could just add an abstract base class like #2916, and then users who want to register their own types can use inheritance to do the registration, respond to isinstance checks, and allow type annotation.

@shoyer
Copy link
Collaborator

shoyer commented Jun 5, 2020

No need for a base class -- if we are happy requiring users to make use of register_pytree_node_class, then we could just define a typing protocol based on the presence of tree_flatten/tree_unflatten methods.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Jun 5, 2020

Well, I don't want to force you to pick one solution or another, but it seems to me that when the user is trying to declare that SomeType is-a PyTree, the most logical way to do that is through inheritance.

Also, the problem with using typing.Protocol is that you can't do runtime checks, unless you were also planning on adding a subtype of ABCMeta and registering there, or adding some alternative way of checking registration (seems ugly to me). The abstract base class seems much simpler than the combination of a typing.Protocol and an ABCMeta declaration.

Finally, the decorator has the very weird effect that if B was decorated with register_pytree_node_class, and A inherits from B, then paradoxically, A is not a pytree! It will obey your protocol, but JAX won't understand it. In other words, the transitivity of is-instance is violated by the decorator approach. The ABC solution doesn't have this problem.

I guess I'm not really sure why you're opposed to inheritance. It seems so convoluted to avoid it, and I can't see why. I understand that when register_pytree_node_class was added, there was no guarantee that you had __init_subclass__ available, but that's not true anymore.

@shoyer
Copy link
Collaborator

shoyer commented Jun 5, 2020

Let's discuss type checking for pytrees in a new issue: #3340

@tetterl
Copy link

tetterl commented Feb 1, 2021

Running mypy on:

import jax.numpy as jnp


def g(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    return (x + 2.0 * y).flatten()


b = g(jnp.array([1, 2]), jnp.array([3]))

gives me

mypy src/jax_types.py
src/jax_types.py:5: error: Returning Any from function declared to return "ndarray"
src/jax_types.py:5: error: "float" has no attribute "flatten"
Found 2 errors in 1 file (checked 1 source file)

Is this expected behaviour or is there a problem with the wrapping/type annotations in JAX?

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 1, 2021

Currently the flatten() method is untyped, so it's expected that mypy interprets the return type as Any.

As for why this is still untyped: it's complicated. Type annotation for arrays is not trivial (see #943 for some discussion); for this reason we're still using typing.Any as a stand-in for the array type in many cases; e.g. here:
https://github.com/google/jax/blob/26afe307018b04dac9857985702fa8c4a4d010ee/jax/_src/lax/lax.py#L67-L69

To further complicate things, mypy does not have good support for flexible decorators (see .e.g #5556), which JAX uses liberally to define things like the flatten method

Those things put together make it difficult to specify appropriate types for JAX's full API, though we're slowly working toward that where possible - search JAX's issue tracker for "pytype" or "mypy" to find various related issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants