-
Notifications
You must be signed in to change notification settings - Fork 3k
Add types to JAX source code #1555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I have not tried to run pytype outside of Google, but mypy is pretty fast for moderately sized projects like JAX. Unfortunately NumPy itself doesn't really have type annotations yet, so the value of these is somewhat limited. Also, it would be great to be able to type check errors like the usage of |
lax.scan is beyond what a simple type system could do. I think that even without numpy type annotations there would be benefit |
My understanding is that pytype and mypy serve slightly different purposes, in that mypy is intended to help construct and check a "static Python" codebase, where ~everything is concretely statically typed, and annotations are relatively pervasive, while pytype is intended more for Python codebases "as they are" and will happily infer unions or Any. If that's in fact the case, then I think JAX would benefit more from pytype (but I think even pytype throws a few false positives on our codebase; we'd want a couple type annotations to silence those). |
We are systematically adding more type annotations to JAX over time. Closing this issue because it isn't actionable; we can open specific issues or PRs if there are particular type annotations we want to add. |
Glad to hear JAX will eventually have annotations! I was following this issue for updates. I actually started trying to add annotations myself before I realized how much work it was going to be. One thing it would be nice to expose as soon as possible though are base types for tensors and pytrees: Tensor = Union[np.ndarray, jnp.ndarray] # probably needs more things in the union like tracers and DeviceArray?
PyTree = Union[Tensor,
'PyTreeLike',
Tuple['PyTree', ...],
List['PyTree'],
Dict[Hashable, 'PyTree'],
None] |
There's recently been a flurry of activity in numpy-stubs. The plan is to move the annotations into NumPy proper soon. It's a little challenging due to some legacy choices in NumPy's API (e.g., the existence of both 0d arrays and scalars) but hopefully we can do better in JAX :). Type checking pytrees is likely beyond the capabilities of Python type checkers, at least without a custom plugin of some sort. |
Yeah, I just saw that! I'm really excited about numpy annotations. Regarding pytrees, I like to annotate them in my code so that I can keep track of what is a PyTree and what's a tensor. What's the problem that you see with type-checking pytrees though, I'm curious? |
I just don't think either mypy or pytype is current capable of type checking user-registered pytrees. Typing has protocol but we don't require implementing particular methods for new pytrees. |
Right, that makes sense. You could just add an abstract base class like #2916, and then users who want to register their own types can use inheritance to do the registration, respond to isinstance checks, and allow type annotation. |
No need for a base class -- if we are happy requiring users to make use of |
Well, I don't want to force you to pick one solution or another, but it seems to me that when the user is trying to declare that Also, the problem with using Finally, the decorator has the very weird effect that if I guess I'm not really sure why you're opposed to inheritance. It seems so convoluted to avoid it, and I can't see why. I understand that when |
Let's discuss type checking for pytrees in a new issue: #3340 |
Running mypy on:
gives me
Is this expected behaviour or is there a problem with the wrapping/type annotations in JAX? |
Currently the As for why this is still untyped: it's complicated. Type annotation for arrays is not trivial (see #943 for some discussion); for this reason we're still using To further complicate things, mypy does not have good support for flexible decorators (see .e.g #5556), which JAX uses liberally to define things like the Those things put together make it difficult to specify appropriate types for JAX's full API, though we're slowly working toward that where possible - search JAX's issue tracker for |
I think that adding type hints to the JAX source code would be an improvement for code readability and perhaps also as protection against simple errors. The checks should be added also to Travis builds, but should be easily runnable locally.
I mentioned this to several people, and in a couple of cases I heard (e.g., @hawkinsp) that pytype can be slow. I have not had issues with pytype before, perhaps it depends on the actual code base. Maybe mypy is better?
The text was updated successfully, but these errors were encountered: