diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index aa70353ab..a66277329 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -849,11 +849,11 @@ def _precompile_structured_decoding(self) -> None: num_reqs] dummy_grammar_bitmask = self.runner.grammar_bitmask_cpu[:num_reqs] - (dummy_logits, dummy_require_struct_decoding, + (dummy_require_struct_decoding, dummy_grammar_bitmask, arange) = device_array( self.runner.mesh, - (dummy_logits, dummy_require_struct_decoding, - dummy_grammar_bitmask, self.runner.structured_decode_arange)) + (dummy_require_struct_decoding, dummy_grammar_bitmask, + self.runner.structured_decode_arange)) self._run_compilation( "structured_decode", diff --git a/tpu_inference/utils.py b/tpu_inference/utils.py index b147bbeea..288fe98fb 100644 --- a/tpu_inference/utils.py +++ b/tpu_inference/utils.py @@ -270,7 +270,7 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array: """ if sharding is None: sharding = NamedSharding(mesh, PartitionSpec(None)) - return jax.device_put(*args, device=sharding, **kwargs) + return jax.make_array_from_process_local_data(sharding, *args) def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: