-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' #198
Comments
Looks like they're not getting de/serialised correctly, so the If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement |
I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.
|
Do you have a MWE? |
Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state. `lr_scheduler = optax.warmup_cosine_decay_schedule( optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array)) checkpoint_params = { with open(checkpoint_params_file, "wb") as f: |
I am also encountering this issue, but only with pip install jaxtyping jax 'ray[default]' import jax
import ray
from jax import numpy as jnp
from jaxtyping import Int
ray.init()
@ray.remote(max_retries=0)
def f(x: Int[jax.Array, "one two"]):
return x * 2
a = ray.put(jnp.arange(10))
ray.get(f.remote(a)) I tried implementing # jaxtyping/_array_types.py
@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
# ...
def __getstate__(cls):
return cls._get_props()
def __setstate__(cls, props):
(
cls.index_variadic,
cls.dims,
cls.array_type,
cls.dtypes,
cls.dim_str,
) = props
# ... But as best I can tell, neither one gets called at all. |
It looks like and then immediately tries to hash it: which fails, as this class does not yet have our attributes set.
Thank you for the MWE, that was invaluable to figure this one out! :) |
You guys, gals, and nonbinary pals rock!! |
Seem to have a related issue with Grain dataloader, which involve also cloudpickle and index_variadic. This error only happens when I set worker_count > 0: ERROR:absl:Error occurred in child process with worker_index: 7 |
Ah, this has already been fixed and I just haven't done a new release for it yet. I've done a version bump + new release in #246 |
I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:
The text was updated successfully, but these errors were encountered: