Skip to content

jax._src.typing: add basic types and use them in lax.py #12300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 13, 2022

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Sep 8, 2022

Part of #12049; this implements some of the simpler types discussed in #11859

The question of what to do with Array is yet to be resolved, but this PR creates the skeleton into which we can insert the eventual solution.

For now, I'm keeping these as private APIs, in jax._src.typing. Once we're certain we're happy with them, we can plan to export them in jax.typing.

@jakevdp jakevdp force-pushed the typing-simple branch 2 times, most recently from b0c712d to 35a3e3e Compare September 8, 2022 23:43
@jakevdp jakevdp requested a review from froystig September 9, 2022 00:09
@jakevdp jakevdp self-assigned this Sep 9, 2022
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Sep 9, 2022
@jakevdp jakevdp changed the title jax._src.typing: add the easy types and use them in lax.py jax._src.typing: add basic types and use them in lax.py Sep 9, 2022
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to add unit tests? For example, can/should we test that various types that we expect to check out as HasDtypeAttribute indeed do?

Otherwise, I think we still need to figure out how opaque-dtype arrays (like key arrays, soon BInt arrays too) fit in to the picture – see inline comment.

Array = Any

# ArrayLike is a Union of all objects that can be implicitly converted to a JAX array.
ArrayLike = Union[
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does jax.random.KeyArray go in here too? If so, then several of the lax.py annotations in this PR are unsound, e.g. you cannot lax.sin a key array. But if not, then several of the lax.py annotations are incomplete, e.g. you can lax.expand_dims a key array.

Copy link
Collaborator Author

@jakevdp jakevdp Sep 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect KeyArray objects to eventually be valid in all lax routines? If not, then I think we should use Union annotations on a case-by-case basis for the functions that support KeyArray.

Array here is meant to be a lowest-common-denominator type for all JAX public APIs that accept array-like objects – it's always possible to add KeyArray in a Union where relevant, but it's not easy to take it away.

Copy link
Member

@froystig froystig Sep 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wrote above is both the current and eventual intended situation: KeyArray is valid for some lax (and lax-numpy) operations, but not others. One cannot sin or + key arrays, but it is possible to expand_dims, squeeze, etc. A union sounds good, at least for correcting the annotations in this PR. Such a union will likely expand at some point to also account for BInt arrays.

Array here is meant to be a lowest-common-denominator type for all JAX public APIs that accept array-like objects

Even if a union resolves this in code, it seems that this statement could be refined in light of the above (which maybe matters if it appears in documentation). When you say "[...] APIs that accept array-like objects," it seems that you do not also mean key arrays (and later BInt arrays), but a reader might not realize this caveat.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the comments to add more clarity on this question - PTAL!

@jakevdp jakevdp force-pushed the typing-simple branch 2 times, most recently from 253890e to 6f7a1d5 Compare September 12, 2022 16:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants