Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' #198

Closed
jaanli opened this issue Apr 5, 2024 · 9 comments · Fixed by #237
Labels
bug Something isn't working

Comments

@jaanli
Copy link

jaanli commented Apr 5, 2024

I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:

joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
    return _lookup_class_or_track(class_tracker_id, skeleton_class)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
    _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
  File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
    self.data[ref(key, self._remove)] = value
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
    return hash(cls._get_props())
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
    cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
    main(cfg)
  File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
    trainer.evaluate(
  File "/home/ray/trainer.py", line 427, in evaluate
    predictions, objective = model_and_objective(batch)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
    predictions = Parallel(n_jobs=self.n_jobs)(
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
    return _lookup_class_or_track(class_tracker_id, skeleton_class)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
    _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
  File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
    self.data[ref(key, self._remove)] = value
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
    return hash(cls._get_props())
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
    cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
    main(cfg)
  File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
    trainer.evaluate(
  File "/home/ray/trainer.py", line 427, in evaluate
    predictions, objective = model_and_objective(batch)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
    predictions = Parallel(n_jobs=self.n_jobs)(
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
@patrick-kidger
Copy link
Owner

Looks like they're not getting de/serialised correctly, so the index_variadic attribute doesn't make it across.

If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__ and __getstate__?)

@patrick-kidger patrick-kidger added the bug Something isn't working label Apr 5, 2024
@sachith-gunasekara
Copy link

I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.

File "/root/optimizer.py", line 89, in run_train_on_modal optimizer_state = optax.tree_utils.tree_set(optimizer_state, inner_state=cloudpickle.load(f)) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class return _lookup_class_or_track(class_tracker_id, skeleton_class) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^ File "/usr/lib/python3.11/weakref.py", line 428, in __setitem__ self.data[ref(key, self._remove)] = value ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 321, in __hash__ return hash(cls._get_props()) ^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 306, in _get_props cls.index_variadic, ^^^^^^^^^^^^^^^^^^ AttributeError: type object 'Float[Array, '*shape']' has no attribute 'index_variadic'

@patrick-kidger
Copy link
Owner

Do you have a MWE?

@sachith-gunasekara
Copy link

sachith-gunasekara commented May 31, 2024

Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state.

`lr_scheduler = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=learning_rate,
warmup_steps=warmup_iters if init_from == 'scratch' else 0,
decay_steps=lr_decay_iters - iter_num,
end_value=min_lr,
)
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler, b1=beta1, b2=beta2)

optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array))

checkpoint_params = {
"optimizer_state": optimizer_state
}

with open(checkpoint_params_file, "wb") as f:
cloudpickle.dump(checkpoint_params, f)`

@LoganWalls
Copy link

Looks like they're not getting de/serialised correctly, so the index_variadic attribute doesn't make it across.

If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__ and __getstate__?)

I am also encountering this issue, but only with ray (just using cloudpickle on its own seems to work now). This MWE reproduces the issue:

pip install jaxtyping jax 'ray[default]'
import jax
import ray
from jax import numpy as jnp

from jaxtyping import Int


ray.init()


@ray.remote(max_retries=0)
def f(x: Int[jax.Array, "one two"]):
    return x * 2


a = ray.put(jnp.arange(10))
ray.get(f.remote(a))

I tried implementing __setstate__ and __getstate__ as follows:

# jaxtyping/_array_types.py

@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
    class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
        # ...

        def __getstate__(cls):
            return cls._get_props()

        def __setstate__(cls, props):
            (
                cls.index_variadic,
                cls.dims,
                cls.array_type,
                cls.dtypes,
                cls.dim_str,
            ) = props
        
        # ...

But as best I can tell, neither one gets called at all.

@patrick-kidger
Copy link
Owner

It looks like ray.cloudpickle.cloudpickle. internally synthesises a class via types.new_class:

https://github.com/ray-project/ray/blob/200c54859dc87f02f7b40e003917b53e68356a60/python/ray/cloudpickle/cloudpickle.py#L536

and then immediately tries to hash it:

https://github.com/ray-project/ray/blob/200c54859dc87f02f7b40e003917b53e68356a60/python/ray/cloudpickle/cloudpickle.py#L124

which fails, as this class does not yet have our attributes set.

ray's approach seems a bit dodgy due to exactly the kind of failure we're seeing here! Anyway, I've worked around this in #237 by just always hashing to zero.

Thank you for the MWE, that was invaluable to figure this one out! :)

@jaanli
Copy link
Author

jaanli commented Aug 1, 2024

You guys, gals, and nonbinary pals rock!!

@danbnyn
Copy link

danbnyn commented Sep 1, 2024

Seem to have a related issue with Grain dataloader, which involve also cloudpickle and index_variadic. This error only happens when I set worker_count > 0:

ERROR:absl:Error occurred in child process with worker_index: 7
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 176, in _worker_loop
element_producer = _get_element_producer_from_queue(
File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 148, in _get_element_producer_from_queue
element_producer_fn: GetElementProducerFn[Any] = cloudpickle.loads(
File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 539, in _make_skeleton_class
return _lookup_class_or_track(class_tracker_id, skeleton_class)
File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 124, in _lookup_class_or_track
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
File "/usr/local/lib/python3.10/weakref.py", line 429, in setitem
self.data[ref(key, self._remove)] = value
File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 339, in hash
return hash(cls._get_props())
File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 324, in _get_props
cls.index_variadic,
AttributeError: type object 'Float[Array, 'N C H W']' has no attribute 'index_variadic'
The above exception was the direct cause of the following exception:

@patrick-kidger
Copy link
Owner

Ah, this has already been fixed and I just haven't done a new release for it yet.

I've done a version bump + new release in #246

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants