@@ -96,7 +96,9 @@ def create_jit_model(
9696 The jitted model.
9797 """
9898 state = nnx .state (model )
99- nnx .update (model , state )
99+ pspecs = nnx .get_partition_spec (state )
100+ sharded_state = jax .lax .with_sharding_constraint (state , pspecs )
101+ nnx .update (model , sharded_state )
100102 if not use_qwix_on_abstract_model :
101103 # NOTE: if Qwix is not configured, this will be a no-op
102104 model = apply_qwix_quantization (vllm_config ,
@@ -142,7 +144,7 @@ def create_sharded_model():
142144 # NOTE: we don't support quantization for the old Qwen2ForCausalLM implementation
143145 return model
144146
145- with mesh :
147+ with jax . set_mesh ( mesh ) :
146148 jit_model = create_sharded_model ()
147149 # In this case, we are applying Qwix quantization to the true, concrete model
148150 jit_model = apply_qwix_quantization (vllm_config ,
@@ -179,7 +181,7 @@ def create_sharded_model():
179181 # Although the created model can already work, we still need to jit
180182 # the model creation again, otherwise the model forward will have
181183 # non-trivial overhead in PjitFunction.
182- with mesh :
184+ with jax . set_mesh ( mesh ) :
183185 loader = get_model_loader (vllm_config .load_config )
184186 if isinstance (loader , RunaiModelStreamerLoader ):
185187 model_weights = vllm_config .model_config .model
0 commit comments