-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug: can't type flax.struct.dataclass with vmapped functions #177
Comments
That's odd -- I've just tried running your code (with the same versions of each library) and don't see the same issue. Can you perhaps double-check in a new environment? |
I see, my minimal reproduction was ambiguous. Sorry for that. I figured out that it depends on the order of the decorator @jaxtyped. import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped
@jaxtyped(typechecker=beartype.beartype)
@flax.struct.dataclass
class Data:
a: Array
def f(x: Data) -> int:
return 1
data = Data(a=jnp.ones(10, dtype=int))
jax.vmap(f)(data)
This doesn't fail: import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped
@flax.struct.dataclass
@jaxtyped(typechecker=beartype.beartype)
class Data:
a: Array
def f(x: Data) -> int:
return 1
data = Data(a=jnp.ones(10, dtype=int))
jax.vmap(f)(data) I think the error occured for me because I used the pytest hook, that should add the jaxtyped decorator on top according to the docs. Tell me if you can reproduce this 😄 (I have |
Ah, thank you! It looks like this is a bug in Flax itself. Here's a MWE that doesn't use jaxtyping: import flax
import jax.tree_util as jtu
@flax.struct.dataclass
class A:
x: int
def __init__(self):
pass
leaves, treedef = jtu.tree_flatten(A())
jtu.tree_unflatten(treedef, leaves) It looks like the reason for this is that their tree-unflattening rule is using the Unsurprisingly, I'd recommend using Equinox instead :) |
Hey, runtime type-checking seems to fail when providing a Flax dataclass to a vmapped function. I wasn't able to find related resources . Here is a minimal reproduction with the associated error.
It raises the following error (with beartyping):
Here are the versions I'm using:
I tested, it works with chex.dataclass and equinox.Module, but I don't have the choice of using flax dataclasses in my case. Would love to find a workaround. Thanks!!
The text was updated successfully, but these errors were encountered: