Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ google-cloud-storage
jax[tpu]==0.8.0
jaxlib==0.8.0
jaxtyping
flax==0.11.1
flax==0.12.0
torchax==0.0.7
qwix==0.1.1
qwix==0.1.4
torchvision==0.24.0
pathwaysutils
parameterized
Expand Down
2 changes: 1 addition & 1 deletion tests/models/jax/test_weight_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SourceModel(nnx.Module):

def __init__(self, rngs):
self.src_lm_head = nnx.Param(jax.random.normal(rngs(), (2, 4)))
self.layers = {0: SourceLayer(rngs)}
self.layers = nnx.List({0: SourceLayer(rngs)})


class TargetLinear(nnx.Module):
Expand Down
10 changes: 6 additions & 4 deletions tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def create_jit_model(
The jitted model.
"""
state = nnx.state(model)
nnx.update(model, state)
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state)
if not use_qwix_on_abstract_model:
# NOTE: if Qwix is not configured, this will be a no-op
model = apply_qwix_quantization(vllm_config,
Expand Down Expand Up @@ -127,7 +129,7 @@ def create_jit_model(
"quantization_config") else {}
load_random_weights_into_qwix_abstract_model(
rng, model, mesh, quantization_config)
with mesh:
with jax.set_mesh(mesh):
jit_model = create_jit_model(model,
use_qwix_on_abstract_model=True)
return jit_model
Expand All @@ -142,7 +144,7 @@ def create_sharded_model():
# NOTE: we don't support quantization for the old Qwen2ForCausalLM implementation
return model

with mesh:
with jax.set_mesh(mesh):
jit_model = create_sharded_model()
# In this case, we are applying Qwix quantization to the true, concrete model
jit_model = apply_qwix_quantization(vllm_config,
Expand Down Expand Up @@ -179,7 +181,7 @@ def create_sharded_model():
# Although the created model can already work, we still need to jit
# the model creation again, otherwise the model forward will have
# non-trivial overhead in PjitFunction.
with mesh:
with jax.set_mesh(mesh):
loader = get_model_loader(vllm_config.load_config)
if isinstance(loader, RunaiModelStreamerLoader):
model_weights = vllm_config.model_config.model
Expand Down
2 changes: 2 additions & 0 deletions tpu_inference/models/jax/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def _create_mla() -> MLA:
vd_sharding=(('data', 'expert', 'model'), None),
dv_sharding=(None, ('data', 'expert', 'model')),
random_init=self.random_init)
# For Flax 0.12.0 compatibility.
self.layers = nnx.List(self.layers)

# For compatibility with flax.
def apply(self, variables, *args, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions tpu_inference/models/jax/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def __init__(self,
random_init=self.random_init,
)

# For Flax 0.12.0 compatibility.
self.layers = nnx.List(self.layers)

# For compatibility with flax.
def apply(self, variables, *args, **kwargs):
return self.__call__(*args, **kwargs)
Expand Down
5 changes: 2 additions & 3 deletions tpu_inference/models/jax/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,15 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
init_fn, (ShardingAxisName.VOCAB, None)),
rngs=rng,
)
self.layers = [
self.layers = nnx.List(
LlamaDecoderLayer(
config=hf_config,
dtype=dtype,
rng=rng,
mesh=mesh,
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
for _ in range(hf_config.num_hidden_layers)
]
for _ in range(hf_config.num_hidden_layers))
self.norm = nnx.RMSNorm(
hidden_size,
epsilon=rms_norm_eps,
Expand Down
3 changes: 3 additions & 0 deletions tpu_inference/models/jax/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def __init__(self,
use_attention_rope=use_attention_rope)
self.layers.append(block)

# For Flax 0.12.0 compatibility.
self.layers = nnx.List(self.layers)

self.final_norm = RMSNorm(
dims=self.hidden_size,
epsilon=self.rms_norm_eps,
Expand Down
2 changes: 2 additions & 0 deletions tpu_inference/models/jax/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, mesh: Mesh):
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
]
# For Flax 0.12.0 compatibility.
self.layers = nnx.List(self.layers)

if hasattr(hf_config, "target_hidden_size"):
input_size = hf_config.target_hidden_size * 3
Expand Down
5 changes: 4 additions & 1 deletion tpu_inference/models/jax/llama_guard_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self,

self.layers = []

for i in range(self.num_layers):
for _ in range(self.num_layers):
use_attention_rope = True

custom_module = DenseFFW(dtype=self.dtype,
Expand Down Expand Up @@ -161,6 +161,9 @@ def __init__(self,
use_attention_rope=use_attention_rope)
self.layers.append(block)

# For Flax 0.12.0 compatibility.
self.layers = nnx.List(self.layers)

self.final_norm = RMSNorm(
dims=self.hidden_size,
activation_ffw_td=P(),
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/models/jax/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
rngs=rng,
)
self.layers = [
self.layers = nnx.List([
Qwen2DecoderLayer(
config=hf_config,
dtype=dtype,
Expand All @@ -251,7 +251,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
for _ in range(hf_config.num_hidden_layers)
]
])
self.norm = nnx.RMSNorm(
hidden_size,
epsilon=rms_norm_eps,
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/models/jax/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
rngs=rng,
)
self.layers = [
self.layers = nnx.List([
Qwen3DecoderLayer(
config=hf_config,
dtype=dtype,
Expand All @@ -204,7 +204,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
for _ in range(hf_config.num_hidden_layers)
]
])
self.norm = nnx.RMSNorm(
hidden_size,
epsilon=rms_norm_eps,
Expand Down
5 changes: 5 additions & 0 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

import flax
import jax
import jax.numpy as jnp
import jaxtyping
Expand Down Expand Up @@ -283,6 +284,10 @@ def __init__(
self._substitute_placeholder_token_fn = _substitute_placeholder_token
self.execute_model_state: ExecuteModelState | None = None

# TODO (jacobplatin): eventually, we'll want to use eager sharding in Flax
# but as a hotfix to upgrade to Flax 0.12.0, we'll disable it
flax.config.update('flax_always_shard_variable', False)

def _init_random(self):
if self.model_config.seed is None:
self.model_config.seed = 0
Expand Down