diff --git a/tunix/models/qwen3/model.py b/tunix/models/qwen3/model.py index b9eeba9fd..d2c4d796c 100644 --- a/tunix/models/qwen3/model.py +++ b/tunix/models/qwen3/model.py @@ -172,6 +172,21 @@ def qwen3_30b(cls): # qwen3-30B num_experts_per_tok=8, ) + @classmethod + def qwen3_235b_a22b(cls): # qwen3-235B-A22B (MoE) + return cls( + num_layers=94, + vocab_size=151936, + embed_dim=4096, + hidden_dim=1536, + num_heads=64, + head_dim=128, + num_kv_heads=4, + norm_eps=1e-06, + rope_theta=1_000_000, + num_experts=128, + num_experts_per_tok=8, + ) def shard(x: jnp.ndarray, s: Tuple[str, ...]): mesh = pxla.thread_resources.env.physical_mesh