diff --git a/scripts/generate_tiny_models/for_causal_lm/glm4_moe_for_causal_lm.py b/scripts/generate_tiny_models/for_causal_lm/glm4_moe_for_causal_lm.py index b072179529..8b580279d0 100644 --- a/scripts/generate_tiny_models/for_causal_lm/glm4_moe_for_causal_lm.py +++ b/scripts/generate_tiny_models/for_causal_lm/glm4_moe_for_causal_lm.py @@ -33,14 +33,23 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) generation_config = GenerationConfig.from_pretrained(MODEL_ID) config = Glm4MoeConfig( - vocab_size=len(tokenizer.vocab), + vocab_size=151365, hidden_size=8, num_attention_heads=4, num_key_value_heads=2, num_hidden_layers=2, intermediate_size=32, + moe_intermediate_size=32, + head_dim=2, n_routed_experts=4, num_experts_per_tok=2, + attention_bias=True, + eos_token_id=[151329, 151336, 151338], + pad_token_id=151329, + rope_theta=1000000, + routed_scaling_factor=2.5, + use_qk_norm=True, + num_nextn_predict_layers=1, ) model = Glm4MoeForCausalLM(config).to(dtype=torch.bfloat16) init_weights_tiny_model(model) diff --git a/tests/conftest.py b/tests/conftest.py index f071b789ff..582f6321fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,7 @@ MODEL_REVISIONS = { # Add model_id: revision mappings here to test PRs + "trl-internal-testing/tiny-Glm4MoeForCausalLM": "refs/pr/1", } diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 79c38ba1ea..9828da63a8 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -193,14 +193,7 @@ class TestSupportsToolCalling: reason="Gemma4 models were introduced in transformers-5.5.0", ), ), - pytest.param( - "trl-internal-testing/tiny-Glm4MoeForCausalLM", - id="glm4moe", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="GLM4 tokenizer requires transformers>=5.0.0", - ), - ), + pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.1", id="llama3.1"), pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.2", id="llama3.2"), @@ -442,14 +435,7 @@ def test_prefix_preserving_template_processor(self): pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM", id="deepseekv3"), pytest.param("trl-internal-testing/tiny-GemmaForCausalLM", id="gemma"), pytest.param("trl-internal-testing/tiny-Gemma2ForCausalLM", id="gemma2"), - pytest.param( - "trl-internal-testing/tiny-Glm4MoeForCausalLM", - id="glm4moe", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="GLM4 tokenizer requires transformers>=5.0.0", - ), - ), + pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3", id="llama3"), pytest.param("trl-internal-testing/tiny-Phi3ForCausalLM-3", id="phi3"), diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 68a60b9bb0..b14be20b83 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -530,13 +530,7 @@ class TestApplyChatTemplate(TrlTestCase): "trl-internal-testing/tiny-Gemma2ForCausalLM", "trl-internal-testing/tiny-GemmaForCausalLM", "trl-internal-testing/tiny-GptOssForCausalLM", - pytest.param( - "trl-internal-testing/tiny-Glm4MoeForCausalLM", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="GLM4 tokenizer requires transformers>=5.0.0", - ), - ), + "trl-internal-testing/tiny-Glm4MoeForCausalLM", "trl-internal-testing/tiny-LlamaForCausalLM-3.1", "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-LlamaForCausalLM-3", diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 700dcd1d6a..851fc830dd 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -355,13 +355,7 @@ def test_init_with_training_arguments(self): "model_id", [ "trl-internal-testing/tiny-Cohere2ForCausalLM", - pytest.param( - "trl-internal-testing/tiny-Glm4MoeForCausalLM", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="GLM4 tokenizer requires transformers>=5.0.0", - ), - ), + "trl-internal-testing/tiny-Glm4MoeForCausalLM", "trl-internal-testing/tiny-GptOssForCausalLM", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3MoeForCausalLM",