diff --git a/src/optimum/nvidia/models/auto.py b/src/optimum/nvidia/models/auto.py index 0a092c42..8ea71c7a 100644 --- a/src/optimum/nvidia/models/auto.py +++ b/src/optimum/nvidia/models/auto.py @@ -20,6 +20,7 @@ from optimum.nvidia.errors import UnsupportedModelException from optimum.nvidia.models.gemma import GemmaForCausalLM from optimum.nvidia.models.llama import LlamaForCausalLM +from optimum.nvidia.models.qwen import QwenForCausalLM from optimum.nvidia.utils import model_type_from_known_config @@ -36,7 +37,8 @@ class AutoModelForCausalLM(ModelHubMixin): "mistral": LlamaForCausalLM, "mixtral": LlamaForCausalLM, "gemma": GemmaForCausalLM, - # "phi": PhiForCausalLM + "qwen": QwenForCausalLM, + "qwen2": QwenForCausalLM, } def __init__(self): @@ -63,12 +65,11 @@ def _from_pretrained( if config is None: raise ValueError("Unable to determine the model type with config = None") - model_type = model_type_from_known_config(config) + if not (model_type := model_type_from_known_config(config)): + print(f"Config: {config}") + raise ValueError(f"Unable to determine the model type from undefined model_type '{model_type}'") - if ( - not model_type - or model_type not in AutoModelForCausalLM._SUPPORTED_MODEL_CLASS - ): + if model_type not in AutoModelForCausalLM._SUPPORTED_MODEL_CLASS: raise UnsupportedModelException(model_type) model_clazz = AutoModelForCausalLM._SUPPORTED_MODEL_CLASS[model_type] diff --git a/src/optimum/nvidia/models/qwen.py b/src/optimum/nvidia/models/qwen.py new file mode 100644 index 00000000..7d7d910f --- /dev/null +++ b/src/optimum/nvidia/models/qwen.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from logging import getLogger + +from tensorrt_llm.models.qwen.model import QWenForCausalLM +from transformers import Qwen2ForCausalLM as TransformersQwen2ForCausalLM + +from optimum.nvidia.models import SupportsTransformersConversion +from optimum.nvidia.runtime import CausalLM + + +LOGGER = getLogger(__name__) + + +class QwenForCausalLM(CausalLM, SupportsTransformersConversion): + HF_LIBRARY_TARGET_MODEL_CLASS = TransformersQwen2ForCausalLM + TRT_LLM_TARGET_MODEL_CLASSES = QWenForCausalLM diff --git a/tests/integration/test_causal_lm.py b/tests/integration/test_causal_lm.py index cfb29622..cf5f7fe6 100644 --- a/tests/integration/test_causal_lm.py +++ b/tests/integration/test_causal_lm.py @@ -28,7 +28,9 @@ "meta-llama/Llama-2-7b-chat-hf", "mistralai/Mistral-7B-Instruct-v0.2", "meta-llama/Meta-Llama-3-8B", - # "mistralai/Mixtral-8x7B-Instruct-v0.1", + "Qwen/Qwen1.5-0.5B-Chat", + "Qwen/Qwen2-1.5B", + "Qwen/Qwen2.5-3B", } MODEL_KWARGS_MAPS = {"Mixtral-8x7B-Instruct-v0.1": {"tp": 4}} diff --git a/tests/test_hub.py b/tests/test_hub.py index 9a23fe23..1f6a868b 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -65,6 +65,9 @@ def test_folder_list_engines(rank: int): ("meta-llama/Llama-2-7b-chat-hf", 1), ("google/gemma-2b", 1), ("mistralai/Mistral-7B-v0.3", 4), + ("Qwen/Qwen1.5-0.5B-Chat", 1), + ("Qwen/Qwen2-1.5B", 1), + ("Qwen/Qwen2.5-3B", 2), ], ) def test_save_engine_locally_and_reload(model_id: Tuple[str, int]): @@ -125,14 +128,15 @@ def _reload(): @pytest.mark.parametrize( "type", ( - ("llama", "LlamaForCausalLM"), - ("gemma", "GemmaForCausalLM"), - ("mistral", "MistralForCausalLM"), - ("mixtral", "MixtralForCausalLM"), + ("llama", "llama", "LlamaForCausalLM"), + ("gemma", "gemma", "GemmaForCausalLM"), + ("mistral", "mistral", "MistralForCausalLM"), + ("mixtral", "mixtral", "MixtralForCausalLM"), + ("qwen2", "qwen", "QwenForCausalLM"), ), ) def test_model_type_from_known_config(type: Tuple[str, str]): - transformers_type, trtllm_type = type + transformers_type, trtllm_model_type, trtllm_type = type # transformers config transformers_config = {"model_type": transformers_type} @@ -140,7 +144,7 @@ def test_model_type_from_known_config(type: Tuple[str, str]): # trtllm engine config tensorrt_llm_config = {"pretrained_config": {"architecture": trtllm_type}} - assert model_type_from_known_config(tensorrt_llm_config) == transformers_type + assert model_type_from_known_config(tensorrt_llm_config) == trtllm_model_type def test_model_type_from_known_config_fail():