diff --git a/tunix/models/gemma/model.py b/tunix/models/gemma/model.py index 2243d50f5..b1efc6577 100644 --- a/tunix/models/gemma/model.py +++ b/tunix/models/gemma/model.py @@ -61,7 +61,6 @@ class ShardingConfig: ffw_weight_fd: Tuple[str | None, ...] rms_norm_weight: Tuple[str | None, ...] act_btd: Tuple[str | None, ...] - act_btf: Tuple[str | None, ...] act_btnh: Tuple[str | None, ...] score_weight_d1: Tuple[str | None, ...] @@ -79,7 +78,6 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, None), act_btnh=('fsdp', None, 'tp', None), score_weight_d1=(fsdp, None), ) @@ -586,9 +584,9 @@ def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: ff1 = self.up_proj(x) activations = gate_value * ff1 - activations = shard(activations, self.shd_config.act_btf) outputs = self.down_proj(activations) + outputs = shard(outputs, self.shd_config.act_btd) return outputs diff --git a/tunix/models/gemma3/model.py b/tunix/models/gemma3/model.py index cc42ea8fa..a7563f3e4 100644 --- a/tunix/models/gemma3/model.py +++ b/tunix/models/gemma3/model.py @@ -55,7 +55,6 @@ class ShardingConfig: ffw_weight_fd: Tuple[str | None, ...] rms_norm_weight: Tuple[str | None, ...] act_btd: Tuple[str | None, ...] - act_btf: Tuple[str | None, ...] act_btnh: Tuple[str | None, ...] @staticmethod @@ -72,7 +71,6 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, None), act_btnh=('fsdp', None, 'tp', None), ) @@ -715,10 +713,10 @@ def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: with jax.named_scope('up_proj'): ff1 = self.up_proj(x) activations = gate_value * ff1 - activations = shard(activations, self.shd_config.act_btf) with jax.named_scope('down_proj'): outputs = self.down_proj(activations) + outputs = shard(outputs, self.shd_config.act_btd) return outputs diff --git a/tunix/models/llama3/model.py b/tunix/models/llama3/model.py index 3e456503a..2598f5a22 100644 --- a/tunix/models/llama3/model.py +++ b/tunix/models/llama3/model.py @@ -56,7 +56,6 @@ class ShardingConfig: ffw_weight_fd: Tuple[str | None, ...] rms_norm_weight: Tuple[str | None, ...] act_btd: Tuple[str | None, ...] - act_btf: Tuple[str | None, ...] act_btnh: Tuple[str | None, ...] @staticmethod @@ -464,8 +463,8 @@ def __init__( @jax.named_scope('feed_forward') def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) - activations = shard(activations, self.shd_config.act_btf) outputs = self.down_proj(activations) + outputs = shard(outputs, self.shd_config.act_btd) return outputs diff --git a/tunix/models/qwen2/model.py b/tunix/models/qwen2/model.py index 893ef7371..c93f9f09b 100644 --- a/tunix/models/qwen2/model.py +++ b/tunix/models/qwen2/model.py @@ -55,7 +55,6 @@ class ShardingConfig: ffw_weight_fd: Tuple[str | None, ...] rms_norm_weight: Tuple[str | None, ...] act_btd: Tuple[str | None, ...] - act_btf: Tuple[str | None, ...] act_btnh: Tuple[str | None, ...] exp_weight_cdf: Tuple[str | None, ...] exp_weight_cfd: Tuple[str | None, ...] @@ -75,7 +74,6 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, None), act_btnh=('fsdp', None, 'tp', None), exp_weight_cdf=('fsdp', None, 'tp'), exp_weight_cfd=('fsdp', 'tp', None), @@ -547,8 +545,8 @@ def __init__( @jax.named_scope('feed_forward') def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) - activations = shard(activations, self.shd_config.act_btf) outputs = self.down_proj(activations) + outputs = shard(outputs, self.shd_config.act_btd) return outputs diff --git a/tunix/models/qwen3/model.py b/tunix/models/qwen3/model.py index ff055d6b3..fa522abcc 100644 --- a/tunix/models/qwen3/model.py +++ b/tunix/models/qwen3/model.py @@ -56,7 +56,6 @@ class ShardingConfig: ffw_weight_fd: Tuple[str | None, ...] rms_norm_weight: Tuple[str | None, ...] act_btd: Tuple[str | None, ...] - act_btf: Tuple[str | None, ...] act_btnh: Tuple[str | None, ...] exp_weight_cdf: Tuple[str | None, ...] exp_weight_cfd: Tuple[str | None, ...] @@ -75,7 +74,6 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, None), act_btnh=('fsdp', None, 'tp', None), exp_weight_cdf=('fsdp', None, 'tp'), exp_weight_cfd=('fsdp', 'tp', None), @@ -539,8 +537,8 @@ def __call__(self, x): activations = nnx.silu( jnp.einsum('BTD,DF->BTF', expert_input, self.gate_proj[i]) ) * jnp.einsum('BTD,DF->BTF', expert_input, self.up_proj[i]) - activations = shard(activations, self.shd_config.act_btf) expert_output = jnp.einsum('BTF,FD->BTD', activations, self.down_proj[i]) + expert_output = shard(expert_output, self.shd_config.act_btd) expert_outputs.append(expert_output) stacked_outputs = jnp.stack(expert_outputs, axis=2) # [B, T, E, D] @@ -595,8 +593,8 @@ def __init__( @jax.named_scope('feed_forward') def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) - activations = shard(activations, self.shd_config.act_btf) outputs = self.down_proj(activations) + outputs = shard(outputs, self.shd_config.act_btd) return outputs