-
Notifications
You must be signed in to change notification settings - Fork 3k
jax._src.typing: add basic types and use them in lax.py #12300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b0c712d
to
35a3e3e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to add unit tests? For example, can/should we test that various types that we expect to check out as HasDtypeAttribute
indeed do?
Otherwise, I think we still need to figure out how opaque-dtype arrays (like key arrays, soon BInt
arrays too) fit in to the picture – see inline comment.
Array = Any | ||
|
||
# ArrayLike is a Union of all objects that can be implicitly converted to a JAX array. | ||
ArrayLike = Union[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does jax.random.KeyArray
go in here too? If so, then several of the lax.py
annotations in this PR are unsound, e.g. you cannot lax.sin
a key array. But if not, then several of the lax.py
annotations are incomplete, e.g. you can lax.expand_dims
a key array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect KeyArray
objects to eventually be valid in all lax routines? If not, then I think we should use Union
annotations on a case-by-case basis for the functions that support KeyArray
.
Array
here is meant to be a lowest-common-denominator type for all JAX public APIs that accept array-like objects – it's always possible to add KeyArray
in a Union
where relevant, but it's not easy to take it away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I wrote above is both the current and eventual intended situation: KeyArray
is valid for some lax (and lax-numpy) operations, but not others. One cannot sin
or +
key arrays, but it is possible to expand_dims
, squeeze
, etc. A union sounds good, at least for correcting the annotations in this PR. Such a union will likely expand at some point to also account for BInt
arrays.
Array
here is meant to be a lowest-common-denominator type for all JAX public APIs that accept array-like objects
Even if a union resolves this in code, it seems that this statement could be refined in light of the above (which maybe matters if it appears in documentation). When you say "[...] APIs that accept array-like objects," it seems that you do not also mean key arrays (and later BInt
arrays), but a reader might not realize this caveat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the comments to add more clarity on this question - PTAL!
253890e
to
6f7a1d5
Compare
6f7a1d5
to
cc72a20
Compare
Part of #12049; this implements some of the simpler types discussed in #11859
The question of what to do with
Array
is yet to be resolved, but this PR creates the skeleton into which we can insert the eventual solution.For now, I'm keeping these as private APIs, in
jax._src.typing
. Once we're certain we're happy with them, we can plan to export them injax.typing
.