Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tpu_inference/runner/compilation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,11 +849,11 @@ def _precompile_structured_decoding(self) -> None:
num_reqs]
dummy_grammar_bitmask = self.runner.grammar_bitmask_cpu[:num_reqs]

(dummy_logits, dummy_require_struct_decoding,
(dummy_require_struct_decoding,
dummy_grammar_bitmask, arange) = device_array(
self.runner.mesh,
(dummy_logits, dummy_require_struct_decoding,
dummy_grammar_bitmask, self.runner.structured_decode_arange))
(dummy_require_struct_decoding, dummy_grammar_bitmask,
self.runner.structured_decode_arange))

self._run_compilation(
"structured_decode",
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
"""
if sharding is None:
sharding = NamedSharding(mesh, PartitionSpec(None))
return jax.device_put(*args, device=sharding, **kwargs)
return jax.make_array_from_process_local_data(sharding, *args)
Copy link
Collaborator

@py4 py4 Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing this here and which problem does it solve? device_put also creates global array. The usage for make_array_from_process_local_data is when the local data (cpu) is sharded across hosts because of its size.

Imagine you want to load a 100GB Dataset onto your 8 TPUs (sharded).
Global Size: 100GB
Host 0 RAM: 64GB
Host 1 RAM: 64GB

In this case local data should be sharded (before transferring to TPU) because it won't fit on individual RAMs. This is when make_array_from_process_local_data is useful. But in our codebase even in multi-host setup, out local data on cpu is not sharded across hosts to my understanding.

The code will break in this setup (two hosts each 4 chips)

# CODE ON HOST 0
global_data = jnp.arange(8)  # Shape is (8,)

sharding = NamedSharding(mesh, P('data'))
global_shape = (8,)

# FAILURE: 
# JAX expects you to pass ONLY the local shard for Host 0 (size 4).
# You passed the global array (size 8).
arr = jax.make_array_from_process_local_data(
    sharding, 
    global_data,  # <--- TRAP! You passed the whole array.
    global_shape
)

that requires us to handle slicing manually in the code. device_put does the slicing automatically. to use make_array_from_process_local_data data must be sliced on each host first (unless it's being fully replicated) => all prepare_inputs in the codebase must be modified to slice per host.
i would suggest debugging deeper with jax team what is wrong with device_put (if that's the reason) in the original xprof. 4ms difference is unexpected



def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
Expand Down