Skip to content
13 changes: 7 additions & 6 deletions src/optimum/nvidia/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from optimum.nvidia.errors import UnsupportedModelException
from optimum.nvidia.models.gemma import GemmaForCausalLM
from optimum.nvidia.models.llama import LlamaForCausalLM
from optimum.nvidia.models.qwen import QwenForCausalLM
from optimum.nvidia.utils import model_type_from_known_config


Expand All @@ -36,7 +37,8 @@ class AutoModelForCausalLM(ModelHubMixin):
"mistral": LlamaForCausalLM,
"mixtral": LlamaForCausalLM,
"gemma": GemmaForCausalLM,
# "phi": PhiForCausalLM
"qwen": QwenForCausalLM,
"qwen2": QwenForCausalLM,
}

def __init__(self):
Expand All @@ -63,12 +65,11 @@ def _from_pretrained(
if config is None:
raise ValueError("Unable to determine the model type with config = None")

model_type = model_type_from_known_config(config)
if not (model_type := model_type_from_known_config(config)):
print(f"Config: {config}")
raise ValueError(f"Unable to determine the model type from undefined model_type '{model_type}'")

if (
not model_type
or model_type not in AutoModelForCausalLM._SUPPORTED_MODEL_CLASS
):
if model_type not in AutoModelForCausalLM._SUPPORTED_MODEL_CLASS:
raise UnsupportedModelException(model_type)

model_clazz = AutoModelForCausalLM._SUPPORTED_MODEL_CLASS[model_type]
Expand Down
29 changes: 29 additions & 0 deletions src/optimum/nvidia/models/qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger

from tensorrt_llm.models.qwen.model import QWenForCausalLM
from transformers import Qwen2ForCausalLM as TransformersQwen2ForCausalLM

from optimum.nvidia.models import SupportsTransformersConversion
from optimum.nvidia.runtime import CausalLM


LOGGER = getLogger(__name__)


class QwenForCausalLM(CausalLM, SupportsTransformersConversion):
HF_LIBRARY_TARGET_MODEL_CLASS = TransformersQwen2ForCausalLM
TRT_LLM_TARGET_MODEL_CLASSES = QWenForCausalLM
4 changes: 3 additions & 1 deletion tests/integration/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
"meta-llama/Llama-2-7b-chat-hf",
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Meta-Llama-3-8B",
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
"Qwen/Qwen1.5-0.5B-Chat",
"Qwen/Qwen2-1.5B",
"Qwen/Qwen2.5-3B",
}

MODEL_KWARGS_MAPS = {"Mixtral-8x7B-Instruct-v0.1": {"tp": 4}}
Expand Down
16 changes: 10 additions & 6 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def test_folder_list_engines(rank: int):
("meta-llama/Llama-2-7b-chat-hf", 1),
("google/gemma-2b", 1),
("mistralai/Mistral-7B-v0.3", 4),
("Qwen/Qwen1.5-0.5B-Chat", 1),
("Qwen/Qwen2-1.5B", 1),
("Qwen/Qwen2.5-3B", 2),
],
)
def test_save_engine_locally_and_reload(model_id: Tuple[str, int]):
Expand Down Expand Up @@ -125,22 +128,23 @@ def _reload():
@pytest.mark.parametrize(
"type",
(
("llama", "LlamaForCausalLM"),
("gemma", "GemmaForCausalLM"),
("mistral", "MistralForCausalLM"),
("mixtral", "MixtralForCausalLM"),
("llama", "llama", "LlamaForCausalLM"),
("gemma", "gemma", "GemmaForCausalLM"),
("mistral", "mistral", "MistralForCausalLM"),
("mixtral", "mixtral", "MixtralForCausalLM"),
("qwen2", "qwen", "QwenForCausalLM"),
),
)
def test_model_type_from_known_config(type: Tuple[str, str]):
transformers_type, trtllm_type = type
transformers_type, trtllm_model_type, trtllm_type = type

# transformers config
transformers_config = {"model_type": transformers_type}
assert model_type_from_known_config(transformers_config) == transformers_type

# trtllm engine config
tensorrt_llm_config = {"pretrained_config": {"architecture": trtllm_type}}
assert model_type_from_known_config(tensorrt_llm_config) == transformers_type
assert model_type_from_known_config(tensorrt_llm_config) == trtllm_model_type


def test_model_type_from_known_config_fail():
Expand Down
Loading