Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: can't type flax.struct.dataclass with vmapped functions #177

Open
Egiob opened this issue Feb 18, 2024 · 3 comments
Open

bug: can't type flax.struct.dataclass with vmapped functions #177

Egiob opened this issue Feb 18, 2024 · 3 comments
Labels
question User queries

Comments

@Egiob
Copy link

Egiob commented Feb 18, 2024

Hey, runtime type-checking seems to fail when providing a Flax dataclass to a vmapped function. I wasn't able to find related resources . Here is a minimal reproduction with the associated error.

import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array


@flax.struct.dataclass
class Data:
    a: Array


def f(x: Data) -> int:
    return 1


data = Data(a=jnp.ones(1, dtype=int))

jax.vmap(f)(data)

It raises the following error (with beartyping):

E jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of Data.
E The problem arose whilst typechecking argument 'a'.
E Called with arguments: {'self': Data(...), 'a': <object object at 0x7fc7c87e8fc0>}
E Parameter annotations: (self: Any, a: jax.Array).

Here are the versions I'm using:

flax==0.8.1
jax==0.4.21
jaxtyping==0.2.25

I tested, it works with chex.dataclass and equinox.Module, but I don't have the choice of using flax dataclasses in my case. Would love to find a workaround. Thanks!!

@patrick-kidger
Copy link
Owner

That's odd -- I've just tried running your code (with the same versions of each library) and don't see the same issue. Can you perhaps double-check in a new environment?

@Egiob
Copy link
Author

Egiob commented Feb 18, 2024

I see, my minimal reproduction was ambiguous. Sorry for that. I figured out that it depends on the order of the decorator @jaxtyped.
This fails:

import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped


@jaxtyped(typechecker=beartype.beartype)
@flax.struct.dataclass
class Data:
    a: Array


def f(x: Data) -> int:
    return 1


data = Data(a=jnp.ones(10, dtype=int))

jax.vmap(f)(data)

This doesn't fail:

import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped


@flax.struct.dataclass
@jaxtyped(typechecker=beartype.beartype)
class Data:
    a: Array


def f(x: Data) -> int:
    return 1


data = Data(a=jnp.ones(10, dtype=int))

jax.vmap(f)(data)

I think the error occured for me because I used the pytest hook, that should add the jaxtyped decorator on top according to the docs.

Tell me if you can reproduce this 😄 (I have beartype==0.17.2 but I don't think it matters)

@patrick-kidger
Copy link
Owner

Ah, thank you!

It looks like this is a bug in Flax itself. Here's a MWE that doesn't use jaxtyping:

import flax
import jax.tree_util as jtu

@flax.struct.dataclass
class A:
    x: int

    def __init__(self):
        pass

leaves, treedef = jtu.tree_flatten(A())
jtu.tree_unflatten(treedef, leaves)

It looks like the reason for this is that their tree-unflattening rule is using the __init__ method for their type, which is a long-standing gotcha when using JAX.

Unsurprisingly, I'd recommend using Equinox instead :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants