diff --git a/requirements.txt b/requirements.txt index 55b554c07..4a85fc81b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,9 +8,9 @@ google-cloud-storage jax[tpu]==0.8.0 jaxlib==0.8.0 jaxtyping -flax==0.11.1 +flax==0.12.0 torchax==0.0.7 -qwix==0.1.1 +qwix==0.1.4 torchvision==0.24.0 pathwaysutils parameterized diff --git a/tests/models/jax/test_weight_loading.py b/tests/models/jax/test_weight_loading.py index aa4a8fe2e..db07ff36d 100644 --- a/tests/models/jax/test_weight_loading.py +++ b/tests/models/jax/test_weight_loading.py @@ -31,7 +31,7 @@ class SourceModel(nnx.Module): def __init__(self, rngs): self.src_lm_head = nnx.Param(jax.random.normal(rngs(), (2, 4))) - self.layers = {0: SourceLayer(rngs)} + self.layers = nnx.List({0: SourceLayer(rngs)}) class TargetLinear(nnx.Module): diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 238b123b3..01004a1a4 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -96,7 +96,9 @@ def create_jit_model( The jitted model. """ state = nnx.state(model) - nnx.update(model, state) + pspecs = nnx.get_partition_spec(state) + sharded_state = jax.lax.with_sharding_constraint(state, pspecs) + nnx.update(model, sharded_state) if not use_qwix_on_abstract_model: # NOTE: if Qwix is not configured, this will be a no-op model = apply_qwix_quantization(vllm_config, @@ -127,7 +129,7 @@ def create_jit_model( "quantization_config") else {} load_random_weights_into_qwix_abstract_model( rng, model, mesh, quantization_config) - with mesh: + with jax.set_mesh(mesh): jit_model = create_jit_model(model, use_qwix_on_abstract_model=True) return jit_model @@ -142,7 +144,7 @@ def create_sharded_model(): # NOTE: we don't support quantization for the old Qwen2ForCausalLM implementation return model - with mesh: + with jax.set_mesh(mesh): jit_model = create_sharded_model() # In this case, we are applying Qwix quantization to the true, concrete model jit_model = apply_qwix_quantization(vllm_config, @@ -179,7 +181,7 @@ def create_sharded_model(): # Although the created model can already work, we still need to jit # the model creation again, otherwise the model forward will have # non-trivial overhead in PjitFunction. - with mesh: + with jax.set_mesh(mesh): loader = get_model_loader(vllm_config.load_config) if isinstance(loader, RunaiModelStreamerLoader): model_weights = vllm_config.model_config.model diff --git a/tpu_inference/models/jax/deepseek_v3.py b/tpu_inference/models/jax/deepseek_v3.py index 15686e881..ec63a8663 100644 --- a/tpu_inference/models/jax/deepseek_v3.py +++ b/tpu_inference/models/jax/deepseek_v3.py @@ -307,6 +307,8 @@ def _create_mla() -> MLA: vd_sharding=(('data', 'expert', 'model'), None), dv_sharding=(None, ('data', 'expert', 'model')), random_init=self.random_init) + # For Flax 0.12.0 compatibility. + self.layers = nnx.List(self.layers) # For compatibility with flax. def apply(self, variables, *args, **kwargs): diff --git a/tpu_inference/models/jax/gpt_oss.py b/tpu_inference/models/jax/gpt_oss.py index 45fc66a6e..992a996e9 100644 --- a/tpu_inference/models/jax/gpt_oss.py +++ b/tpu_inference/models/jax/gpt_oss.py @@ -182,6 +182,9 @@ def __init__(self, random_init=self.random_init, ) + # For Flax 0.12.0 compatibility. + self.layers = nnx.List(self.layers) + # For compatibility with flax. def apply(self, variables, *args, **kwargs): return self.__call__(*args, **kwargs) diff --git a/tpu_inference/models/jax/llama3.py b/tpu_inference/models/jax/llama3.py index d3f6ee3fc..17fc113fb 100644 --- a/tpu_inference/models/jax/llama3.py +++ b/tpu_inference/models/jax/llama3.py @@ -243,7 +243,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, init_fn, (ShardingAxisName.VOCAB, None)), rngs=rng, ) - self.layers = [ + self.layers = nnx.List( LlamaDecoderLayer( config=hf_config, dtype=dtype, @@ -251,8 +251,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, mesh=mesh, # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly kv_cache_dtype=vllm_config.cache_config.cache_dtype) - for _ in range(hf_config.num_hidden_layers) - ] + for _ in range(hf_config.num_hidden_layers)) self.norm = nnx.RMSNorm( hidden_size, epsilon=rms_norm_eps, diff --git a/tpu_inference/models/jax/llama4.py b/tpu_inference/models/jax/llama4.py index 0e65f5ccc..3b16533f8 100644 --- a/tpu_inference/models/jax/llama4.py +++ b/tpu_inference/models/jax/llama4.py @@ -216,6 +216,9 @@ def __init__(self, use_attention_rope=use_attention_rope) self.layers.append(block) + # For Flax 0.12.0 compatibility. + self.layers = nnx.List(self.layers) + self.final_norm = RMSNorm( dims=self.hidden_size, epsilon=self.rms_norm_eps, diff --git a/tpu_inference/models/jax/llama_eagle3.py b/tpu_inference/models/jax/llama_eagle3.py index 3f4d86822..871bb53fc 100644 --- a/tpu_inference/models/jax/llama_eagle3.py +++ b/tpu_inference/models/jax/llama_eagle3.py @@ -141,6 +141,8 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, mesh: Mesh): # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly kv_cache_dtype=vllm_config.cache_config.cache_dtype) ] + # For Flax 0.12.0 compatibility. + self.layers = nnx.List(self.layers) if hasattr(hf_config, "target_hidden_size"): input_size = hf_config.target_hidden_size * 3 diff --git a/tpu_inference/models/jax/llama_guard_4.py b/tpu_inference/models/jax/llama_guard_4.py index 827f1e0d4..d7c03442f 100644 --- a/tpu_inference/models/jax/llama_guard_4.py +++ b/tpu_inference/models/jax/llama_guard_4.py @@ -82,7 +82,7 @@ def __init__(self, self.layers = [] - for i in range(self.num_layers): + for _ in range(self.num_layers): use_attention_rope = True custom_module = DenseFFW(dtype=self.dtype, @@ -161,6 +161,9 @@ def __init__(self, use_attention_rope=use_attention_rope) self.layers.append(block) + # For Flax 0.12.0 compatibility. + self.layers = nnx.List(self.layers) + self.final_norm = RMSNorm( dims=self.hidden_size, activation_ffw_td=P(), diff --git a/tpu_inference/models/jax/qwen2.py b/tpu_inference/models/jax/qwen2.py index eb5aa6897..ae513f558 100644 --- a/tpu_inference/models/jax/qwen2.py +++ b/tpu_inference/models/jax/qwen2.py @@ -242,7 +242,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, embedding_init=nnx.with_partitioning(init_fn, ("model", None)), rngs=rng, ) - self.layers = [ + self.layers = nnx.List([ Qwen2DecoderLayer( config=hf_config, dtype=dtype, @@ -251,7 +251,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly kv_cache_dtype=vllm_config.cache_config.cache_dtype) for _ in range(hf_config.num_hidden_layers) - ] + ]) self.norm = nnx.RMSNorm( hidden_size, epsilon=rms_norm_eps, diff --git a/tpu_inference/models/jax/qwen3.py b/tpu_inference/models/jax/qwen3.py index 55442b57e..3fa80789f 100644 --- a/tpu_inference/models/jax/qwen3.py +++ b/tpu_inference/models/jax/qwen3.py @@ -195,7 +195,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, embedding_init=nnx.with_partitioning(init_fn, ("model", None)), rngs=rng, ) - self.layers = [ + self.layers = nnx.List([ Qwen3DecoderLayer( config=hf_config, dtype=dtype, @@ -204,7 +204,7 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly kv_cache_dtype=vllm_config.cache_config.cache_dtype) for _ in range(hf_config.num_hidden_layers) - ] + ]) self.norm = nnx.RMSNorm( hidden_size, epsilon=rms_norm_eps, diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 256798093..96249a37d 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, cast +import flax import jax import jax.numpy as jnp import jaxtyping @@ -283,6 +284,10 @@ def __init__( self._substitute_placeholder_token_fn = _substitute_placeholder_token self.execute_model_state: ExecuteModelState | None = None + # TODO (jacobplatin): eventually, we'll want to use eager sharding in Flax + # but as a hotfix to upgrade to Flax 0.12.0, we'll disable it + flax.config.update('flax_always_shard_variable', False) + def _init_random(self): if self.model_config.seed is None: self.model_config.seed = 0