Skip to content

Conversation

@Lumosis
Copy link
Collaborator

@Lumosis Lumosis commented Dec 10, 2025

Description

device_put will introduce extra overhead by triggering tensor checks (github.com/jax-ml/jax/blob/main/jax/experimental/multihost_utils.py#L162-L166) when using multi-host.

Tests

Tested on v6e-8 with llama 8b and on v7x-16 with deepseek (20 layers)

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a Github issue, please include a link, e.g.,:
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@Lumosis Lumosis requested a review from bzgoogle December 10, 2025 22:17
if sharding is None:
sharding = NamedSharding(mesh, PartitionSpec(None))
return jax.device_put(*args, device=sharding, **kwargs)
return jax.make_array_from_process_local_data(sharding, *args)
Copy link
Collaborator

@py4 py4 Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing this here and which problem does it solve? device_put also creates global array. The usage for make_array_from_process_local_data is when the local data (cpu) is sharded across hosts because of its size.

Imagine you want to load a 100GB Dataset onto your 8 TPUs (sharded).
Global Size: 100GB
Host 0 RAM: 64GB
Host 1 RAM: 64GB

In this case local data should be sharded (before transferring to TPU) because it won't fit on individual RAMs. This is when make_array_from_process_local_data is useful. But in our codebase even in multi-host setup, out local data on cpu is not sharded across hosts to my understanding.

The code will break in this setup (two hosts each 4 chips)

# CODE ON HOST 0
global_data = jnp.arange(8)  # Shape is (8,)

sharding = NamedSharding(mesh, P('data'))
global_shape = (8,)

# FAILURE: 
# JAX expects you to pass ONLY the local shard for Host 0 (size 4).
# You passed the global array (size 8).
arr = jax.make_array_from_process_local_data(
    sharding, 
    global_data,  # <--- TRAP! You passed the whole array.
    global_shape
)

that requires us to handle slicing manually in the code. device_put does the slicing automatically. to use make_array_from_process_local_data data must be sliced on each host first (unless it's being fully replicated) => all prepare_inputs in the codebase must be modified to slice per host.
i would suggest debugging deeper with jax team what is wrong with device_put (if that's the reason) in the original xprof. 4ms difference is unexpected

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants