Skip to content

Commit 64b4a4f

Browse files
committed
Fix sharding issue
Signed-off-by: Jacob Platin <[email protected]>
1 parent 08c2915 commit 64b4a4f

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tpu_inference/models/common/model_loader.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def create_jit_model(
9696
The jitted model.
9797
"""
9898
state = nnx.state(model)
99-
nnx.update(model, state)
99+
pspecs = nnx.get_partition_spec(state)
100+
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
101+
nnx.update(model, sharded_state)
100102
if not use_qwix_on_abstract_model:
101103
# NOTE: if Qwix is not configured, this will be a no-op
102104
model = apply_qwix_quantization(vllm_config,
@@ -127,7 +129,7 @@ def create_jit_model(
127129
"quantization_config") else {}
128130
load_random_weights_into_qwix_abstract_model(
129131
rng, model, mesh, quantization_config)
130-
with mesh:
132+
with jax.set_mesh(mesh):
131133
jit_model = create_jit_model(model,
132134
use_qwix_on_abstract_model=True)
133135
return jit_model
@@ -142,7 +144,7 @@ def create_sharded_model():
142144
# NOTE: we don't support quantization for the old Qwen2ForCausalLM implementation
143145
return model
144146

145-
with mesh:
147+
with jax.set_mesh(mesh):
146148
jit_model = create_sharded_model()
147149
# In this case, we are applying Qwix quantization to the true, concrete model
148150
jit_model = apply_qwix_quantization(vllm_config,
@@ -179,7 +181,7 @@ def create_sharded_model():
179181
# Although the created model can already work, we still need to jit
180182
# the model creation again, otherwise the model forward will have
181183
# non-trivial overhead in PjitFunction.
182-
with mesh:
184+
with jax.set_mesh(mesh):
183185
loader = get_model_loader(vllm_config.load_config)
184186
if isinstance(loader, RunaiModelStreamerLoader):
185187
model_weights = vllm_config.model_config.model

0 commit comments

Comments
 (0)