diff --git a/tunix/models/gemma/model.py b/tunix/models/gemma/model.py index e58c2d99..5186cee5 100644 --- a/tunix/models/gemma/model.py +++ b/tunix/models/gemma/model.py @@ -68,6 +68,7 @@ class ShardingConfig: @staticmethod def get_default_sharding(is_sampling: bool = False): fsdp = 'fsdp' if not is_sampling else None + b_sharding = tuple(filter(None, ('dp', fsdp))) return ShardingConfig( emb_vd=('tp', fsdp), @@ -78,9 +79,9 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_df=(fsdp, 'tp'), ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), - act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, 'tp'), - act_btnh=('fsdp', None, 'tp', None), + act_btd=(b_sharding, None, None if is_sampling else 'tp'), + act_btf=(b_sharding, None, 'tp'), + act_btnh=(b_sharding, None, 'tp', None), score_weight_d1=(fsdp, None), ) diff --git a/tunix/models/gemma3/model.py b/tunix/models/gemma3/model.py index 80cf7313..71084f78 100644 --- a/tunix/models/gemma3/model.py +++ b/tunix/models/gemma3/model.py @@ -61,6 +61,7 @@ class ShardingConfig: @staticmethod def get_default_sharding(is_sampling: bool = False): fsdp = 'fsdp' if not is_sampling else None + b_sharding = tuple(filter(None, ('dp', fsdp))) return ShardingConfig( emb_vd=('tp', fsdp), @@ -71,9 +72,9 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_df=(fsdp, 'tp'), ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), - act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, 'tp'), - act_btnh=('fsdp', None, 'tp', None), + act_btd=(b_sharding, None, None if is_sampling else 'tp'), + act_btf=(b_sharding, None, 'tp'), + act_btnh=(b_sharding, None, 'tp', None), ) diff --git a/tunix/models/llama3/model.py b/tunix/models/llama3/model.py index 9bbcf475..a7c26073 100644 --- a/tunix/models/llama3/model.py +++ b/tunix/models/llama3/model.py @@ -62,6 +62,7 @@ class ShardingConfig: @staticmethod def get_default_sharding(is_sampling: bool = False): fsdp = 'fsdp' if not is_sampling else None + b_sharding = tuple(filter(None, ('dp', fsdp))) return ShardingConfig( emb_vd=('tp', fsdp), @@ -72,9 +73,9 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_df=(fsdp, 'tp'), ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), - act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, 'tp'), - act_btnh=('fsdp', None, 'tp', None), + act_btd=(b_sharding, None, None if is_sampling else 'tp'), + act_btf=(b_sharding, None, 'tp'), + act_btnh=(b_sharding, None, 'tp', None), ) diff --git a/tunix/models/qwen2/model.py b/tunix/models/qwen2/model.py index 94a7f066..64f1fcd5 100644 --- a/tunix/models/qwen2/model.py +++ b/tunix/models/qwen2/model.py @@ -64,6 +64,7 @@ class ShardingConfig: @staticmethod def get_default_sharding(is_sampling: bool = False): fsdp = 'fsdp' if not is_sampling else None + b_sharding = tuple(filter(None, ('dp', fsdp))) return ShardingConfig( emb_vd=('tp', fsdp), @@ -74,9 +75,9 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_df=(fsdp, 'tp'), ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), - act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, 'tp'), - act_btnh=('fsdp', None, 'tp', None), + act_btd=(b_sharding, None, None if is_sampling else 'tp'), + act_btf=(b_sharding, None, 'tp'), + act_btnh=(b_sharding, None, 'tp', None), exp_weight_cdf=('fsdp', None, 'tp'), exp_weight_cfd=('fsdp', 'tp', None), qkv_bias=('tp',), diff --git a/tunix/models/qwen3/model.py b/tunix/models/qwen3/model.py index b9eeba9f..b81e4e19 100644 --- a/tunix/models/qwen3/model.py +++ b/tunix/models/qwen3/model.py @@ -64,6 +64,7 @@ class ShardingConfig: @staticmethod def get_default_sharding(is_sampling: bool = False): fsdp = 'fsdp' if not is_sampling else None + b_sharding = tuple(filter(None, ('dp', fsdp))) return ShardingConfig( emb_vd=('tp', fsdp), @@ -74,9 +75,9 @@ def get_default_sharding(is_sampling: bool = False): ffw_weight_df=(fsdp, 'tp'), ffw_weight_fd=('tp', fsdp), rms_norm_weight=('tp',), - act_btd=('fsdp', None, None if is_sampling else 'tp'), - act_btf=('fsdp', None, 'tp'), - act_btnh=('fsdp', None, 'tp', None), + act_btd=(b_sharding, None, None if is_sampling else 'tp'), + act_btf=(b_sharding, None, 'tp'), + act_btnh=(b_sharding, None, 'tp', None), exp_weight_cdf=('fsdp', None, 'tp'), exp_weight_cfd=('fsdp', 'tp', None), )